From 03302b3f917d229f6f79750fa1490fc56b12dd6a Mon Sep 17 00:00:00 2001 From: shantoshdurai Date: Wed, 10 Jun 2026 22:21:08 +0530 Subject: [PATCH] fix(visualization): keep layer arrow in frame for linear flows Set preliminary plot limits before drawing layer annotations so point-based offsets map correctly when all nodes share one y row. Expand ymin for layer label height and enforce a minimum vertical axis span for degenerate layouts. Fixes TeamGraphix/graphix#535 --- graphix/visualization.py | 19 +++++++++++++++---- tests/test_visualization.py | 28 +++++++++++++++++++++++++++- 2 files changed, 42 insertions(+), 5 deletions(-) diff --git a/graphix/visualization.py b/graphix/visualization.py index 8bd39d68b..aa90a3a25 100644 --- a/graphix/visualization.py +++ b/graphix/visualization.py @@ -66,6 +66,9 @@ # Layers LAYER_C = "gray" LAYER_FS = 10 +LAYER_LABEL_OFFSET_PT = 30 +LAYER_ANNOTATION_PADDING_PT = 8 +LAYER_ANNOTATION_BOTTOM_PT = LAYER_LABEL_OFFSET_PT + LAYER_FS + LAYER_ANNOTATION_PADDING_PT # Other labels LABEL_MEAS_FS = 9.5 @@ -315,6 +318,7 @@ def visualize(self) -> None: self._draw_nodes() if self.n_layers: + self._set_plot_lims(plot_lims) plot_lims = self._draw_layers(plot_lims) if self.options.measurement_labels: @@ -374,8 +378,14 @@ def _set_plot_lims(plot_lims: PlotLims) -> None: Current plot limits in axis coordinates. """ offset = 0.7 + ymin = plot_lims.ymin - offset + ymax = plot_lims.ymax + offset + if ymax - ymin < 2 * offset: + mid = (plot_lims.ymin + plot_lims.ymax) / 2 + ymin = mid - offset + ymax = mid + offset plt.xlim(plot_lims.xmin - offset, plot_lims.xmax + offset) - plt.ylim(plot_lims.ymin - offset, plot_lims.ymax + offset) + plt.ylim(ymin, ymax) def _draw_nodes(self) -> None: """Draw graph nodes with style indicating their role and measurement type. @@ -523,7 +533,7 @@ def _draw_layers(self, plot_lims: PlotLims) -> PlotLims: arrowprops={"arrowstyle": "->", "color": "gray", "lw": 1.2}, ) - offset = mtransforms.ScaledTranslation(0, -30 / 72, fig.dpi_scale_trans) + offset = mtransforms.ScaledTranslation(0, -LAYER_LABEL_OFFSET_PT / 72, fig.dpi_scale_trans) mid_x = (self.n_layers - 1) / 2 * self.options.node_distance[0] plt.text( mid_x, @@ -536,8 +546,9 @@ def _draw_layers(self, plot_lims: PlotLims) -> PlotLims: transform=base + offset, ) - # Update plot_lims to take into account label - trans = base + offset + # Update plot_lims to take into account label and arrow below the nodes. + bottom_offset = mtransforms.ScaledTranslation(0, -LAYER_ANNOTATION_BOTTOM_PT / 72, fig.dpi_scale_trans) + trans = base + bottom_offset _, ydisp = trans.transform((0, plot_lims.ymin)) return replace(plot_lims, ymin=base.inverted().transform((0, ydisp))[1]) diff --git a/tests/test_visualization.py b/tests/test_visualization.py index c2f6942b2..9d5797147 100644 --- a/tests/test_visualization.py +++ b/tests/test_visualization.py @@ -8,7 +8,7 @@ import networkx as nx import pytest -from graphix import Circuit, Pattern, command +from graphix import Circuit, Pattern, Plane, command from graphix.fundamentals import ANGLE_PI from graphix.measurements import Measurement from graphix.opengraph import OpenGraph, OpenGraphError @@ -208,6 +208,32 @@ def test_og_draw() -> Figure: return plt.gcf() +@pytest.mark.usefixtures("mock_plot") +def test_linear_causal_flow_layer_label_in_frame() -> None: + """Regression for TeamGraphix/graphix#535.""" + og = OpenGraph( + graph=nx.Graph([(0, 1), (1, 2), (2, 3)]), + input_nodes=[0], + output_nodes=[3], + measurements=dict.fromkeys(range(3), Plane.XY), + ) + flow = og.extract_causal_flow() + flow.draw(legend=False) + + ax = plt.gca() + renderer = plt.gcf().canvas.get_renderer() + axes_bbox = ax.get_window_extent(renderer) + layer_labels = [ + artist + for artist in ax.get_children() + if hasattr(artist, "get_text") and artist.get_text() == "Layer" + ] + assert len(layer_labels) == 1 + label_bbox = layer_labels[0].get_window_extent(renderer) + assert label_bbox.y0 >= axes_bbox.y0 + plt.close() + + @pytest.mark.usefixtures("mock_plot") @pytest.mark.mpl_image_compare def test_causal_flow_draw() -> Figure: