Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions graphix/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@
# Layers
LAYER_C = "gray"
LAYER_FS = 10
LAYER_LABEL_OFFSET_PT = 30
LAYER_ANNOTATION_PADDING_PT = 8
LAYER_ANNOTATION_BOTTOM_PT = LAYER_LABEL_OFFSET_PT + LAYER_FS + LAYER_ANNOTATION_PADDING_PT

# Other labels
LABEL_MEAS_FS = 9.5
Expand Down Expand Up @@ -315,6 +318,7 @@ def visualize(self) -> None:
self._draw_nodes()

if self.n_layers:
self._set_plot_lims(plot_lims)
plot_lims = self._draw_layers(plot_lims)

if self.options.measurement_labels:
Expand Down Expand Up @@ -374,8 +378,14 @@ def _set_plot_lims(plot_lims: PlotLims) -> None:
Current plot limits in axis coordinates.
"""
offset = 0.7
ymin = plot_lims.ymin - offset
ymax = plot_lims.ymax + offset
if ymax - ymin < 2 * offset:
mid = (plot_lims.ymin + plot_lims.ymax) / 2
ymin = mid - offset
ymax = mid + offset
plt.xlim(plot_lims.xmin - offset, plot_lims.xmax + offset)
plt.ylim(plot_lims.ymin - offset, plot_lims.ymax + offset)
plt.ylim(ymin, ymax)

def _draw_nodes(self) -> None:
"""Draw graph nodes with style indicating their role and measurement type.
Expand Down Expand Up @@ -523,7 +533,7 @@ def _draw_layers(self, plot_lims: PlotLims) -> PlotLims:
arrowprops={"arrowstyle": "->", "color": "gray", "lw": 1.2},
)

offset = mtransforms.ScaledTranslation(0, -30 / 72, fig.dpi_scale_trans)
offset = mtransforms.ScaledTranslation(0, -LAYER_LABEL_OFFSET_PT / 72, fig.dpi_scale_trans)
mid_x = (self.n_layers - 1) / 2 * self.options.node_distance[0]
plt.text(
mid_x,
Expand All @@ -536,8 +546,9 @@ def _draw_layers(self, plot_lims: PlotLims) -> PlotLims:
transform=base + offset,
)

# Update plot_lims to take into account label
trans = base + offset
# Update plot_lims to take into account label and arrow below the nodes.
bottom_offset = mtransforms.ScaledTranslation(0, -LAYER_ANNOTATION_BOTTOM_PT / 72, fig.dpi_scale_trans)
trans = base + bottom_offset
_, ydisp = trans.transform((0, plot_lims.ymin))
return replace(plot_lims, ymin=base.inverted().transform((0, ydisp))[1])

Expand Down
28 changes: 27 additions & 1 deletion tests/test_visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import networkx as nx
import pytest

from graphix import Circuit, Pattern, command
from graphix import Circuit, Pattern, Plane, command
from graphix.fundamentals import ANGLE_PI
from graphix.measurements import Measurement
from graphix.opengraph import OpenGraph, OpenGraphError
Expand Down Expand Up @@ -208,6 +208,32 @@ def test_og_draw() -> Figure:
return plt.gcf()


@pytest.mark.usefixtures("mock_plot")
def test_linear_causal_flow_layer_label_in_frame() -> None:
"""Regression for TeamGraphix/graphix#535."""
og = OpenGraph(
graph=nx.Graph([(0, 1), (1, 2), (2, 3)]),
input_nodes=[0],
output_nodes=[3],
measurements=dict.fromkeys(range(3), Plane.XY),
)
flow = og.extract_causal_flow()
flow.draw(legend=False)

ax = plt.gca()
renderer = plt.gcf().canvas.get_renderer()
axes_bbox = ax.get_window_extent(renderer)
layer_labels = [
artist
for artist in ax.get_children()
if hasattr(artist, "get_text") and artist.get_text() == "Layer"
]
assert len(layer_labels) == 1
label_bbox = layer_labels[0].get_window_extent(renderer)
assert label_bbox.y0 >= axes_bbox.y0
plt.close()


@pytest.mark.usefixtures("mock_plot")
@pytest.mark.mpl_image_compare
def test_causal_flow_draw() -> Figure:
Expand Down
Loading