diff --git a/snuba/web/rpc/proto_visitor.py b/snuba/web/rpc/proto_visitor.py index 17571557da..a87098307e 100644 --- a/snuba/web/rpc/proto_visitor.py +++ b/snuba/web/rpc/proto_visitor.py @@ -59,6 +59,10 @@ def accept(self, visitor: ProtoVisitor) -> None: class AggregationComparisonFilterWrapper(ProtoWrapper[AggregationComparisonFilter]): def accept(self, visitor: ProtoVisitor) -> None: visitor.visit_AggregationComparisonFilterWrapper(self) + comparison_filter = self.underlying_proto + if comparison_filter.HasField("formula"): + ColumnWrapper(comparison_filter.formula.left).accept(visitor) + ColumnWrapper(comparison_filter.formula.right).accept(visitor) class AggregationFilterWrapper(ProtoWrapper[AggregationFilter]): diff --git a/tests/web/rpc/v1/test_conditional_aggregation.py b/tests/web/rpc/v1/test_conditional_aggregation.py index 4ea54ff3cf..189b064046 100644 --- a/tests/web/rpc/v1/test_conditional_aggregation.py +++ b/tests/web/rpc/v1/test_conditional_aggregation.py @@ -14,12 +14,19 @@ Column, TraceItemTableRequest, ) +from sentry_protos.snuba.v1.formula_pb2 import Literal from sentry_protos.snuba.v1.request_common_pb2 import RequestMeta, TraceItemType from sentry_protos.snuba.v1.trace_item_attribute_pb2 import ( AttributeAggregation, AttributeKey, + AttributeValue, ExtrapolationMode, Function, + StrArray, +) +from sentry_protos.snuba.v1.trace_item_filter_pb2 import ( + ComparisonFilter, + TraceItemFilter, ) from snuba.web.rpc.proto_visitor import ( @@ -83,6 +90,63 @@ def _build_avg_conditional_aggregation_comparison_filter_with_name( ) +def _build_failure_rate_formula(deprecated: bool) -> Column.BinaryFormula: + """ + Mirrors the `failure_rate()` formula clients send to EndpointTraceItemTable: + (count(status NOT IN [ok, cancelled, unknown]) * 1.0) / count(duration_ms). + + The `right` count(duration_ms) leaf uses the deprecated `aggregation` field + when `deprecated` is True, and the newer `conditional_aggregation` otherwise. + Everything else is identical and must pass through the conversion unchanged. + """ + duration_key = AttributeKey(type=AttributeKey.TYPE_DOUBLE, name="sentry.duration_ms") + if deprecated: + duration_count_column = Column( + aggregation=AttributeAggregation( + aggregate=Function.FUNCTION_COUNT, + key=duration_key, + extrapolation_mode=ExtrapolationMode.EXTRAPOLATION_MODE_SAMPLE_WEIGHTED, + ) + ) + else: + duration_count_column = Column( + conditional_aggregation=AttributeConditionalAggregation( + aggregate=Function.FUNCTION_COUNT, + key=duration_key, + extrapolation_mode=ExtrapolationMode.EXTRAPOLATION_MODE_SAMPLE_WEIGHTED, + ) + ) + return Column.BinaryFormula( + op=Column.BinaryFormula.OP_DIVIDE, + left=Column( + formula=Column.BinaryFormula( + op=Column.BinaryFormula.OP_MULTIPLY, + left=Column( + conditional_aggregation=AttributeConditionalAggregation( + aggregate=Function.FUNCTION_COUNT, + key=AttributeKey(type=AttributeKey.TYPE_STRING, name="sentry.status"), + extrapolation_mode=ExtrapolationMode.EXTRAPOLATION_MODE_SAMPLE_WEIGHTED, + filter=TraceItemFilter( + comparison_filter=ComparisonFilter( + key=AttributeKey( + type=AttributeKey.TYPE_STRING, name="sentry.status" + ), + op=ComparisonFilter.OP_NOT_IN, + value=AttributeValue( + val_str_array=StrArray(values=["ok", "cancelled", "unknown"]) + ), + ) + ), + ) + ), + right=Column(literal=Literal(val_double=1.0)), + default_value_double=0.0, + ) + ), + right=duration_count_column, + ) + + _UNIMPORTANT_REQUEST_META = RequestMeta( project_ids=[1, 2, 3], organization_id=1, @@ -326,6 +390,58 @@ def test_convert_aggregation_to_conditional_aggregation_in_having(self) -> None: ) ) + def test_convert_aggregation_to_conditional_aggregation_in_comparison_filter_formula( + self, + ) -> None: + # Mirrors the real failure_rate() query clients send: the same formula + # appears both as a SELECT column and inside the aggregation filter + # (the HAVING clause). Its count(duration_ms) leaf uses the deprecated + # `aggregation` field. The SELECT copy already converted via + # ColumnWrapper; the aggregation-filter copy did not, so the pipeline failed + message = TraceItemTableRequest( + meta=_UNIMPORTANT_REQUEST_META, + columns=[ + Column( + key=AttributeKey(type=AttributeKey.TYPE_STRING, name="sentry.transaction"), + label="transaction", + ), + Column( + formula=_build_failure_rate_formula(deprecated=True), + label="failure_rate()", + ), + ], + aggregation_filter=AggregationFilter( + comparison_filter=AggregationComparisonFilter( + formula=_build_failure_rate_formula(deprecated=True), + op=AggregationComparisonFilter.OP_GREATER_THAN, + val=0.5, + ) + ), + ) + aggregation_to_conditional_aggregation_visitor = ( + AggregationToConditionalAggregationVisitor() + ) + message_wrapper = TraceItemTableRequestWrapper(message) + message_wrapper.accept(aggregation_to_conditional_aggregation_visitor) + + assert list(message.columns) == [ + Column( + key=AttributeKey(type=AttributeKey.TYPE_STRING, name="sentry.transaction"), + label="transaction", + ), + Column( + formula=_build_failure_rate_formula(deprecated=False), + label="failure_rate()", + ), + ] + assert message.aggregation_filter == AggregationFilter( + comparison_filter=AggregationComparisonFilter( + formula=_build_failure_rate_formula(deprecated=False), + op=AggregationComparisonFilter.OP_GREATER_THAN, + val=0.5, + ) + ) + def test_convert_aggregation_to_conditional_aggregation_in_all_of_select_and_order_by_and_having( self, ) -> None: