From dbac33ea51302de312f35a564b468b15b64d0e3c Mon Sep 17 00:00:00 2001 From: Pierre Massat Date: Mon, 15 Jun 2026 22:47:12 -0700 Subject: [PATCH] chore: Modernize Python tooling (ruff, uv, mypy) Bring the lint/type toolchain up to modern standards and clear the debt the upgrades surface. Ruff: expand the lint set from E/F/W/I to add B (bugbear), UP (pyupgrade), C4 (comprehensions), SIM (simplify) and RET (return), enforce E402, and set isort known-first-party. Apply the resulting ~5,100 fixes (typing modernization, comprehension/return cleanups). PEP 695 native-generic rules (UP040/046/047) are deliberately left ignored - mass-migrating generics under custom metaclasses is a separate, risky change. mypy: upgrade 1.1.1 -> 2.1.0 (mypyc-compiled). The jump surfaced 110 pre-existing errors that 1.1.1 missed; all are fixed (41 redundant casts plus arg-type/attr-defined/type-arg across ~25 files). Resolve the duplicate-conftest collision and align the pre-commit mypy hook's excludes with the config so 'mypy .' and pre-commit check the same files. type: ignore: reduce from 159 to 62 by fixing the underlying types; every remaining suppression is now error-code-specific (zero bare, zero unused). Config: move pytest config from setup.cfg (deleted) into pyproject.toml, add Makefile lint/lint-check/typecheck targets, and remove the obsolete no-op tests/utils/conftest.py. Several latent bugs were caught in the process: missing test assertions, a profile-events call missing an argument, and DeletionSettings constructed with positional args in the wrong fields. Co-Authored-By: Claude Opus 4.8 (1M context) Agent transcript: https://claudescope.sentry.dev/share/DD3ny1HpZNB23VAt3zOpAjqDbo3ncHFJvSHXcXKnr-Y --- .pre-commit-config.yaml | 4 + Makefile | 14 + pyproject.toml | 52 +++- scripts/check-migrations.py | 9 +- scripts/copy_tables.py | 4 +- scripts/ddl-changes.py | 37 ++- scripts/fetch_service_refs.py | 22 +- scripts/generate_items.py | 7 +- setup.cfg | 11 - snuba/admin/audit_log/base.py | 11 +- snuba/admin/audit_log/query.py | 15 +- snuba/admin/auth.py | 4 +- snuba/admin/auth_roles.py | 5 +- snuba/admin/clickhouse/common.py | 10 +- snuba/admin/clickhouse/copy_tables.py | 15 +- snuba/admin/clickhouse/database_clusters.py | 4 +- snuba/admin/clickhouse/migration_checks.py | 22 +- snuba/admin/clickhouse/nodes.py | 38 ++- snuba/admin/clickhouse/profile_events.py | 21 +- snuba/admin/clickhouse/querylog.py | 4 +- snuba/admin/clickhouse/system_queries.py | 12 +- snuba/admin/clickhouse/trace_log_parsing.py | 11 +- snuba/admin/clickhouse/tracing.py | 5 +- snuba/admin/dead_letter_queue/__init__.py | 20 +- snuba/admin/google.py | 5 +- snuba/admin/jwt.py | 4 +- snuba/admin/kafka/topics.py | 3 +- snuba/admin/migrations_policies.py | 27 +- snuba/admin/notifications/slack/client.py | 29 +- snuba/admin/notifications/slack/utils.py | 36 ++- .../admin/production_queries/prod_queries.py | 8 +- snuba/admin/rpc/rpc_queries.py | 8 +- snuba/admin/runtime_config/__init__.py | 27 +- snuba/admin/tool_policies.py | 3 +- snuba/admin/user.py | 2 +- snuba/admin/views.py | 166 ++++++----- snuba/admin/wsgi.py | 4 +- snuba/cleanup.py | 2 +- snuba/cli/accepted_outcomes_consumer.py | 14 +- snuba/cli/admin.py | 5 +- snuba/cli/api.py | 13 +- snuba/cli/bootstrap.py | 6 +- snuba/cli/bulk_load.py | 7 +- snuba/cli/cleanup.py | 14 +- snuba/cli/config.py | 11 +- snuba/cli/consumer.py | 31 +- snuba/cli/devserver.py | 5 +- snuba/cli/dlq_consumer.py | 12 +- snuba/cli/entities.py | 4 +- snuba/cli/jobs.py | 11 +- snuba/cli/lw_deletions_consumer.py | 7 +- snuba/cli/migrations.py | 34 +-- snuba/cli/offline_replacer.py | 3 +- snuba/cli/optimize.py | 13 +- snuba/cli/querylog_to_csv.py | 11 +- snuba/cli/replacer.py | 9 +- snuba/cli/rust_consumer.py | 32 +-- snuba/cli/subscriptions_executor.py | 13 +- snuba/cli/subscriptions_scheduler.py | 11 +- snuba/cli/subscriptions_scheduler_executor.py | 11 +- snuba/clickhouse/columns.py | 8 +- snuba/clickhouse/errors.py | 2 +- snuba/clickhouse/escaping.py | 33 ++- snuba/clickhouse/formatter/expression.py | 45 ++- snuba/clickhouse/formatter/nodes.py | 16 +- snuba/clickhouse/formatter/query.py | 44 ++- snuba/clickhouse/http.py | 49 ++-- snuba/clickhouse/native.py | 59 ++-- snuba/clickhouse/optimize/optimize.py | 22 +- .../clickhouse/optimize/optimize_scheduler.py | 32 +-- snuba/clickhouse/optimize/optimize_tracker.py | 15 +- snuba/clickhouse/query.py | 30 +- snuba/clickhouse/query_dsl/accessors.py | 40 +-- snuba/clickhouse/query_inspector.py | 39 ++- snuba/clickhouse/query_profiler.py | 16 +- snuba/clickhouse/translators/snuba/allowed.py | 44 ++- .../clickhouse/translators/snuba/defaults.py | 16 +- .../snuba/function_call_mappers.py | 11 +- snuba/clickhouse/translators/snuba/mappers.py | 64 ++--- snuba/clickhouse/translators/snuba/mapping.py | 2 +- snuba/clusters/cluster.py | 88 +++--- snuba/clusters/storage_sets.py | 15 +- snuba/configs/configuration.py | 71 +++-- snuba/consumers/consumer.py | 92 +++--- snuba/consumers/consumer_builder.py | 35 +-- snuba/consumers/consumer_config.py | 81 +++--- snuba/consumers/dlq.py | 8 +- snuba/consumers/rust_processor.py | 55 ++-- snuba/consumers/schemas.py | 5 +- snuba/consumers/strategy_factory.py | 25 +- snuba/datasets/cdc/cdcprocessors.py | 22 +- snuba/datasets/cdc/groupassignee_processor.py | 15 +- .../datasets/cdc/groupedmessage_processor.py | 19 +- snuba/datasets/cdc/row_processors.py | 6 +- snuba/datasets/cdc/types.py | 11 +- .../datasets/configuration/entity_builder.py | 17 +- .../datasets/configuration/storage_builder.py | 18 +- snuba/datasets/configuration/utils.py | 6 +- snuba/datasets/dataset.py | 2 +- snuba/datasets/deletion_settings.py | 9 +- snuba/datasets/entities/entity_data_model.py | 2 +- snuba/datasets/entities/entity_key.py | 5 +- snuba/datasets/entities/factory.py | 8 +- .../entities/storage_selectors/__init__.py | 7 +- .../entities/storage_selectors/eap_items.py | 2 +- .../entities/storage_selectors/errors.py | 2 +- .../entities/storage_selectors/outcomes.py | 2 +- .../entities/storage_selectors/selector.py | 2 +- snuba/datasets/entity.py | 18 +- .../entity_subscriptions/processors.py | 15 +- .../entity_subscriptions/validators.py | 13 +- snuba/datasets/events_format.py | 35 +-- snuba/datasets/factory.py | 5 +- snuba/datasets/message_filters.py | 6 +- snuba/datasets/metrics_messages.py | 3 +- snuba/datasets/plans/cluster_selector.py | 7 +- snuba/datasets/plans/entity_processing.py | 7 +- snuba/datasets/plans/entity_validation.py | 10 +- snuba/datasets/plans/query_plan.py | 30 +- snuba/datasets/plans/storage_processing.py | 7 +- snuba/datasets/plans/translator/mapper.py | 5 +- snuba/datasets/pluggable_entity.py | 23 +- snuba/datasets/processors/__init__.py | 6 +- .../processors/group_attributes_processor.py | 3 +- .../processors/rust_compat_processor.py | 15 +- .../processors/search_issues_processor.py | 25 +- .../processors/transactions_processor.py | 28 +- snuba/datasets/schemas/__init__.py | 13 +- snuba/datasets/schemas/tables.py | 10 +- snuba/datasets/slicing.py | 2 +- snuba/datasets/storage.py | 41 +-- snuba/datasets/storages/factory.py | 7 +- snuba/datasets/storages/storage_key.py | 5 +- snuba/datasets/storages/validator.py | 4 +- snuba/datasets/table_storage.py | 83 +++--- snuba/environment.py | 5 +- snuba/lw_deletions/batching.py | 31 +- snuba/lw_deletions/formatters.py | 21 +- snuba/lw_deletions/off_peak.py | 9 +- snuba/lw_deletions/strategy.py | 22 +- snuba/lw_deletions/types.py | 7 +- snuba/manual_jobs/__init__.py | 7 +- .../delete_events_by_tag_key_value.py | 7 +- snuba/manual_jobs/extract_span_data.py | 5 +- snuba/manual_jobs/redis.py | 6 +- .../manual_jobs/rerun_idempotent_migration.py | 5 +- snuba/manual_jobs/runner.py | 5 +- snuba/manual_jobs/scrub_ips_from_eap_spans.py | 7 +- .../manual_jobs/scrub_users_from_eap_spans.py | 7 +- .../scrub_users_from_eap_spans_str_attrs.py | 7 +- snuba/manual_jobs/update_migration_status.py | 5 +- snuba/migrations/autogeneration/diff.py | 14 +- snuba/migrations/autogeneration/main.py | 5 +- snuba/migrations/check_dangerous.py | 41 ++- snuba/migrations/columns.py | 20 +- snuba/migrations/connect.py | 18 +- snuba/migrations/context.py | 3 +- snuba/migrations/group_loader.py | 15 +- snuba/migrations/groups.py | 15 +- snuba/migrations/migration.py | 14 +- snuba/migrations/migration_utilities.py | 8 +- snuba/migrations/operations.py | 53 ++-- snuba/migrations/parse_schema.py | 10 +- snuba/migrations/policies.py | 6 +- snuba/migrations/runner.py | 51 ++-- .../system_migrations/0001_migrations.py | 2 +- snuba/migrations/table_engines.py | 69 +++-- snuba/migrations/validator.py | 68 ++--- snuba/pipeline/composite_entity_processing.py | 33 +-- .../pipeline/composite_storage_processing.py | 5 +- snuba/pipeline/query_pipeline.py | 21 +- snuba/pipeline/stages/query_execution.py | 56 ++-- snuba/pipeline/stages/query_processing.py | 16 +- .../storage_query_identity_translate.py | 45 ++- snuba/pipeline/utils/storage_finder.py | 3 +- snuba/processor.py | 39 ++- snuba/protos/common.py | 14 +- snuba/query/__init__.py | 150 +++++----- snuba/query/accessors.py | 6 +- snuba/query/allocation_policies/__init__.py | 14 +- .../bytes_scanned_rejecting_policy.py | 63 ++-- .../bytes_scanned_window_policy.py | 128 ++++----- .../concurrent_rate_limit.py | 30 +- snuba/query/allocation_policies/cross_org.py | 17 +- snuba/query/allocation_policies/utils.py | 4 +- snuba/query/composite.py | 50 ++-- snuba/query/conditions.py | 40 +-- snuba/query/data_source/join.py | 13 +- snuba/query/data_source/simple.py | 10 +- snuba/query/data_source/visitor.py | 25 +- snuba/query/dsl.py | 38 +-- snuba/query/dsl_mapper.py | 6 +- snuba/query/exceptions.py | 6 +- snuba/query/expressions.py | 48 ++-- snuba/query/formatters/tracing.py | 62 ++-- snuba/query/indexer/resolver.py | 3 +- snuba/query/joins/classifier.py | 45 ++- snuba/query/joins/equivalence_adder.py | 10 +- .../query/joins/metrics_subquery_generator.py | 6 +- snuba/query/joins/pre_processor.py | 13 +- snuba/query/joins/semi_joins.py | 6 +- snuba/query/joins/subquery_generator.py | 15 +- snuba/query/logical.py | 52 ++-- snuba/query/matchers.py | 65 +++-- snuba/query/mql/mql_context.py | 5 +- snuba/query/mql/parser.py | 161 +++++------ snuba/query/parser/__init__.py | 62 ++-- snuba/query/parser/expressions.py | 29 +- snuba/query/parser/validation/functions.py | 20 +- snuba/query/parsing.py | 5 +- .../processors/condition_checkers/__init__.py | 6 +- .../processors/logical/basic_functions.py | 14 +- .../processors/logical/custom_function.py | 17 +- .../logical/granularity_processor.py | 9 +- .../logical/hash_bucket_functions.py | 2 +- .../logical/low_cardinality_processor.py | 2 +- .../logical/optional_attribute_aggregation.py | 11 +- .../logical/timeseries_processor.py | 23 +- .../physical/abstract_array_join_optimizer.py | 43 +-- .../physical/array_has_optimizer.py | 2 +- .../physical/arrayjoin_keyvalue_optimizer.py | 36 ++- .../physical/arrayjoin_optimizer.py | 24 +- .../physical/bloom_filter_optimizer.py | 12 +- .../physical/clickhouse_settings_override.py | 3 +- .../physical/column_filter_processor.py | 2 +- .../physical/conditions_enforcer.py | 9 +- .../fixedstring_array_column_processor.py | 8 +- .../physical/group_id_column_processor.py | 19 +- .../physical/hexint_column_processor.py | 12 +- .../processors/physical/mapping_optimizer.py | 35 +-- .../processors/physical/mapping_promoter.py | 18 +- .../processors/physical/null_column_caster.py | 55 ++-- snuba/query/processors/physical/prewhere.py | 10 +- .../processors/physical/replaced_groups.py | 10 +- .../processors/physical/tuple_unaliaser.py | 7 +- .../processors/physical/type_converters.py | 22 +- .../physical/uniq_in_select_and_having.py | 16 +- .../physical/user_column_processor.py | 13 +- .../physical/uuid_array_column_processor.py | 7 +- .../physical/uuid_column_processor.py | 7 +- snuba/query/query_settings.py | 11 +- snuba/query/snql/discover_entity_selection.py | 15 +- snuba/query/snql/expression_visitor.py | 43 ++- snuba/query/snql/joins.py | 24 +- snuba/query/snql/parser.py | 270 ++++++++---------- snuba/query/validation/__init__.py | 2 +- snuba/query/validation/functions.py | 5 +- snuba/query/validation/signature.py | 61 ++-- snuba/query/validation/validators.py | 29 +- snuba/querylog/__init__.py | 31 +- snuba/querylog/query_metadata.py | 29 +- snuba/reader.py | 67 ++--- snuba/redis.py | 32 +-- snuba/replacer.py | 32 +-- snuba/replacers/errors_replacer.py | 166 +++++------ snuba/replacers/projects_query_flags.py | 41 +-- snuba/replacers/replacements_and_expiry.py | 7 +- snuba/replacers/replacer_processor.py | 11 +- snuba/request/__init__.py | 6 +- snuba/request/schema.py | 19 +- snuba/request/validation.py | 36 +-- snuba/schemas.py | 8 +- snuba/settings/__init__.py | 33 +-- snuba/settings/settings_self_hosted.py | 9 +- snuba/settings/settings_test.py | 5 +- .../settings_test_distributed_migrations.py | 3 +- snuba/settings/validation.py | 12 +- snuba/snapshots/__init__.py | 22 +- snuba/snapshots/loaders/__init__.py | 6 +- snuba/snapshots/loaders/single_table.py | 10 +- snuba/snapshots/postgres_snapshot.py | 39 ++- .../discover/0001_discover_merge_table.py | 12 +- ...0002_discover_add_deleted_tags_hash_map.py | 10 +- .../discover/0003_discover_fix_user_column.py | 10 +- .../0004_discover_fix_title_and_message.py | 10 +- .../0005_discover_fix_transaction_name.py | 14 +- .../discover/0006_discover_add_trace_id.py | 10 +- .../discover/0007_discover_add_span_id.py | 10 +- .../0008_discover_fix_add_local_table.py | 8 +- .../discover/0009_discover_add_replay_id.py | 2 +- .../events/0007_groupedmessages.py | 2 +- .../events/0008_groupassignees.py | 2 +- ...groupedmessages_onpremise_compatibility.py | 9 +- .../events/0011_rebuild_errors.py | 19 +- .../events/0012_errors_make_level_nullable.py | 2 +- .../0013_errors_add_hierarchical_hashes.py | 10 +- .../events/0017_errors_add_indexes.py | 2 +- .../0018_errors_ro_add_tags_hash_map.py | 6 +- .../events/0019_add_replay_id_column.py | 2 +- .../events/0020_add_main_thread_column.py | 6 +- .../events/0021_add_replay_id_errors_ro.py | 2 +- .../0022_add_main_thread_column_errors_ro.py | 6 +- ...e_sampled_num_processing_errors_columns.py | 6 +- ...ampled_num_processing_errors_columns_ro.py | 6 +- .../events/0025_add_flags_column.py | 6 +- .../0026_add_symbolicated_in_app_column.py | 2 +- .../0027_add_symbolicated_in_app_column_ro.py | 2 +- .../0028_add_timestamp_ms_column_errors.py | 2 +- ...0029_add_sample_weight_column_to_errors.py | 2 +- ...0_add_group_first_seen_column_to_errors.py | 14 +- .../events_analytics_platform/0001_spans.py | 6 +- .../0002_spans_attributes_mv.py | 2 +- .../0003_eap_spans_project_id_index.py | 2 +- .../0004_modify_sampling_weight.py | 2 +- .../0005_remove_attribute_mv.py | 2 +- ...6_drop_attribute_key_project_id_indexes.py | 2 +- .../0007_drop_project_id_index.py | 2 +- .../0008_drop_index_attribute_key_bucket_0.py | 2 +- ...9_drop_index_attribute_key_buckets_1_19.py | 2 +- .../0010_drop_indexes_on_attribute_values.py | 2 +- .../0011_span_attribute_table.py | 2 +- .../0012_span_attribute_table_numeric.py | 2 +- .../0013_span_attribute_table_shard_keys.py | 2 +- .../0014_span_attribute_table_smaller.py | 2 +- .../0015_span_attribute_table_namespaced.py | 2 +- .../0016_spans_v2.py | 6 +- .../0017_span_attribute_table_v3.py | 2 +- .../0018_drop_unused_span_tables.py | 2 +- .../0019_uptime_monitors_init.py | 4 +- .../0020_ourlogs_init.py | 4 +- .../0021_ourlogs_attrs.py | 4 +- .../0022_uptime_monitors_init_v2.py | 4 +- .../0023_smart_autocomplete_mv.py | 5 +- .../events_analytics_platform/0024_items.py | 6 +- .../0025_smart_autocomplete_index.py | 11 +- .../0026_items_add_attributes_hash_map.py | 2 +- ...27_uptime_checks_add_column_in_incident.py | 2 +- .../0028_ourlogs_v3.py | 4 +- ..._remove_smart_autocomplete_experimental.py | 5 +- .../0030_smart_autocomplete_items.py | 17 +- .../0032_sampled_storage_views.py | 5 +- .../0033_items_attribute_table_v1.py | 2 +- .../0034_materialize_sampled_storage_views.py | 5 +- .../0035_drop_item_attrs.py | 2 +- .../0036_items_attribute_table_v1.py | 2 +- .../0037_remove_items_attribute_mv_v1.py | 2 +- .../0038_eap_items_add_sampling_factor.py | 10 +- .../0039_update_mv_with_sampling_factor.py | 9 +- ...ms_downsampled_dist_add_sampling_factor.py | 10 +- .../0041_hashed_attributes_index.py | 15 +- .../0042_remove_hashed_columns.py | 13 +- ...ampled_retention_for_downsampled_tables.py | 22 +- ...rver_sample_rates_in_materialized_views.py | 8 +- .../0051_add_bool_keys_to_autocomplete.py | 10 +- .../0052_create_deletes_workload.py | 2 +- ...0053_alter_deletes_workload_max_threads.py | 2 +- .../0054_fix_bools_in_autocomplete.py | 2 +- ...55_fix_attribute_keys_hash_missing_bool.py | 2 +- .../0056_eap_items_dist_ro.py | 6 +- .../functions/0001_functions.py | 16 +- .../0002_add_new_columns_to_raw_functions.py | 14 +- .../0003_add_new_columns_to_raw_functions.py | 2 +- .../functions/0004_functions_v2.py | 4 +- .../0001_sets_aggregate_table.py | 2 +- .../generic_metrics/0002_sets_raw_table.py | 2 +- .../generic_metrics/0003_sets_mv.py | 2 +- .../0004_sets_raw_add_granularities.py | 2 +- .../generic_metrics/0005_sets_replace_mv.py | 2 +- ...6_sets_raw_add_granularities_dist_table.py | 2 +- .../0007_distributions_aggregate_table.py | 2 +- .../0008_distributions_raw_table.py | 2 +- .../generic_metrics/0009_distributions_mv.py | 2 +- .../0010_counters_aggregate_table.py | 2 +- .../0011_counters_raw_table.py | 2 +- .../generic_metrics/0012_counters_mv.py | 2 +- .../0013_distributions_dist_tags_hash.py | 2 +- .../0014_distribution_add_options.py | 12 +- .../generic_metrics/0015_sets_add_options.py | 10 +- .../0016_counters_add_options.py | 10 +- .../generic_metrics/0017_distributions_mv2.py | 2 +- .../0018_sets_update_opt_default.py | 10 +- .../0019_counters_update_opt_default.py | 10 +- .../generic_metrics/0020_sets_mv2.py | 2 +- .../generic_metrics/0021_counters_mv2.py | 2 +- .../0022_gauges_aggregate_table.py | 6 +- .../generic_metrics/0023_gauges_raw_table.py | 10 +- .../generic_metrics/0024_gauges_mv.py | 2 +- .../0025_counters_add_raw_tags_hash_column.py | 2 +- .../0026_gauges_add_raw_tags_hash_column.py | 2 +- .../0027_sets_add_raw_tags_column.py | 2 +- ...8_distributions_add_indexed_tags_column.py | 2 +- .../0029_add_use_case_id_index.py | 2 +- .../0030_add_record_meta_column.py | 6 +- .../0031_counters_meta_table.py | 2 +- .../0032_counters_meta_table_mv.py | 2 +- .../0033_counters_meta_tag_values_table.py | 2 +- .../0034_counters_meta_tag_values_table_mv.py | 2 +- ...create_counters_meta_tag_value_table_mv.py | 2 +- .../0036_counters_meta_tables_final.py | 2 +- .../0037_add_record_meta_column_sets.py | 6 +- ...38_add_record_meta_column_distributions.py | 6 +- .../0039_add_record_meta_column_gauges.py | 6 +- .../0040_remove_counters_meta_tables.py | 2 +- .../0041_adjust_partitioning_meta_tables.py | 2 +- .../0042_rename_counters_meta_tables.py | 10 +- .../generic_metrics/0043_sets_meta_tables.py | 2 +- .../0044_gauges_meta_tables.py | 2 +- .../0045_distributions_meta_tables.py | 2 +- ...6_distributions_add_disable_percentiles.py | 6 +- .../generic_metrics/0047_distributions_mv3.py | 2 +- ...counters_meta_tables_support_empty_tags.py | 2 +- ...049_sets_meta_tables_support_empty_tags.py | 2 +- ...ibutions_meta_tables_support_empty_tags.py | 2 +- ...1_gauges_meta_tables_support_empty_tags.py | 2 +- .../0052_counters_raw_add_sampling_weight.py | 4 +- ...counters_aggregated_add_sampling_weight.py | 2 +- .../generic_metrics/0054_counters_mv3.py | 2 +- .../0055_gauges_raw_add_sampling_weight.py | 4 +- ..._gauges_aggregated_add_weighted_columns.py | 2 +- .../generic_metrics/0057_gauges_mv3.py | 2 +- ...8_distributions_raw_add_sampling_weight.py | 4 +- ...butions_aggregated_add_weighted_columns.py | 2 +- .../generic_metrics/0060_distributions_mv4.py | 2 +- ..._remove_distribution_meta_tag_values_mv.py | 2 +- .../group_attributes/0001_group_attributes.py | 4 +- .../0002_add_priority_to_group_attributes.py | 2 +- ...dd_first_release_id_to_group_attributes.py | 2 +- ...irst_release_column_to_group_attributes.py | 2 +- .../metrics/0001_metrics_buckets.py | 2 +- .../metrics/0002_metrics_sets.py | 2 +- .../metrics/0003_counters_to_buckets.py | 2 +- .../metrics/0004_metrics_counters.py | 2 +- .../0005_metrics_distributions_buckets.py | 2 +- .../metrics/0006_metrics_distributions.py | 6 +- .../0007_metrics_sets_granularity_10.py | 8 +- .../0008_metrics_counters_granularity_10.py | 6 +- ...09_metrics_distributions_granularity_10.py | 2 +- .../0010_metrics_sets_granularity_1h.py | 6 +- .../0011_metrics_counters_granularity_1h.py | 2 +- ...12_metrics_distributions_granularity_1h.py | 2 +- .../0013_metrics_sets_granularity_1d.py | 2 +- .../0014_metrics_counters_granularity_1d.py | 2 +- ...15_metrics_distributions_granularity_1d.py | 2 +- ...6_metrics_sets_consolidated_granularity.py | 2 +- ...trics_counters_consolidated_granularity.py | 2 +- ..._distributions_consolidated_granularity.py | 2 +- .../metrics/0019_aggregate_tables_add_ttl.py | 5 +- .../metrics/0020_polymorphic_buckets_table.py | 6 +- ...1_polymorphic_bucket_materialized_views.py | 2 +- .../0022_repartition_polymorphic_table.py | 6 +- ...olymorphic_repartitioned_bucket_matview.py | 2 +- ...024_metrics_distributions_add_histogram.py | 10 +- .../0025_metrics_counters_aggregate_v2.py | 12 +- ...026_metrics_counters_v2_writing_matview.py | 2 +- .../metrics/0027_fix_migration_0026.py | 2 +- .../metrics/0028_metrics_sets_aggregate_v2.py | 12 +- ...0029_metrics_distributions_aggregate_v2.py | 12 +- ...030_metrics_distributions_v2_writing_mv.py | 10 +- .../0031_metrics_sets_v2_writing_mv.py | 2 +- ...2_redo_0030_and_0031_without_timestamps.py | 10 +- .../metrics/0033_metrics_cleanup_old_views.py | 6 +- .../0034_metrics_cleanup_old_tables.py | 10 +- .../metrics/0035_metrics_raw_timeseries_id.py | 3 +- snuba/snuba_migrations/metrics/templates.py | 7 +- .../outcomes/0001_outcomes.py | 2 +- .../0002_outcomes_remove_size_and_bytes.py | 2 +- ...0003_outcomes_add_category_and_quantity.py | 34 +-- .../0004_outcomes_matview_additions.py | 2 +- .../outcomes/0005_outcomes_ttl.py | 2 +- .../outcomes/0006_outcomes_add_size_col.py | 2 +- .../0007_outcomes_add_event_id_ttl_codec.py | 2 +- .../outcomes/0008_outcomes_add_indexes.py | 2 +- .../outcomes/0009_outcomes_daily_table.py | 2 +- .../0010_outcomes_daily_fixed_partitioning.py | 2 +- .../outcomes/0011_add_quantity64.py | 2 +- .../0001_create_profile_chunks_table.py | 4 +- .../0002_add_environment_column.py | 2 +- .../profiles/0001_profiles.py | 8 +- .../0002_disable_vertical_merge_algorithm.py | 2 +- .../profiles/0003_add_device_architecture.py | 2 +- .../profiles/0004_drop_profile_column.py | 2 +- .../querylog/0001_querylog.py | 2 +- .../querylog/0002_status_type_change.py | 14 +- .../querylog/0003_add_profile_fields.py | 36 +-- .../querylog/0004_add_bytes_scanned.py | 10 +- .../0005_add_codec_update_settings.py | 3 +- .../querylog/0006_sorting_key_change.py | 10 +- .../querylog/0007_add_offset_column.py | 10 +- .../snuba_migrations/replays/0001_replays.py | 10 +- .../snuba_migrations/replays/0002_add_url.py | 4 +- .../replays/0003_alter_url_allow_null.py | 2 +- .../replays/0004_add_error_ids_column.py | 12 +- ..._urls_user_agent_replay_start_timestamp.py | 12 +- .../replays/0006_add_is_archived_column.py | 24 +- .../replays/0007_add_replay_type_column.py | 24 +- .../replays/0008_add_sample_rate.py | 4 +- .../replays/0009_add_dom_index_columns.py | 4 +- .../replays/0010_add_nullable_columns.py | 4 +- .../replays/0011_add_is_dead_rage.py | 4 +- .../replays/0012_materialize_counts.py | 4 +- .../0013_add_low_cardinality_codecs.py | 20 +- .../replays/0014_add_id_event_columns.py | 2 +- .../0015_index_frequently_accessed_columns.py | 2 +- .../0016_materialize_new_event_counts.py | 8 +- .../replays/0017_add_component_name_column.py | 4 +- .../replays/0018_add_viewed_by_id_column.py | 6 +- .../replays/0019_add_materialization.py | 24 +- ..._add_dist_migration_for_materialization.py | 24 +- .../replays/0021_index_tags.py | 2 +- .../replays/0022_add_context_ota_updates.py | 4 +- .../replays/0023_add_geo_columns.py | 4 +- .../replays/0024_add_tap_columns.py | 4 +- .../search_issues/0001_search_issues.py | 8 +- .../0002_search_issues_add_tags_hash_map.py | 2 +- ...h_issues_modify_occurrence_type_id_size.py | 2 +- ...0004_rebuild_search_issues_with_version.py | 8 +- .../search_issues/0005_search_issues_v2.py | 8 +- ..._add_subtitle_culprit_level_resource_id.py | 2 +- .../0007_add_transaction_duration.py | 2 +- .../0008_add_profile_id_replay_id.py | 2 +- .../search_issues/0009_add_message.py | 2 +- .../0010_add_group_first_seen.py | 6 +- .../search_issues/0011_add_timestamp_ms.py | 2 +- .../0012_add_group_id_bloom_filter_index.py | 2 +- .../test_migration/0001_create_test_table.py | 3 +- .../test_migration/0002_add_test_col.py | 3 +- .../transactions/0001_transactions.py | 4 +- ...s_onpremise_fix_orderby_and_partitionby.py | 10 +- ...0003_transactions_onpremise_fix_columns.py | 34 +-- .../0004_transactions_add_tags_hash_map.py | 6 +- .../0005_transactions_add_measurements.py | 10 +- .../0006_transactions_add_http_fields.py | 18 +- .../0007_transactions_add_discover_cols.py | 26 +- .../0008_transactions_add_timestamp_index.py | 2 +- ...0009_transactions_fix_title_and_message.py | 18 +- .../0010_transactions_nullable_trace_id.py | 2 +- ...011_transactions_add_span_op_breakdowns.py | 2 +- .../0012_transactions_add_spans.py | 17 +- ...ransactions_reduce_spans_exclusive_time.py | 3 +- ...4_transactions_remove_flattened_columns.py | 3 +- .../0015_transactions_add_source_column.py | 14 +- .../0016_transactions_add_group_ids_column.py | 14 +- ..._transactions_add_app_start_type_column.py | 14 +- .../0018_transactions_add_profile_id.py | 2 +- ...ansactions_add_indexes_and_context_hash.py | 8 +- .../0020_transactions_add_codecs.py | 6 +- .../0021_transactions_add_replay_id.py | 2 +- ...0022_transactions_add_index_on_trace_id.py | 2 +- .../0023_add_profiler_id_column.py | 2 +- snuba/state/__init__.py | 71 ++--- snuba/state/cache/abstract.py | 7 +- snuba/state/cache/redis/backend.py | 60 ++-- snuba/state/explain_meta.py | 7 +- snuba/state/rate_limit.py | 29 +- snuba/subscriptions/codecs.py | 4 +- .../combined_scheduler_executor.py | 15 +- snuba/subscriptions/data.py | 55 ++-- snuba/subscriptions/executor_consumer.py | 30 +- snuba/subscriptions/scheduler.py | 54 ++-- snuba/subscriptions/scheduler_consumer.py | 61 ++-- .../scheduler_processing_strategy.py | 49 ++-- snuba/subscriptions/store.py | 6 +- snuba/subscriptions/subscription.py | 2 +- snuba/util.py | 17 +- snuba/utils/bucket_timer.py | 4 +- snuba/utils/describer.py | 13 +- snuba/utils/gcs.py | 15 +- snuba/utils/health_info.py | 8 +- snuba/utils/iterators.py | 3 +- snuba/utils/manage_topics.py | 21 +- snuba/utils/metrics/addr_config.py | 3 +- snuba/utils/metrics/backends/abstract.py | 27 +- snuba/utils/metrics/backends/datadog.py | 44 +-- snuba/utils/metrics/backends/dummy.py | 28 +- snuba/utils/metrics/backends/sentry.py | 2 +- snuba/utils/metrics/backends/testing.py | 42 +-- snuba/utils/metrics/gauge.py | 14 +- snuba/utils/metrics/timer.py | 22 +- snuba/utils/metrics/types.py | 2 +- snuba/utils/metrics/util.py | 19 +- snuba/utils/metrics/wrapper.py | 42 ++- snuba/utils/profiler.py | 3 +- snuba/utils/rate_limiter.py | 10 +- snuba/utils/registered_class.py | 33 ++- snuba/utils/schemas.py | 241 +++++++--------- snuba/utils/serializable_exception.py | 20 +- snuba/utils/server.py | 6 +- snuba/utils/streams/configuration_builder.py | 43 ++- snuba/utils/streams/metrics_adapter.py | 10 +- snuba/utils/streams/topics.py | 2 +- snuba/utils/streams/types.py | 4 +- snuba/utils/threaded_function_delegator.py | 11 +- snuba/utils/types.py | 4 +- snuba/web/__init__.py | 14 +- snuba/web/bulk_delete_query.py | 21 +- snuba/web/db_query.py | 61 ++-- snuba/web/delete_query.py | 13 +- snuba/web/query.py | 14 +- snuba/web/rpc/__init__.py | 36 ++- snuba/web/rpc/common/common.py | 77 +++-- snuba/web/rpc/common/debug_info.py | 17 +- snuba/web/rpc/common/exceptions.py | 4 +- snuba/web/rpc/proto_visitor.py | 7 +- snuba/web/rpc/storage_routing/common.py | 7 +- .../web/rpc/storage_routing/load_retriever.py | 3 +- .../routing_strategies/outcomes_based.py | 15 +- .../routing_strategies/outcomes_flex_time.py | 4 +- .../routing_strategies/storage_routing.py | 45 ++- .../routing_strategy_selector.py | 13 +- snuba/web/rpc/v1/create_subscription.py | 6 +- .../web/rpc/v1/endpoint_delete_trace_items.py | 26 +- .../web/rpc/v1/endpoint_export_trace_items.py | 32 ++- snuba/web/rpc/v1/endpoint_get_trace.py | 67 +++-- snuba/web/rpc/v1/endpoint_get_traces.py | 70 +++-- snuba/web/rpc/v1/endpoint_time_series.py | 41 ++- .../v1/endpoint_trace_item_attribute_names.py | 27 +- .../web/rpc/v1/endpoint_trace_item_details.py | 41 ++- snuba/web/rpc/v1/endpoint_trace_item_stats.py | 10 +- snuba/web/rpc/v1/endpoint_trace_item_table.py | 10 +- .../R_eap_items/resolver_time_series.py | 21 +- .../R_eap_items/resolver_trace_item_stats.py | 7 +- .../R_eap_items/resolver_trace_item_table.py | 84 +++--- .../rpc/v1/resolvers/common/aggregation.py | 31 +- .../v1/resolvers/common/cross_item_queries.py | 4 +- .../resolvers/common/formula_reliability.py | 4 +- .../v1/resolvers/common/trace_item_table.py | 39 ++- .../web/rpc/v1/trace_item_attribute_values.py | 5 +- .../sparse_aggregate_attribute_transformer.py | 45 ++- .../visitors/time_series_request_visitor.py | 13 +- snuba/web/rpc/v1/visitors/visitor_v2.py | 2 +- snuba/web/views.py | 58 ++-- snuba/writer.py | 5 +- test_distributed_migrations/conftest.py | 8 +- test_initialization/test_initialize.py | 2 +- tests/admin/clickhouse/test_querylog.py | 4 +- tests/admin/clickhouse_migrations/test_api.py | 43 +-- .../test_migration_checks.py | 16 +- tests/admin/test_api.py | 21 +- tests/admin/test_jwt.py | 9 +- tests/admin/test_migration_policies.py | 8 +- tests/admin/test_querylog_audit_log.py | 2 +- tests/admin/test_system_queries.py | 16 +- tests/assertions.py | 3 +- tests/backends/metrics.py | 57 ++-- tests/base.py | 3 +- tests/cli/test_consumer.py | 2 +- tests/cli/test_migrations.py | 4 +- tests/cli/test_subscriptions.py | 5 +- tests/clickhouse/optimize/test_optimize.py | 60 ++-- .../optimize/test_optimize_scheduler.py | 18 +- .../optimize/test_optimize_tracker.py | 7 +- tests/clickhouse/query_dsl/test_accessors.py | 6 +- tests/clickhouse/query_dsl/test_project_id.py | 7 +- tests/clickhouse/test_http.py | 4 +- tests/clickhouse/test_native.py | 7 +- tests/clickhouse/test_profile_events.py | 28 +- tests/clickhouse/test_query_format.py | 5 +- tests/clusters/fake_cluster.py | 29 +- tests/clusters/test_cluster.py | 37 ++- tests/conftest.py | 66 ++--- tests/consumers/test_consumer_builder.py | 22 +- tests/consumers/test_message_processors.py | 14 +- tests/consumers/test_schemas.py | 7 +- tests/consumers/test_utils.py | 5 +- tests/datasets/cdc/test_groupassignee.py | 14 +- tests/datasets/cdc/test_groupedmessage.py | 14 +- .../configuration/test_entity_loader.py | 9 +- .../configuration/test_storage_loader.py | 28 +- tests/datasets/configuration/utils.py | 4 +- .../entities/storage_selectors/test_errors.py | 4 +- .../storage_selectors/test_selector.py | 4 +- tests/datasets/entities/test_entity_key.py | 2 +- .../entities/test_pluggable_entity.py | 9 +- tests/datasets/plans/test_cluster_selector.py | 3 +- .../datasets/plans/translator/test_mapping.py | 14 +- .../processors/test_replaced_groups.py | 41 ++- tests/datasets/storages/test_storages.py | 2 +- tests/datasets/test_context_promotion.py | 12 +- tests/datasets/test_dataset_factory.py | 4 +- tests/datasets/test_discover.py | 3 +- tests/datasets/test_errors_processor.py | 13 +- tests/datasets/test_errors_replacer.py | 117 ++++---- tests/datasets/test_events.py | 6 +- tests/datasets/test_functions_processor.py | 21 +- .../test_generic_metrics_processor.py | 9 +- tests/datasets/test_group_attributes_join.py | 27 +- .../test_group_attributes_processor.py | 7 +- tests/datasets/test_metrics_processing.py | 16 +- tests/datasets/test_metrics_processor.py | 17 +- tests/datasets/test_processors_idempotency.py | 3 +- tests/datasets/test_profiles_processor.py | 17 +- .../datasets/test_search_issues_processor.py | 72 ++--- tests/datasets/test_table_storage.py | 3 +- tests/datasets/test_transaction_processor.py | 148 ++++------ .../test_datetime_condition_validator.py | 7 +- .../validation/test_entity_validation.py | 7 +- ...illegal_aggregate_conditions_validation.py | 5 +- .../test_no_time_condition_validator.py | 8 +- .../test_subscription_clauses_validator.py | 30 +- tests/fixtures.py | 21 +- tests/helpers.py | 7 +- tests/lw_deletions/test_formatters.py | 6 +- tests/lw_deletions/test_lw_deletions.py | 2 +- tests/lw_deletions/test_off_peak.py | 6 +- tests/manual_jobs/test_extract_span_data.py | 9 +- .../test_generate_python_migration.py | 2 +- tests/migrations/test_check_dangerous.py | 5 +- tests/migrations/test_connect.py | 8 +- tests/migrations/test_legacy_use.py | 4 +- tests/migrations/test_operations.py | 2 +- tests/migrations/test_parse_schema.py | 6 +- tests/migrations/test_policies.py | 14 +- tests/migrations/test_runner.py | 27 +- tests/migrations/test_runner_individual.py | 9 +- tests/migrations/test_validator.py | 14 +- tests/pipeline/conftest.py | 6 +- tests/pipeline/test_execution_stage.py | 6 +- tests/pipeline/test_pipeline_stage.py | 4 +- .../pipeline/test_storage_processing_stage.py | 6 +- .../test_storage_query_identity_translate.py | 5 +- .../test_allocation_policy_base.py | 90 +++--- ..._bytes_scanned_window_allocation_policy.py | 12 +- .../test_concurrent_rate_limit_policy.py | 24 +- .../test_cross_org_policy.py | 5 +- .../allocation_policies/test_per_referrer.py | 8 +- tests/query/data_source/test_join.py | 4 +- tests/query/formatters/test_query.py | 25 +- tests/query/joins/equivalence_schema.py | 2 +- tests/query/joins/join_structures.py | 31 +- tests/query/joins/test_branch_cutter.py | 4 +- tests/query/joins/test_equivalence_adder.py | 24 +- tests/query/joins/test_equivalences.py | 32 +-- tests/query/joins/test_metrics_subqueries.py | 12 +- tests/query/joins/test_semi_join.py | 7 +- tests/query/joins/test_subqueries.py | 68 ++--- tests/query/parser/test_formula_mql_query.py | 98 ++----- .../query/parser/test_invalid_legacy_query.py | 5 +- tests/query/parser/test_parser.py | 87 ++---- .../test_parse_snql_query_initial.py | 4 +- .../test_post_process_and_validate_query.py | 3 +- .../unit_tests/test_resolver_visitor.py | 16 +- .../query/parser/validation/test_functions.py | 8 +- tests/query/processors/query_builders.py | 8 +- .../processors/test_array_has_optimizer.py | 4 +- .../processors/test_arrayjoin_optimizer.py | 13 +- .../test_arrayjoin_spans_optimizer.py | 97 ++----- .../test_clickhouse_settings_override.py | 3 +- .../test_empty_tag_condition_processor.py | 2 +- ...test_fixedstring_array_column_processor.py | 2 +- .../processors/test_granularity_processor.py | 17 +- .../processors/test_handled_functions.py | 4 +- .../test_hexint_column_processor.py | 2 +- .../test_low_cardinality_processor.py | 8 +- .../test_mandatory_condition_applier.py | 3 +- .../processors/test_null_column_caster.py | 24 +- tests/query/processors/test_prewhere.py | 13 +- .../processors/test_timeseries_processor.py | 6 +- .../query/processors/test_tuple_unaliaser.py | 14 +- .../test_uuid_array_column_processor.py | 2 +- .../processors/test_uuid_column_processor.py | 4 +- tests/query/snql/test_invalid_queries.py | 3 +- tests/query/snql/test_joins.py | 6 +- tests/query/snql/test_query.py | 34 ++- .../snql/test_query_column_validation.py | 33 +-- tests/query/snql/test_storage_query.py | 34 +-- tests/query/test_expressions.py | 3 +- tests/query/test_matcher.py | 14 +- tests/query/test_query.py | 2 +- tests/query/test_query_ast.py | 15 +- tests/query/test_visitor.py | 32 +-- tests/query/validation/test_signature.py | 2 +- tests/querylog/test_query_metadata.py | 8 +- tests/replacer/test_cluster_replacements.py | 23 +- tests/replacer/test_load_balancer.py | 2 +- .../replacer/test_replacements_and_expiry.py | 21 +- tests/request/test_build_request.py | 8 +- tests/settings/test_settings.py | 6 +- tests/state/test_cache.py | 23 +- tests/state/test_rate_limit.py | 96 ++++--- tests/state/test_state.py | 8 +- tests/subscriptions/__init__.py | 4 +- .../test_entity_subscriptions.py | 17 +- .../test_entity_subscriptions_data.py | 6 +- .../subscriptions/test_builder_mode_state.py | 4 +- tests/subscriptions/test_codecs.py | 3 +- .../test_combined_scheduler_executor.py | 2 +- tests/subscriptions/test_data.py | 17 +- tests/subscriptions/test_executor_consumer.py | 10 +- .../test_filter_subscriptions.py | 10 +- tests/subscriptions/test_scheduler.py | 8 +- .../subscriptions/test_scheduler_consumer.py | 5 +- .../test_scheduler_processing_strategy.py | 4 +- tests/subscriptions/test_store.py | 2 +- tests/subscriptions/test_subscription.py | 10 +- tests/subscriptions/test_task_builder.py | 8 +- tests/subscriptions/test_types.py | 2 +- tests/test_api.py | 66 ++--- tests/test_api_status.py | 3 +- tests/test_cleanup.py | 55 ++-- tests/test_cli.py | 2 +- tests/test_configurable_component.py | 6 +- tests/test_consumer.py | 4 +- tests/test_discover_api.py | 17 +- tests/test_generic_metrics_api.py | 19 +- tests/test_group_attributes_api.py | 5 +- tests/test_metrics_api.py | 109 +++---- tests/test_metrics_meta_api.py | 13 +- tests/test_metrics_mql_api.py | 7 +- tests/test_metrics_sdk_api.py | 13 +- tests/test_outcomes_api.py | 23 +- tests/test_replacer.py | 7 +- tests/test_replays_api.py | 6 +- tests/test_search_issues_api.py | 87 +++--- tests/test_snql_api.py | 17 +- tests/test_transactions_api.py | 42 ++- tests/test_writer.py | 3 +- tests/utils/conftest.py | 6 - tests/utils/metrics/test_gauge.py | 3 +- tests/utils/test_columns_validator.py | 10 +- tests/utils/test_describer.py | 8 +- tests/utils/test_rate_limiter.py | 23 +- tests/utils/test_registered_class.py | 8 +- .../utils/test_threaded_function_delegator.py | 6 +- tests/web/rpc/test_base.py | 27 +- tests/web/rpc/test_common.py | 10 +- tests/web/rpc/v1/routing_strategies/common.py | 11 +- tests/web/rpc/v1/test_create_subscription.py | 3 +- .../v1/test_endpoint_delete_trace_items.py | 32 ++- .../v1/test_endpoint_export_trace_items.py | 20 +- tests/web/rpc/v1/test_endpoint_get_trace.py | 18 +- tests/web/rpc/v1/test_endpoint_get_traces.py | 20 +- .../test_endpoint_time_series.py | 23 +- ...ndpoint_time_series_cross_item_sampling.py | 112 ++++---- ...test_endpoint_time_series_extrapolation.py | 8 +- .../test_endpoint_time_series_logs.py | 6 +- ...est_endpoint_trace_item_attribute_names.py | 6 +- .../v1/test_endpoint_trace_item_details.py | 18 +- .../rpc/v1/test_endpoint_trace_item_stats.py | 6 +- .../test_endpoint_trace_item_stats_heatmap.py | 6 +- .../v1/test_endpoint_trace_item_stats_logs.py | 6 +- .../test_endpoint_trace_item_table.py | 78 ++--- ...nt_trace_item_table_cross_item_sampling.py | 120 ++++---- ...endpoint_trace_item_table_extrapolation.py | 54 ++-- .../test_endpoint_trace_item_table_logs.py | 6 +- .../test_occurrence_hourly_event_rate.py | 16 +- .../test_trace_item_table_flex_time.py | 8 +- tests/web/rpc/v1/test_storage_routing.py | 4 +- .../v1/test_trace_item_attribute_values_v1.py | 23 +- tests/web/rpc/v1/test_utils.py | 45 +-- tests/web/test__get_allocation_policy.py | 3 +- tests/web/test_bulk_delete_query.py | 77 ++--- tests/web/test_db_query.py | 82 +++--- tests/web/test_max_rows_enforcer.py | 3 +- tests/web/test_project_finder.py | 10 +- tests/web/test_results.py | 2 +- tests/web/test_tables_collector.py | 8 +- tests/web/test_views.py | 4 +- uv.lock | 88 +++++- 848 files changed, 6476 insertions(+), 7543 deletions(-) delete mode 100644 setup.cfg delete mode 100644 tests/utils/conftest.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d65fbce85cd..eaf9805cacf 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -28,6 +28,10 @@ repos: language: system types: [python] require_serial: true + # Keep in sync with [tool.mypy] exclude in pyproject.toml. pre-commit + # passes filenames explicitly, which bypasses mypy's own `exclude`, so + # these dirs must be re-excluded here to honor the same policy. + exclude: '^(rust_snuba/|tests/datasets/|tests/query/|test_distributed_migrations/)' - id: validate-configs-syntax name: validate-configs-syntax diff --git a/Makefile b/Makefile index a0363e57048..3be64981dc1 100644 --- a/Makefile +++ b/Makefile @@ -37,6 +37,20 @@ test-distributed: tests: test +lint: + uv run ruff check --fix . + uv run ruff format . +.PHONY: lint + +lint-check: + uv run ruff check . + uv run ruff format --check . +.PHONY: lint-check + +typecheck: + uv run mypy . +.PHONY: typecheck + api-tests: SNUBA_SETTINGS=test pytest -vv tests/*_api.py diff --git a/pyproject.toml b/pyproject.toml index 9c4dcac2563..75704ce6d1c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -78,7 +78,7 @@ snuba = "snuba.cli:main" dev = [ "devservices>=1.2.1", "freezegun>=1.5.5", - "mypy>=1.1.1", + "mypy>=1.18.2", "pre-commit>=4.2.0", "pytest>=9.0.3", "pytest-cov>=4.1.0", @@ -96,6 +96,19 @@ dev = [ "typing-extensions>=4.12.2", ] +[tool.pytest.ini_options] +python_files = "test*.py" +addopts = "--tb=native -p no:doctest -p no:warnings" +norecursedirs = "bin dist docs htmlcov script hooks node_modules .*" +looponfailroots = ["snuba", "tests"] +markers = [ + "clickhouse_db: Use clickhouse", + "redis_db: Use redis", + "ci_only: Only run in CI", + "eap: Use clickhouse with EAP migrations only", + "genmetrics_db: Use clickhouse with generic metrics migrations only", +] + [tool.ruff] # File filtering is taken care of in pre-commit. line-length = 100 @@ -103,24 +116,43 @@ target-version = "py313" [tool.ruff.lint] select = [ - # todo: eventually we should get this enabled - # "B", flake8-bugbear - "E", # pycodestyle errors - "F", # pyflakes - "W", # pycodestyle warnings - "I", # isort + "E", # pycodestyle errors + "F", # pyflakes + "W", # pycodestyle warnings + "I", # isort + "B", # flake8-bugbear + "UP", # pyupgrade + "C4", # flake8-comprehensions + "SIM", # flake8-simplify + "RET", # flake8-return ] ignore = [ - "E501", # line too long (handled by formatter) - "E402", # module level import not at top of file + "E501", # line too long (handled by formatter) + "SIM108", # use ternary instead of if-else (often hurts readability) + "RET504", # unnecessary assignment before return (named results aid debugging) + # PEP 695 native generics/type-aliases are a deliberate migration, not a + # mechanical sweep — risky alongside custom metaclasses. Adopt case by case. + "UP046", # non-pep695-generic-class + "UP047", # non-pep695-generic-function + "UP040", # non-pep695-type-alias ] +[tool.ruff.lint.isort] +known-first-party = ["snuba", "tests"] + [tool.mypy] python_version = "3.13" strict = true ignore_missing_imports = false files = ["."] -exclude = ["^rust_snuba/", "^tests/datasets/", "^tests/query/"] +exclude = [ + "^rust_snuba/", + "^tests/datasets/", + "^tests/query/", + # docker-only distributed-migration scaffolding; its top-level conftest.py + # collides with the repo-root conftest.py module name under `mypy .` + "^test_distributed_migrations/", +] [[tool.mypy.overrides]] module = [ diff --git a/scripts/check-migrations.py b/scripts/check-migrations.py index c9050d19e2b..b0cf918f24e 100755 --- a/scripts/check-migrations.py +++ b/scripts/check-migrations.py @@ -2,8 +2,8 @@ import argparse import os import subprocess +from collections.abc import Sequence from shutil import ExecError -from typing import Optional, Sequence import requests @@ -107,13 +107,10 @@ def _get_changes(globs: Sequence[str], workdir: str, to: str) -> str: ) if changes.returncode != 0: raise ExecError(changes.stdout) - else: - return changes.stdout + return changes.stdout -def main( - to: str = "origin/master", workdir: str = ".", labels: Optional[Sequence[str]] = [] -) -> None: +def main(to: str = "origin/master", workdir: str = ".", labels: Sequence[str] | None = []) -> None: if labels: for label in labels: if SKIP_LABEL in label: diff --git a/scripts/copy_tables.py b/scripts/copy_tables.py index 03439266bcd..75db4ff43d4 100755 --- a/scripts/copy_tables.py +++ b/scripts/copy_tables.py @@ -3,7 +3,7 @@ import argparse import re from collections import OrderedDict -from typing import Mapping, Optional, Sequence +from collections.abc import Mapping, Sequence from clickhouse_driver import Client @@ -100,7 +100,7 @@ def copy_tables( source_database: str, target_database: str, execute: bool, - tables: Optional[Sequence[str]], + tables: Sequence[str] | None, ) -> None: """ When adding a replica to a clickhouse cluster, that node will not have any tables diff --git a/scripts/ddl-changes.py b/scripts/ddl-changes.py index 791a5e218ce..fd9988a3e3b 100644 --- a/scripts/ddl-changes.py +++ b/scripts/ddl-changes.py @@ -26,26 +26,25 @@ def _main() -> None: ) if diff_result.returncode != 0: raise ExecError(diff_result.stdout) - else: - lines = diff_result.stdout.splitlines() - if len(lines) > 0: - print("-- start migrations") - print() - for line in lines: - migration_filename = os.path.basename(line) - migration_group = MigrationGroup(os.path.basename(os.path.dirname(line))) - migration_id, _ = os.path.splitext(migration_filename) - runner = Runner() - migration_key = MigrationKey(migration_group, migration_id) - print(f"-- forward migration {migration_group.value} : {migration_id}") - runner.run_migration(migration_key, dry_run=True) - print(f"-- end forward migration {migration_group.value} : {migration_id}") + lines = diff_result.stdout.splitlines() + if len(lines) > 0: + print("-- start migrations") + print() + for line in lines: + migration_filename = os.path.basename(line) + migration_group = MigrationGroup(os.path.basename(os.path.dirname(line))) + migration_id, _ = os.path.splitext(migration_filename) + runner = Runner() + migration_key = MigrationKey(migration_group, migration_id) + print(f"-- forward migration {migration_group.value} : {migration_id}") + runner.run_migration(migration_key, dry_run=True) + print(f"-- end forward migration {migration_group.value} : {migration_id}") - print("\n\n\n") - migration_key = MigrationKey(migration_group, migration_id) - print(f"-- backward migration {migration_group.value} : {migration_id}") - runner.reverse_migration(migration_key, dry_run=True) - print(f"-- end backward migration {migration_group.value} : {migration_id}") + print("\n\n\n") + migration_key = MigrationKey(migration_group, migration_id) + print(f"-- backward migration {migration_group.value} : {migration_id}") + runner.reverse_migration(migration_key, dry_run=True) + print(f"-- end backward migration {migration_group.value} : {migration_id}") if __name__ == "__main__": diff --git a/scripts/fetch_service_refs.py b/scripts/fetch_service_refs.py index 822c32263df..2d4600f032c 100755 --- a/scripts/fetch_service_refs.py +++ b/scripts/fetch_service_refs.py @@ -8,7 +8,7 @@ import time import urllib.error import urllib.request -from typing import Any, Dict, Optional +from typing import Any GO_SERVER_URL = os.environ["GO_SERVER_URL"] @@ -16,12 +16,12 @@ MAX_FETCHES = 100 -def pipeline_passed(pipeline: Dict[str, Any]) -> bool: - stage_status_dict: Dict[str, str] = { +def pipeline_passed(pipeline: dict[str, Any]) -> bool: + stage_status_dict: dict[str, str] = { stage["name"]: stage["status"] for stage in pipeline["stages"] } - return stage_status_dict.get("pipeline-complete", None) == "Passed" + return stage_status_dict.get("pipeline-complete") == "Passed" # print the most recent passing sha for a repo @@ -33,7 +33,7 @@ def main(pipeline_name: str = "deploy-snuba-us", repo: str = "snuba") -> int: GOCD_ACCESS_TOKEN not set. It should be an access token belonging to bot@sentry.io. """ ) - fetch_url: Optional[str] = f"{GO_SERVER_URL}/api/pipelines/{pipeline_name}/history" + fetch_url: str | None = f"{GO_SERVER_URL}/api/pipelines/{pipeline_name}/history" fetches = 0 while fetch_url and fetches < MAX_FETCHES: fetches += 1 @@ -47,7 +47,7 @@ def main(pipeline_name: str = "deploy-snuba-us", repo: str = "snuba") -> int: try: resp = urllib.request.urlopen(req) except urllib.error.HTTPError as e: - raise SystemExit(f"Failed to fetch pipeline history:\n{e.read().decode()}") + raise SystemExit(f"Failed to fetch pipeline history:\n{e.read().decode()}") from e print("fetching pipeline history for", pipeline_name, fetch_url, file=sys.stderr) data = json.loads(resp.read()) @@ -65,11 +65,11 @@ def main(pipeline_name: str = "deploy-snuba-us", repo: str = "snuba") -> int: for r in pipeline["build_cause"]["material_revisions"]: # example material description format... `in` is good enough # 'URL: git@github.com:getsentry/devinfra-example-service.git, Branch: main' - if f"git@github.com:getsentry/{repo}.git" in r["material"]["description"]: - rev = r["modifications"][0]["revision"] - print(rev) - return 0 - elif f"https://github.com/getsentry/{repo}.git" in r["material"]["description"]: + if ( + f"git@github.com:getsentry/{repo}.git" in r["material"]["description"] + or f"https://github.com/getsentry/{repo}.git" + in r["material"]["description"] + ): rev = r["modifications"][0]["revision"] print(rev) return 0 diff --git a/scripts/generate_items.py b/scripts/generate_items.py index d6644ebc112..bb4b61328c7 100644 --- a/scripts/generate_items.py +++ b/scripts/generate_items.py @@ -1,10 +1,9 @@ import time import uuid from datetime import UTC, datetime, timedelta -from typing import Optional +from confluent_kafka import KafkaError, Producer from confluent_kafka import Message as KafkaMessage -from confluent_kafka import Producer from google.protobuf.timestamp_pb2 import Timestamp from sentry_protos.snuba.v1.request_common_pb2 import TraceItemType from sentry_protos.snuba.v1.trace_item_pb2 import AnyValue, TraceItem @@ -15,12 +14,12 @@ producer = Producer(kafka_config) -def delivery_report(err: Optional[Exception], msg: KafkaMessage) -> None: +def delivery_report(err: KafkaError | None, msg: KafkaMessage) -> None: if err is not None: print(f"Message delivery failed: {err}") -def generate_item_message(start_timestamp: Optional[datetime] = None) -> bytes: +def generate_item_message(start_timestamp: datetime | None = None) -> bytes: if start_timestamp is None: start_timestamp = datetime.now(tz=UTC) diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index 9bd528e51de..00000000000 --- a/setup.cfg +++ /dev/null @@ -1,11 +0,0 @@ -[tool:pytest] -python_files = test*.py -addopts = --tb=native -p no:doctest -p no:warnings -norecursedirs = bin dist docs htmlcov script hooks node_modules .* {args} -looponfailroots = snuba tests -markers = - clickhouse_db: Use clickhouse - redis_db: Use redis - ci_only: Only run in CI - eap: Use clickhouse with EAP migrations only - genmetrics_db: Use clickhouse with generic metrics migrations only diff --git a/snuba/admin/audit_log/base.py b/snuba/admin/audit_log/base.py index 136670d0ecb..0cfef2e4834 100644 --- a/snuba/admin/audit_log/base.py +++ b/snuba/admin/audit_log/base.py @@ -1,5 +1,6 @@ -from datetime import datetime, timezone -from typing import Any, Mapping, MutableMapping, Optional, Union +from collections.abc import Mapping, MutableMapping +from datetime import UTC, datetime +from typing import Any import structlog @@ -24,10 +25,10 @@ def record( self, user: str, action: AuditLogAction, - data: Mapping[str, Union[str, int]], - notify: Optional[bool] = False, + data: Mapping[str, str | int], + notify: bool | None = False, ) -> None: - timestamp = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%S.%fZ") + timestamp = datetime.now(UTC).strftime("%Y-%m-%dT%H:%M:%S.%fZ") self.logger.info( event=action.value, user=user, diff --git a/snuba/admin/audit_log/query.py b/snuba/admin/audit_log/query.py index c5a900901c5..5ee0a8065d0 100644 --- a/snuba/admin/audit_log/query.py +++ b/snuba/admin/audit_log/query.py @@ -1,15 +1,16 @@ from __future__ import annotations -from datetime import datetime, timezone +from collections.abc import Callable, MutableMapping +from datetime import UTC, datetime from enum import Enum from functools import partial -from typing import Callable, MutableMapping, TypeVar, Union - -DATETIME_FORMAT = "%Y-%m-%dT%H:%M:%S.%fZ" +from typing import TypeVar from snuba.admin.audit_log.action import AuditLogAction from snuba.admin.audit_log.base import AuditLog +DATETIME_FORMAT = "%Y-%m-%dT%H:%M:%S.%fZ" + Return = TypeVar("Return") @@ -30,7 +31,7 @@ def audit_log(fn: Callable[[str, str], Return]) -> Callable[[str, str], Return]: """ def audit_log_wrapper(query: str, user: str) -> Return: - data: MutableMapping[str, Union[str, QueryExecutionStatus]] = { + data: MutableMapping[str, str | int] = { "query": query, } audit_log_notify = partial( @@ -42,11 +43,11 @@ def audit_log_wrapper(query: str, user: str) -> Return: result = fn(query, user) except Exception: data["status"] = QueryExecutionStatus.FAILED.value - data["end_timestamp"] = datetime.now(timezone.utc).strftime(DATETIME_FORMAT) + data["end_timestamp"] = datetime.now(UTC).strftime(DATETIME_FORMAT) audit_log_notify(data=data) raise data["status"] = QueryExecutionStatus.SUCCEEDED.value - data["end_timestamp"] = datetime.now(timezone.utc).strftime(DATETIME_FORMAT) + data["end_timestamp"] = datetime.now(UTC).strftime(DATETIME_FORMAT) audit_log_notify(data=data) return result diff --git a/snuba/admin/auth.py b/snuba/admin/auth.py index 5bb5bb3df66..15962c183af 100644 --- a/snuba/admin/auth.py +++ b/snuba/admin/auth.py @@ -1,7 +1,7 @@ from __future__ import annotations import json -from typing import Sequence +from collections.abc import Sequence import rapidjson import structlog @@ -48,7 +48,7 @@ def _is_member_of_group(user: AdminUser, group: str) -> bool: def get_iam_roles_from_user(user: AdminUser) -> Sequence[str]: iam_roles = [] try: - with open(settings.ADMIN_IAM_POLICY_FILE, "r") as policy_file: + with open(settings.ADMIN_IAM_POLICY_FILE) as policy_file: policy = json.load(policy_file) for binding in policy["bindings"]: role: str = binding["role"].split("roles/")[-1] diff --git a/snuba/admin/auth_roles.py b/snuba/admin/auth_roles.py index c1eeec596f5..c6a6b814cc6 100644 --- a/snuba/admin/auth_roles.py +++ b/snuba/admin/auth_roles.py @@ -1,9 +1,10 @@ from __future__ import annotations from abc import ABC, abstractmethod, abstractproperty +from collections.abc import Sequence from dataclasses import dataclass from enum import Enum -from typing import Generic, Sequence, Set, TypeVar +from typing import Generic, TypeVar from snuba import settings from snuba.migrations.runner import get_active_migration_groups @@ -127,7 +128,7 @@ class ExecuteSudoSystemQuery(ToolAction): @dataclass(frozen=True) class Role: name: str - actions: Set[MigrationAction | ToolAction] + actions: set[MigrationAction | ToolAction] def generate_tool_test_role(tool: str) -> Role: diff --git a/snuba/admin/clickhouse/common.py b/snuba/admin/clickhouse/common.py index 25c71940b75..fdab1ad81f7 100644 --- a/snuba/admin/clickhouse/common.py +++ b/snuba/admin/clickhouse/common.py @@ -1,9 +1,9 @@ from __future__ import annotations import re -from typing import MutableMapping +from collections.abc import MutableMapping -from sql_metadata import Parser, QueryType # type: ignore +from sql_metadata import Parser, QueryType # type: ignore[import-untyped] from snuba import settings from snuba.clickhouse.native import ClickhousePool @@ -41,7 +41,7 @@ def is_valid_node(host: str, port: int, cluster: ClickhouseCluster, storage_name "port": port, "nodes": ",".join([node.host_name for node in nodes]), }, - ) + ) from e return any(node.host_name == host and node.port == port for node in nodes) @@ -50,11 +50,11 @@ def _get_storage(storage_name: str) -> ReadableTableStorage: storage_key = None try: storage_key = StorageKey(storage_name) - except ValueError: + except ValueError as e: raise InvalidStorageError( f"storage {storage_name} is not a valid storage name", extra_data={"storage_name": storage_name}, - ) + ) from e return get_storage(storage_key) diff --git a/snuba/admin/clickhouse/copy_tables.py b/snuba/admin/clickhouse/copy_tables.py index 52cdd90bb78..ef6b382c01a 100644 --- a/snuba/admin/clickhouse/copy_tables.py +++ b/snuba/admin/clickhouse/copy_tables.py @@ -1,5 +1,6 @@ +from collections.abc import MutableMapping, Sequence from dataclasses import dataclass -from typing import MutableMapping, Optional, Sequence, Tuple, TypedDict +from typing import TypedDict from snuba.admin.clickhouse.common import _get_storage, get_clusterless_node_connection from snuba.clickhouse.native import ClickhousePool @@ -26,7 +27,7 @@ def get_create_table_statements( tables: Sequence[str], source_connection: ClickhousePool, source_database: str, - cluster_name: Optional[str], + cluster_name: str | None, ) -> Sequence[TableStatement]: table_statements = [] @@ -55,7 +56,7 @@ def get_create_table_statements( table_engine = source_connection.execute( f"SELECT engine FROM system.tables WHERE name = '{table}'" ).results[0][0] - is_mergetree = True if "MergeTree" in table_engine else False + is_mergetree = "MergeTree" in table_engine if cluster_name: table_statement = table_statement.replace( @@ -77,10 +78,10 @@ def get_tables(connection: ClickhousePool) -> Sequence[str]: def verify_tables_on_replicas( connection: ClickhousePool, - cluster_name: Optional[str], + cluster_name: str | None, database_name: str, table_names: Sequence[str], -) -> Tuple[MutableMapping[str, list[str]], int]: +) -> tuple[MutableMapping[str, list[str]], int]: """ Checks that the tables we have copied are present on all hosts. Returns a count of the verified hosts (host that have all the @@ -124,9 +125,9 @@ def copy_tables( source_host: str, storage_name: str, dry_run: bool, - target_host: Optional[str] = None, + target_host: str | None = None, skip_on_cluster: bool = False, - cluster_name_override: Optional[str] = None, + cluster_name_override: str | None = None, ) -> CopyTablesResponse: settings = ClickhouseClientSettings.QUERY source_connection = get_clusterless_node_connection( diff --git a/snuba/admin/clickhouse/database_clusters.py b/snuba/admin/clickhouse/database_clusters.py index e7db0c63210..a2486035742 100644 --- a/snuba/admin/clickhouse/database_clusters.py +++ b/snuba/admin/clickhouse/database_clusters.py @@ -1,7 +1,7 @@ import threading +from collections.abc import Sequence from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass -from typing import List, Sequence from snuba.admin.clickhouse.common import get_ro_node_connection from snuba.admin.clickhouse.nodes import get_storage_info @@ -71,7 +71,7 @@ def fetch_node_info_from_host(host_info: HostInfo) -> Sequence[Node]: def get_node_info() -> Sequence[Node]: - node_info: List[Node] = [] + node_info: list[Node] = [] hosts = set() for storage_info in get_storage_info(): for node in storage_info["dist_nodes"]: diff --git a/snuba/admin/clickhouse/migration_checks.py b/snuba/admin/clickhouse/migration_checks.py index 895819c32bf..4663d79cd18 100644 --- a/snuba/admin/clickhouse/migration_checks.py +++ b/snuba/admin/clickhouse/migration_checks.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod +from collections.abc import Mapping, Sequence from dataclasses import dataclass from enum import Enum -from typing import List, Mapping, Optional, Sequence, Set, Tuple, Union from snuba.migrations.groups import MigrationGroup, get_group_loader from snuba.migrations.policies import MigrationPolicy @@ -28,7 +28,7 @@ class ReverseReason(Enum): @dataclass class Result: allowed: bool - reason: Optional[Union[RunReason, ReverseReason]] = None + reason: RunReason | ReverseReason | None = None def __post_init__(self) -> None: if self.allowed and self.reason: @@ -39,12 +39,12 @@ def __post_init__(self) -> None: @dataclass class RunResult(Result): - reason: Optional[RunReason] = None + reason: RunReason | None = None @dataclass class ReverseResult(Result): - reason: Optional[ReverseReason] = None + reason: ReverseReason | None = None @dataclass @@ -154,8 +154,8 @@ def can_reverse(self, migration_key: MigrationKey) -> ReverseResult: def run_migration_checks_and_policies( - group_policies: Mapping[str, Set[MigrationPolicy]], runner: Runner -) -> Sequence[Tuple[MigrationGroup, Sequence[MigrationData]]]: + group_policies: Mapping[str, set[MigrationPolicy]], runner: Runner +) -> Sequence[tuple[MigrationGroup, Sequence[MigrationData]]]: """ Runs the policies for the given groups in addition to status checks for all groups. @@ -163,17 +163,19 @@ def run_migration_checks_and_policies( Returns the results of those checks along with the statuses for the migrations. """ - group_results: List[Tuple[MigrationGroup, Sequence[MigrationData]]] = [] + group_results: list[tuple[MigrationGroup, Sequence[MigrationData]]] = [] - for group, migrations in runner.show_all([g for g in group_policies.keys()]): - migration_ids: List[MigrationData] = [] + for group, migrations in runner.show_all(list(group_policies.keys())): + migration_ids: list[MigrationData] = [] status_checker = StatusChecker(group, migrations) policies = group_policies[group.value] def do_checking( migration_key: MigrationKey, - ) -> Tuple[RunResult, ReverseResult]: + status_checker: StatusChecker = status_checker, + policies: set[MigrationPolicy] = policies, + ) -> tuple[RunResult, ReverseResult]: run_result = status_checker.can_run(migration_key) reverse_result = status_checker.can_reverse(migration_key) diff --git a/snuba/admin/clickhouse/nodes.py b/snuba/admin/clickhouse/nodes.py index 59a2f4b8536..39905b3124f 100644 --- a/snuba/admin/clickhouse/nodes.py +++ b/snuba/admin/clickhouse/nodes.py @@ -1,6 +1,7 @@ from __future__ import annotations -from typing import Optional, Sequence, TypedDict +from collections.abc import Sequence +from typing import TypedDict import structlog @@ -14,18 +15,18 @@ logger = structlog.get_logger().bind(module=__name__) -Node = TypedDict("Node", {"host": str, "port": int}) -Storage = TypedDict( - "Storage", - { - "storage_name": str, - "local_table_name": str, - "local_nodes": Sequence[Node], - "dist_nodes": Sequence[Node], - "query_node": Optional[Node], - }, -) +class Node(TypedDict): + host: str + port: int + + +class Storage(TypedDict): + storage_name: str + local_table_name: str + local_nodes: Sequence[Node] + dist_nodes: Sequence[Node] + query_node: Node | None def _get_local_table_name(storage_key: StorageKey) -> str: @@ -46,19 +47,16 @@ def _get_nodes(storage_key: StorageKey, local: bool = True) -> Sequence[Node]: # The get_nodes cluster methods would result in an error because # discover is not a single node, but also does not belong to any cluster. return [] - else: - return [ - {"host": node.host_name, "port": node.port} - for node in ( - cluster.get_local_nodes() if local else cluster.get_distributed_nodes() - ) - ] + return [ + {"host": node.host_name, "port": node.port} + for node in (cluster.get_local_nodes() if local else cluster.get_distributed_nodes()) + ] except (AssertionError, KeyError, UndefinedClickhouseCluster) as e: logger.warning(str(e), storage_key=storage_key.value, local=local) return [] -def _get_query_node(storage_key: StorageKey) -> Optional[Node]: +def _get_query_node(storage_key: StorageKey) -> Node | None: try: cluster = get_storage(storage_key).get_cluster() query_node = cluster.get_query_node() diff --git a/snuba/admin/clickhouse/profile_events.py b/snuba/admin/clickhouse/profile_events.py index 6873f7411d6..5032e20ecff 100644 --- a/snuba/admin/clickhouse/profile_events.py +++ b/snuba/admin/clickhouse/profile_events.py @@ -1,7 +1,7 @@ import json import socket import time -from typing import Dict, List, cast +from typing import cast import structlog from flask import g @@ -26,14 +26,14 @@ def gather_profile_events(query_trace: TraceOutput, storage: str) -> None: query_trace: TraceOutput object to update with profile events storage: Storage identifier """ - profile_events_raw_sql = "SELECT ProfileEvents FROM system.query_log WHERE query_id = '{}' AND type = 'QueryFinish'" + profile_events_raw_sql = ( + "SELECT ProfileEvents FROM system.query_log WHERE query_id = '{}' AND type = 'QueryFinish'" + ) for query_trace_data in parse_trace_for_query_ids(query_trace): sql = profile_events_raw_sql.format(query_trace_data.query_id) logger.info( - "Gathering profile event using host: {}, port = {}, storage = {}, sql = {}, g.user = {}".format( - query_trace_data.host, query_trace_data.port, storage, sql, g.user - ) + f"Gathering profile event using host: {query_trace_data.host}, port = {query_trace_data.port}, storage = {storage}, sql = {sql}, g.user = {g.user}" ) system_query_result = None @@ -47,6 +47,7 @@ def gather_profile_events(query_trace: TraceOutput, storage: str) -> None: storage, sql, False, + False, g.user, ) except InvalidNodeError as exc: @@ -65,9 +66,7 @@ def gather_profile_events(query_trace: TraceOutput, storage: str) -> None: if system_query_result is not None and len(system_query_result.results) > 0: query_trace.profile_events_meta.append(system_query_result.meta) - query_trace.profile_events_profile = cast( - Dict[str, int], system_query_result.profile - ) + query_trace.profile_events_profile = cast(dict[str, int], system_query_result.profile) columns = system_query_result.meta if columns: res = {} @@ -82,19 +81,19 @@ def gather_profile_events(query_trace: TraceOutput, storage: str) -> None: def hostname_resolves(hostname: str) -> bool: try: socket.gethostbyname(hostname) - except socket.error: + except OSError: return False else: return True -def parse_trace_for_query_ids(trace_output: TraceOutput) -> List[QueryTraceData]: +def parse_trace_for_query_ids(trace_output: TraceOutput) -> list[QueryTraceData]: summarized_trace_output = trace_output.summarized_trace_output node_name_to_query_id = { node_name: query_summary.query_id for node_name, query_summary in summarized_trace_output.query_summaries.items() } - logger.info("node to query id mapping: {}".format(node_name_to_query_id)) + logger.info(f"node to query id mapping: {node_name_to_query_id}") return [ QueryTraceData( host=node_name if hostname_resolves(node_name) else "127.0.0.1", diff --git a/snuba/admin/clickhouse/querylog.py b/snuba/admin/clickhouse/querylog.py index f78c5e9d6a7..e9c8c56dac4 100644 --- a/snuba/admin/clickhouse/querylog.py +++ b/snuba/admin/clickhouse/querylog.py @@ -45,11 +45,11 @@ def _get_clickhouse_threads() -> int: int(config_threads) if config_threads is not None else _MAX_CH_THREADS, _MAX_CH_THREADS, ) - except ValueError: + except ValueError as e: # in case the config is set incorrectly raise BadThreadsValue( f"{config_threads} is not a valid configuration option for Clickhouse `max_threads`" - ) + ) from e def __run_querylog_query(query: str) -> ClickhouseResult: diff --git a/snuba/admin/clickhouse/system_queries.py b/snuba/admin/clickhouse/system_queries.py index 4db2123db41..b63258a1dd7 100644 --- a/snuba/admin/clickhouse/system_queries.py +++ b/snuba/admin/clickhouse/system_queries.py @@ -243,7 +243,7 @@ def is_query_show(sql_query: str) -> bool: """ sql_query = " ".join(sql_query.split()) match = SHOW_QUERY_RE.match(sql_query) - return True if match else False + return bool(match) def is_query_describe(sql_query: str) -> bool: @@ -252,7 +252,7 @@ def is_query_describe(sql_query: str) -> bool: """ sql_query = " ".join(sql_query.split()) match = DESCRIBE_QUERY_RE.match(sql_query) - return True if match else False + return bool(match) def is_system_command(sql_query: str) -> bool: @@ -273,7 +273,7 @@ def is_query_optimize(sql_query: str) -> bool: """ sql_query = " ".join(sql_query.split()) match = OPTIMIZE_QUERY_RE.match(sql_query) - return True if match else False + return bool(match) def is_query_alter(sql_query: str) -> bool: @@ -282,7 +282,7 @@ def is_query_alter(sql_query: str) -> bool: """ sql_query = " ".join(sql_query.split()) match = ALTER_QUERY_RE.match(sql_query) - return True if match else False + return bool(match) def is_query_drop(sql_query: str) -> bool: @@ -291,7 +291,7 @@ def is_query_drop(sql_query: str) -> bool: """ sql_query = " ".join(sql_query.split()) match = DROP_TABLE_QUERY_RE.match(sql_query) - return True if match else False + return bool(match) def validate_query( @@ -368,7 +368,7 @@ def run_system_query_on_host_with_sql( # Don't send error to Snuba if it is an unknown table or column as it # will be too noisy if exc.code in (ErrorCodes.UNKNOWN_TABLE, ErrorCodes.UNKNOWN_IDENTIFIER): - raise InvalidCustomQuery(f"Invalid query: {exc.message} {exc.code}") + raise InvalidCustomQuery(f"Invalid query: {exc.message} {exc.code}") from exc raise finally: diff --git a/snuba/admin/clickhouse/trace_log_parsing.py b/snuba/admin/clickhouse/trace_log_parsing.py index 9255af58808..2947fca37ad 100644 --- a/snuba/admin/clickhouse/trace_log_parsing.py +++ b/snuba/admin/clickhouse/trace_log_parsing.py @@ -2,7 +2,7 @@ import re from dataclasses import dataclass -from typing import Any +from typing import Any, Protocol # [ spans-clickhouse-1 ] [ 65011 ] {0.21246445055947638} default.spans_optimized_v2_traces (aacb1a4f-32d0-49ea-8985-9c0d92a079ae) (SelectExecutor): Index `bf_attr_str_5` has dropped 0/2199 granules. INDEX_MATCHER_RE = re.compile( @@ -200,7 +200,12 @@ class TracingSummary: query_summaries: dict[str, QuerySummary] -line_types = [ +class LogLineType(Protocol): + @staticmethod + def from_log(log_line: str) -> Any: ... + + +line_types: list[type[LogLineType]] = [ IndexSummary, SelectSummary, StreamSummary, @@ -224,7 +229,7 @@ def summarize_trace_output(raw_trace_logs: str) -> TracingSummary: query_summary = summary.query_summaries[line["node_name"]] for line_type in line_types: - parsed_line = line_type.from_log(line["log_content"]) # type: ignore + parsed_line = line_type.from_log(line["log_content"]) if parsed_line is not None: attr_name = line_type.__name__.lower().replace("summary", "") + "_summaries" if getattr(query_summary, attr_name) is None: diff --git a/snuba/admin/clickhouse/tracing.py b/snuba/admin/clickhouse/tracing.py index f0492bc8610..6891c66790b 100644 --- a/snuba/admin/clickhouse/tracing.py +++ b/snuba/admin/clickhouse/tracing.py @@ -1,9 +1,10 @@ from __future__ import annotations import math +from collections.abc import Mapping from dataclasses import dataclass from datetime import datetime -from typing import Any, Mapping +from typing import Any, cast from uuid import UUID from snuba.admin.clickhouse.common import ( @@ -49,7 +50,7 @@ def run_query_and_get_trace( return TraceOutput( trace_output=query_result.trace_output, summarized_trace_output=summarized_trace_output, - cols=query_result.meta, # type: ignore + cols=cast("list[tuple[str, str]]", query_result.meta), num_rows_result=len(query_result.results), result=list(map(scrub_row, query_result.results)), profile_events_results={}, diff --git a/snuba/admin/dead_letter_queue/__init__.py b/snuba/admin/dead_letter_queue/__init__.py index f5db026ea67..3917293d555 100644 --- a/snuba/admin/dead_letter_queue/__init__.py +++ b/snuba/admin/dead_letter_queue/__init__.py @@ -1,22 +1,25 @@ from __future__ import annotations -from typing import NamedTuple, Optional, Sequence, TypedDict +from collections.abc import Sequence +from typing import NamedTuple, TypedDict from snuba import settings from snuba.datasets.slicing import is_storage_set_sliced from snuba.datasets.storage import WritableTableStorage from snuba.datasets.storages.factory import get_writable_storages -Topic = TypedDict( - "Topic", - {"logicalName": str, "physicalName": str, "slice": Optional[int], "storage": str}, -) + +class Topic(TypedDict): + logicalName: str + physicalName: str + slice: int | None + storage: str class DlqTopic(NamedTuple): logical_name: str physical_name: str - slice_id: Optional[int] + slice_id: int | None storage: str def to_json(self) -> Topic: @@ -51,10 +54,9 @@ def get_dlq_topics() -> Sequence[Topic]: return [t.to_json() for t in dlq_topics] -def get_slices(storage: WritableTableStorage) -> Sequence[Optional[int]]: +def get_slices(storage: WritableTableStorage) -> Sequence[int | None]: storage_set_key = storage.get_storage_set_key() if is_storage_set_sliced(storage_set_key): return list(range(settings.SLICED_STORAGE_SETS[storage_set_key.value])) - else: - return [None] + return [None] diff --git a/snuba/admin/google.py b/snuba/admin/google.py index cb89af5826a..32bf62848df 100644 --- a/snuba/admin/google.py +++ b/snuba/admin/google.py @@ -1,4 +1,3 @@ -from typing import Optional from urllib.parse import urlencode import structlog @@ -26,7 +25,7 @@ def __init__(self, service: Resource = None) -> None: except Exception as e: logger.exception(e) - def _get_group_id(self, group_email: str) -> Optional[str]: + def _get_group_id(self, group_email: str) -> str | None: if not self.initialized: return None @@ -53,7 +52,7 @@ def _check_transitive_membership(self, group_resource_name: str, member: str) -> return False try: - query_params = urlencode({"query": "member_key_id == '{}'".format(member)}) + query_params = urlencode({"query": f"member_key_id == '{member}'"}) request = ( self.service.groups() .memberships() diff --git a/snuba/admin/jwt.py b/snuba/admin/jwt.py index c402c09a1ba..4d32337ec0c 100644 --- a/snuba/admin/jwt.py +++ b/snuba/admin/jwt.py @@ -1,4 +1,4 @@ -from typing import Any, Optional +from typing import Any import jwt import requests @@ -6,7 +6,7 @@ from snuba import settings from snuba.admin.user import AdminUser -CERTS: Optional[Any] = None +CERTS: Any | None = None def _certs() -> Any: diff --git a/snuba/admin/kafka/topics.py b/snuba/admin/kafka/topics.py index 89888f55679..6d014de5b4c 100644 --- a/snuba/admin/kafka/topics.py +++ b/snuba/admin/kafka/topics.py @@ -1,5 +1,6 @@ import json -from typing import Any, Sequence +from collections.abc import Sequence +from typing import Any from confluent_kafka.admin import AdminClient diff --git a/snuba/admin/migrations_policies.py b/snuba/admin/migrations_policies.py index 00b4cca61a5..b65e3b85049 100644 --- a/snuba/admin/migrations_policies.py +++ b/snuba/admin/migrations_policies.py @@ -1,8 +1,9 @@ from __future__ import annotations from collections import defaultdict +from collections.abc import Callable, MutableMapping from functools import wraps -from typing import Any, Callable, Dict, MutableMapping, Set +from typing import Any from flask import Response, g, jsonify, make_response, request @@ -26,12 +27,12 @@ def get_migration_group_policies( user: AdminUser, -) -> Dict[str, Set[MigrationPolicy]]: +) -> dict[str, set[MigrationPolicy]]: """ Creates a mapping of migration groups to policies based on a user's roles. """ - group_policies: MutableMapping[str, Set[str]] = defaultdict(set) + group_policies: MutableMapping[str, set[str]] = defaultdict(set) allowed_groups = [group.value for group in get_active_migration_groups()] for role in user.roles: @@ -76,16 +77,16 @@ def str_to_bool(s: str) -> bool: dry_run = request.args.get("dry_run", False, type=str_to_bool) if not dry_run: - if action == "run": - if not any(policy.can_run(migration_key) for policy in policies): - return make_response( - jsonify({"error": "Group not allowed run policy"}), 403 - ) - elif action == "reverse": - if not any(policy.can_reverse(migration_key) for policy in policies): - return make_response( - jsonify({"error": "Group not allowed reverse policy"}), 403 - ) + if action == "run" and not any( + policy.can_run(migration_key) for policy in policies + ): + return make_response(jsonify({"error": "Group not allowed run policy"}), 403) + if action == "reverse" and not any( + policy.can_reverse(migration_key) for policy in policies + ): + return make_response( + jsonify({"error": "Group not allowed reverse policy"}), 403 + ) return f(*args, **kwargs) return check_group_perms diff --git a/snuba/admin/notifications/slack/client.py b/snuba/admin/notifications/slack/client.py index b7c40589101..94f34568bdf 100644 --- a/snuba/admin/notifications/slack/client.py +++ b/snuba/admin/notifications/slack/client.py @@ -1,4 +1,5 @@ -from typing import Any, MutableMapping, Optional +from collections.abc import MutableMapping +from typing import Any import requests import structlog @@ -6,8 +7,8 @@ logger = structlog.get_logger().bind(module=__name__) -class SlackClient(object): - def __init__(self, channel_id: Optional[str] = None, token: Optional[str] = None) -> None: +class SlackClient: + def __init__(self, channel_id: str | None = None, token: str | None = None) -> None: self.__channel_id = channel_id self.__token = token @@ -51,7 +52,7 @@ def post_file( file_name: str, file_path: str, file_type: str, - initial_comment: Optional[str] = None, + initial_comment: str | None = None, ) -> None: headers = { "Authorization": f"Bearer {self.__token}", @@ -62,17 +63,17 @@ def post_file( "initial_comment": initial_comment, } - files = { - "file": (file_name, open(file_path, "rb"), file_type), - } - try: - resp = requests.post( - "https://slack.com/api/files.upload", - headers=headers, - data=data, - files=files, - ) + with open(file_path, "rb") as file_obj: + files = { + "file": (file_name, file_obj, file_type), + } + resp = requests.post( + "https://slack.com/api/files.upload", + headers=headers, + data=data, + files=files, + ) except Exception as exc: logger.error(exc, exc_info=True) return diff --git a/snuba/admin/notifications/slack/utils.py b/snuba/admin/notifications/slack/utils.py index 6144d44b68f..c404b952bd8 100644 --- a/snuba/admin/notifications/slack/utils.py +++ b/snuba/admin/notifications/slack/utils.py @@ -1,5 +1,5 @@ import os -from typing import Any, Dict, List, Optional, Union +from typing import Any import sentry_sdk @@ -12,7 +12,7 @@ ) -def build_blocks(data: Any, action: AuditLogAction, timestamp: str, user: str) -> List[Any]: +def build_blocks(data: Any, action: AuditLogAction, timestamp: str, user: str) -> list[Any]: if action in RUNTIME_CONFIG_ACTIONS: text = build_runtime_config_text(data, action) elif action in MIGRATION_ACTIONS: @@ -30,25 +30,22 @@ def build_blocks(data: Any, action: AuditLogAction, timestamp: str, user: str) - return [section, build_context(user, timestamp, action)] -def build_configurable_component_changed_text(data: Any, action: AuditLogAction) -> Optional[str]: +def build_configurable_component_changed_text(data: Any, action: AuditLogAction) -> str | None: base = f"*Resource {data['resource_identifier']} Configurable Component {data['configurable_component_class_name']} Changed:*" if action == AuditLogAction.CONFIGURABLE_COMPONENT_DELETE: removed = f"~```'{data['configurable_component_class_name']}.{data['key']}({data.get('params', {})})'```~" return f"{base} :put_litter_in_its_place:\n\n{removed}" - elif action == AuditLogAction.CONFIGURABLE_COMPONENT_UPDATE: + if action == AuditLogAction.CONFIGURABLE_COMPONENT_UPDATE: updated = f"```'{data['configurable_component_class_name']}.{data['key']}({data.get('params', {})})' = '{data['value']}'```" return f"{base} :up: :date:\n\n{updated}" - else: - # todo: raise error, cause slack won't accept this - # if it is none - sentry_sdk.capture_message( - f"Unknown action: {action.value} with data: {data}", level="error" - ) - return f"{action.value}: {data}" + # todo: raise error, cause slack won't accept this + # if it is none + sentry_sdk.capture_message(f"Unknown action: {action.value} with data: {data}", level="error") + return f"{action.value}: {data}" -def build_runtime_config_text(data: Any, action: AuditLogAction) -> Optional[str]: +def build_runtime_config_text(data: Any, action: AuditLogAction) -> str | None: base = "*Runtime Config Option:*" removed = f"~```{{'{data['option']}': {data.get('old')}}}```~" added = f"```{{'{data['option']}': {data.get('new')}}}```" @@ -56,17 +53,16 @@ def build_runtime_config_text(data: Any, action: AuditLogAction) -> Optional[str if action == AuditLogAction.REMOVED_OPTION: return f"{base} :put_litter_in_its_place:\n\n {removed}" - elif action == AuditLogAction.ADDED_OPTION: + if action == AuditLogAction.ADDED_OPTION: return f"{base} :new:\n\n {added}" - elif action == AuditLogAction.UPDATED_OPTION: + if action == AuditLogAction.UPDATED_OPTION: return f"{base} :up: :date:\n\n {updated}" - else: - # todo: raise error, cause slack won't accept this - # if it is none - return None + # todo: raise error, cause slack won't accept this + # if it is none + return None -def build_migration_run_text(data: Any, action: AuditLogAction) -> Optional[str]: +def build_migration_run_text(data: Any, action: AuditLogAction) -> str | None: if action in [ AuditLogAction.RAN_MIGRATION_COMPLETED, AuditLogAction.RAN_MIGRATION_FAILED, @@ -93,7 +89,7 @@ def build_migration_run_text(data: Any, action: AuditLogAction) -> Optional[str] def build_context( user: str, timestamp: str, action: AuditLogAction -) -> Dict[str, Union[str, List[Dict[str, str]]]]: +) -> dict[str, str | list[dict[str, str]]]: url = f"{settings.ADMIN_URL}/#auditlog" environ = os.environ.get("SENTRY_ENVIRONMENT") or "unknown environment" return { diff --git a/snuba/admin/production_queries/prod_queries.py b/snuba/admin/production_queries/prod_queries.py index 91f25c219e1..5f9c12772e7 100644 --- a/snuba/admin/production_queries/prod_queries.py +++ b/snuba/admin/production_queries/prod_queries.py @@ -1,4 +1,4 @@ -from typing import Any, Dict +from typing import Any from flask import Response @@ -15,7 +15,7 @@ from snuba.web.views import dataset_query -def run_snql_query(body: Dict[str, Any], user: str) -> Response: +def run_snql_query(body: dict[str, Any], user: str) -> Response: """ Validates, audit logs, and executes given query. """ @@ -36,7 +36,7 @@ def run_query_with_audit(query: str, user: str) -> Response: return run_query_with_audit(body["query"], user) -def _validate_projects_in_query(body: Dict[str, Any], dataset: Dataset, is_mql: bool) -> None: +def _validate_projects_in_query(body: dict[str, Any], dataset: Dataset, is_mql: bool) -> None: """ Validates that the projects accessed by the query are allowed to be accessed. """ @@ -65,7 +65,7 @@ def _validate_projects_in_query(body: Dict[str, Any], dataset: Dataset, is_mql: ) -def run_mql_query(body: Dict[str, Any], user: str) -> Response: +def run_mql_query(body: dict[str, Any], user: str) -> Response: """ Validates, audit logs, and executes given query. """ diff --git a/snuba/admin/rpc/rpc_queries.py b/snuba/admin/rpc/rpc_queries.py index 3178591f45d..cbaac3e508a 100644 --- a/snuba/admin/rpc/rpc_queries.py +++ b/snuba/admin/rpc/rpc_queries.py @@ -1,14 +1,14 @@ -from typing import Any, List, Set +from typing import Any from snuba import settings from snuba.web.rpc.common.exceptions import BadSnubaRPCRequestException -def _validate_projects_in_query(project_ids: List[int]) -> None: +def _validate_projects_in_query(project_ids: list[int]) -> None: if settings.DEBUG and len(settings.ADMIN_ALLOWED_PROD_PROJECTS) == 0: return - allowed_projects: Set[int] = set(settings.ADMIN_ALLOWED_PROD_PROJECTS) - query_projects: Set[int] = set(project_ids) + allowed_projects: set[int] = set(settings.ADMIN_ALLOWED_PROD_PROJECTS) + query_projects: set[int] = set(project_ids) if len(query_projects - allowed_projects) > 0: raise BadSnubaRPCRequestException(f"Project IDs {query_projects} are not allowed") diff --git a/snuba/admin/runtime_config/__init__.py b/snuba/admin/runtime_config/__init__.py index 4eed1eb3f5d..7b31948a3dd 100644 --- a/snuba/admin/runtime_config/__init__.py +++ b/snuba/admin/runtime_config/__init__.py @@ -1,22 +1,19 @@ -from typing import Optional, TypedDict, Union +from typing import TypedDict -ConfigType = Union[str, int, float] +ConfigType = str | int | float -ConfigChange = TypedDict( - "ConfigChange", - { - "key": str, - "timestamp": float, - "user": Optional[str], - "before": Optional[str], - "beforeType": Optional[str], - "after": Optional[str], - "afterType": Optional[str], - }, -) +class ConfigChange(TypedDict): + key: str + timestamp: float + user: str | None + before: str | None + beforeType: str | None + after: str | None + afterType: str | None -def get_config_type_from_value(value: Optional[Union[str, int, float]]) -> Optional[str]: + +def get_config_type_from_value(value: str | int | float | None) -> str | None: if value is None: return None diff --git a/snuba/admin/tool_policies.py b/snuba/admin/tool_policies.py index c6268233f71..c0d29f2adf4 100644 --- a/snuba/admin/tool_policies.py +++ b/snuba/admin/tool_policies.py @@ -1,8 +1,9 @@ from __future__ import annotations +from collections.abc import Callable from enum import Enum from functools import wraps -from typing import Any, Callable +from typing import Any from flask import Response, g, jsonify, make_response diff --git a/snuba/admin/user.py b/snuba/admin/user.py index c6a0aab7b53..2ec62b2bb9f 100644 --- a/snuba/admin/user.py +++ b/snuba/admin/user.py @@ -1,5 +1,5 @@ +from collections.abc import Sequence from dataclasses import dataclass, field -from typing import Sequence from snuba.admin.auth_roles import Role diff --git a/snuba/admin/views.py b/snuba/admin/views.py index daab95bc0fa..a6110431153 100644 --- a/snuba/admin/views.py +++ b/snuba/admin/views.py @@ -2,10 +2,11 @@ import io import sys +from collections.abc import Mapping, Sequence from contextlib import redirect_stdout from dataclasses import asdict from datetime import datetime -from typing import Any, List, Mapping, Optional, Sequence, Tuple, Type, cast +from typing import Any, cast import sentry_sdk import simplejson as json @@ -198,7 +199,7 @@ def migrations_groups() -> Response: group_policies = get_migration_group_policies(g.user) allowed_groups = group_policies.keys() - res: List[Mapping[str, str | Sequence[Mapping[str, str | bool]]]] = [] + res: list[Mapping[str, str | Sequence[Mapping[str, str | bool]]]] = [] if not allowed_groups: return make_response(jsonify(res), 200) @@ -437,7 +438,7 @@ def clickhouse_system_query() -> Response: host, port, storage, raw_sql, sudo_mode, clusterless_mode, g.user ) rows = [] - rows, columns = cast(List[List[str]], result.results), result.meta + rows, columns = cast(list[list[str]], result.results), result.meta if columns is not None: res = {} @@ -834,28 +835,27 @@ def configs() -> Response: return Response(json.dumps(config), 200, {"Content-Type": "application/json"}) - else: - descriptions = state.get_all_config_descriptions() + descriptions = state.get_all_config_descriptions() - raw_configs: Sequence[Tuple[str, Any]] = state.get_raw_configs().items() + raw_configs: Sequence[tuple[str, Any]] = state.get_raw_configs().items() - sorted_configs = sorted(raw_configs, key=lambda c: c[0]) + sorted_configs = sorted(raw_configs, key=lambda c: c[0]) - config_data = [ - { - "key": k, - "value": str(v) if v is not None else None, - "description": str(descriptions.get(k)) if k in descriptions else None, - "type": get_config_type_from_value(v), - } - for (k, v) in sorted_configs - ] + config_data = [ + { + "key": k, + "value": str(v) if v is not None else None, + "description": str(descriptions.get(k)) if k in descriptions else None, + "type": get_config_type_from_value(v), + } + for (k, v) in sorted_configs + ] - return Response( - json.dumps(config_data), - 200, - {"Content-Type": "application/json"}, - ) + return Response( + json.dumps(config_data), + 200, + {"Content-Type": "application/json"}, + ) @application.route("/all_config_descriptions", methods=["GET"]) @@ -885,76 +885,75 @@ def config(config_key: str) -> Response: audit_log.record( user or "", AuditLogAction.REMOVED_OPTION, - {"option": config_key, "old": str(old) if not old else old}, + {"option": config_key, "old": old if old else str(old)}, notify=True, ) return Response("", 200) - else: - # PUT currently only supports editing existing config when old and - # new types match. Does not currently support passing force to - # set_config to override the type check. + # PUT currently only supports editing existing config when old and + # new types match. Does not currently support passing force to + # set_config to override the type check. - user = request.headers.get(USER_HEADER_KEY) - data = json.loads(request.data) + user = request.headers.get(USER_HEADER_KEY) + data = json.loads(request.data) - # Get the previous value for notifications - old = state.get_uncached_config(config_key) + # Get the previous value for notifications + old = state.get_uncached_config(config_key) - try: - new_value = data["value"] - new_desc = data["description"] + try: + new_value = data["value"] + new_desc = data["description"] - assert isinstance(config_key, str), "Invalid key" - assert isinstance(new_value, str), "Invalid value" - assert config_key != "", "Key cannot be empty string" + assert isinstance(config_key, str), "Invalid key" + assert isinstance(new_value, str), "Invalid value" + assert config_key != "", "Key cannot be empty string" - state.set_config( - config_key, - new_value, - user=user, - ) - state.set_config_description(config_key, new_desc, user=user) + state.set_config( + config_key, + new_value, + user=user, + ) + state.set_config_description(config_key, new_desc, user=user) - except (KeyError, AssertionError) as exc: - return Response( - json.dumps({"error": f"Invalid config: {str(exc)}"}), - 400, - {"Content-Type": "application/json"}, - ) - except state.MismatchedTypeException: - return Response( - json.dumps({"error": "Mismatched type"}), - 400, - {"Content-Type": "application/json"}, - ) + except (KeyError, AssertionError) as exc: + return Response( + json.dumps({"error": f"Invalid config: {str(exc)}"}), + 400, + {"Content-Type": "application/json"}, + ) + except state.MismatchedTypeException: + return Response( + json.dumps({"error": "Mismatched type"}), + 400, + {"Content-Type": "application/json"}, + ) - # Value was updated successfully, refetch and return it - evaluated_value = state.get_uncached_config(config_key) - assert evaluated_value is not None - evaluated_type = get_config_type_from_value(evaluated_value) + # Value was updated successfully, refetch and return it + evaluated_value = state.get_uncached_config(config_key) + assert evaluated_value is not None + evaluated_type = get_config_type_from_value(evaluated_value) - # Send notification - audit_log.record( - user or "", - AuditLogAction.UPDATED_OPTION, - { - "option": config_key, - "old": str(old) if not old else old, - "new": evaluated_value, - }, - notify=True, - ) + # Send notification + audit_log.record( + user or "", + AuditLogAction.UPDATED_OPTION, + { + "option": config_key, + "old": old if old else str(old), + "new": evaluated_value, + }, + notify=True, + ) - config = { - "key": config_key, - "value": str(evaluated_value), - "description": state.get_config_description(config_key), - "type": evaluated_type, - } + config = { + "key": config_key, + "value": str(evaluated_value), + "description": state.get_config_description(config_key), + "type": evaluated_type, + } - return Response(json.dumps(config), 200, {"Content-Type": "application/json"}) + return Response(json.dumps(config), 200, {"Content-Type": "application/json"}) @application.route("/config_auditlog") @@ -963,9 +962,9 @@ def config_changes() -> Response: def serialize( key: str, ts: float, - user: Optional[str], - before: Optional[ConfigType], - after: Optional[ConfigType], + user: str | None, + before: ConfigType | None, + after: ConfigType | None, ) -> ConfigChange: return { "key": key, @@ -1120,7 +1119,7 @@ def set_configuration() -> Response: notify=True, ) return Response("", 200) - elif request.method == "POST": + if request.method == "POST": try: value = data["value"] assert isinstance(value, str), "Invalid value" @@ -1281,7 +1280,7 @@ def list_rpc_endpoints() -> Response: @check_tool_perms(tools=[AdminTools.RPC_ENDPOINTS]) def execute_rpc_endpoint(endpoint_name: str, version: str) -> Response: try: - endpoint_class: Type[RPCEndpoint[Any, Any]] = RPCEndpoint.get_from_name( + endpoint_class: type[RPCEndpoint[Any, Any]] = RPCEndpoint.get_from_name( endpoint_name, version ) except InvalidConfigKeyError: @@ -1407,9 +1406,8 @@ def delete() -> Response: from traceback import format_exception return make_response(jsonify({"error": format_exception(e)}), 500) - else: - sentry_sdk.capture_exception(e) - return make_response(jsonify({"error": "unexpected internal error"}), 500) + sentry_sdk.capture_exception(e) + return make_response(jsonify({"error": "unexpected internal error"}), 500) return Response(json.dumps(delete_results), 200, {"Content-Type": "application/json"}) diff --git a/snuba/admin/wsgi.py b/snuba/admin/wsgi.py index f184957246b..beadee506df 100644 --- a/snuba/admin/wsgi.py +++ b/snuba/admin/wsgi.py @@ -3,7 +3,7 @@ setup_logging() setup_sentry() -from snuba.core.initialize import initialize_snuba +from snuba.core.initialize import initialize_snuba # noqa: E402 # must run after setup_sentry initialize_snuba() -from snuba.admin.views import application # noqa +from snuba.admin.views import application # noqa: E402, F401 # WSGI entrypoint; import after init diff --git a/snuba/cleanup.py b/snuba/cleanup.py index eeca7c915ab..94bcb82f672 100644 --- a/snuba/cleanup.py +++ b/snuba/cleanup.py @@ -1,6 +1,6 @@ import logging +from collections.abc import Sequence from datetime import datetime, timedelta -from typing import Sequence from snuba import util from snuba.clickhouse.native import ClickhousePool diff --git a/snuba/cli/accepted_outcomes_consumer.py b/snuba/cli/accepted_outcomes_consumer.py index ac082d896b5..3e11c1af731 100644 --- a/snuba/cli/accepted_outcomes_consumer.py +++ b/snuba/cli/accepted_outcomes_consumer.py @@ -1,7 +1,7 @@ import json import sys +from collections.abc import Sequence from dataclasses import asdict -from typing import Optional, Sequence import click @@ -140,7 +140,7 @@ def accepted_outcomes_consumer( no_strict_offset_reset: bool, queued_max_messages_kbytes: int, queued_min_messages: int, - raw_topic: Optional[str], + raw_topic: str | None, accepted_outcomes_topic: str, bootstrap_servers: Sequence[str], accepted_outcomes_bootstrap_server: Sequence[str], @@ -149,13 +149,13 @@ def accepted_outcomes_consumer( bucket_interval: int, commit_frequency_sec: int, log_level: str, - concurrency: Optional[int], + concurrency: int | None, max_poll_interval_ms: int, health_check: str, - health_check_file: Optional[str], + health_check_file: str | None, enforce_schema: bool, - max_dlq_buffer_length: Optional[int], - join_timeout_ms: Optional[int], + max_dlq_buffer_length: int | None, + join_timeout_ms: int | None, ) -> None: """ Accepted outcomes consumer @@ -191,7 +191,7 @@ def accepted_outcomes_consumer( os.environ["RUST_LOG"] = log_level.lower() - exitcode = rust_snuba.accepted_outcomes_consumer( # type: ignore + exitcode = rust_snuba.accepted_outcomes_consumer( # type: ignore[attr-defined] consumer_group, auto_offset_reset, no_strict_offset_reset, diff --git a/snuba/cli/admin.py b/snuba/cli/admin.py index 2cbd19efcac..ce96407eb45 100644 --- a/snuba/cli/admin.py +++ b/snuba/cli/admin.py @@ -1,5 +1,4 @@ import os -from typing import Optional import click @@ -16,9 +15,9 @@ def admin( *, debug: bool, - log_level: Optional[str], + log_level: str | None, processes: int, - threads: Optional[int], + threads: int | None, backlog: int, ) -> None: from snuba import settings diff --git a/snuba/cli/api.py b/snuba/cli/api.py index 8550ba554cd..ec4f9b2dc18 100644 --- a/snuba/cli/api.py +++ b/snuba/cli/api.py @@ -1,5 +1,4 @@ import os -from typing import Optional, Union import click @@ -16,16 +15,16 @@ @click.option("--backlog", type=click.IntRange(128)) def api( *, - bind: Optional[str], + bind: str | None, debug: bool, - log_level: Optional[str], - processes: Optional[int], - threads: Optional[int], - backlog: Optional[int], + log_level: str | None, + processes: int | None, + threads: int | None, + backlog: int | None, ) -> None: from snuba import settings - port: Union[int, str] + port: int | str if bind: if ":" in bind: host, port = bind.split(":", 1) diff --git a/snuba/cli/bootstrap.py b/snuba/cli/bootstrap.py index 3c21d4f55f8..8ff6ae4697b 100644 --- a/snuba/cli/bootstrap.py +++ b/snuba/cli/bootstrap.py @@ -1,6 +1,6 @@ import logging +from collections.abc import Sequence from syslog import LOG_CRIT -from typing import Optional, Sequence import click from confluent_kafka import KafkaException @@ -30,7 +30,7 @@ def bootstrap( kafka: bool, migrate: bool, force: bool, - log_level: Optional[str] = None, + log_level: str | None = None, ) -> None: """ Warning: Not intended to be used in production yet. @@ -80,7 +80,7 @@ def bootstrap( logger.info("Connected to Kafka on attempt %d", attempts) - create_topics(client, [t for t in Topic]) + create_topics(client, list(Topic)) if migrate: check_clickhouse_connections(CLUSTERS) diff --git a/snuba/cli/bulk_load.py b/snuba/cli/bulk_load.py index c5b66713b12..6a149de0cd6 100644 --- a/snuba/cli/bulk_load.py +++ b/snuba/cli/bulk_load.py @@ -1,6 +1,5 @@ import logging from functools import partial -from typing import Optional import click import progressbar @@ -52,12 +51,12 @@ def bulk_load( *, storage_name: str, - dest_table: Optional[str], + dest_table: str | None, source: str, ignore_existing_data: bool, pre_processed: bool, show_progress: bool, - log_level: Optional[str] = None, + log_level: str | None = None, ) -> None: setup_logging(log_level) setup_sentry() @@ -88,7 +87,7 @@ def progress_callback(bar: progressbar.ProgressBar, progress: int) -> None: progress = progressbar.ProgressBar( max_value=snapshot_source.get_table_file_size(storage.get_postgres_table()) ) - progress_func: Optional[ProgressCallback] = partial(progress_callback, progress) + progress_func: ProgressCallback | None = partial(progress_callback, progress) else: progress_func = None diff --git a/snuba/cli/cleanup.py b/snuba/cli/cleanup.py index 10ce232eb62..5c2d1d82029 100644 --- a/snuba/cli/cleanup.py +++ b/snuba/cli/cleanup.py @@ -1,5 +1,3 @@ -from typing import Optional - import click from snuba.clusters.cluster import ClickhouseClientSettings @@ -52,14 +50,14 @@ @click.option("--log-level", help="Logging level to use.") def cleanup( *, - clickhouse_host: Optional[str], - clickhouse_port: Optional[int], + clickhouse_host: str | None, + clickhouse_port: int | None, clickhouse_secure: bool, - clickhouse_ca_certs: Optional[str], - clickhouse_verify: Optional[bool], + clickhouse_ca_certs: str | None, + clickhouse_verify: bool | None, dry_run: bool, storage_name: str, - log_level: Optional[str] = None, + log_level: str | None = None, ) -> None: """ Deletes stale partitions for ClickHouse tables @@ -98,4 +96,4 @@ def cleanup( connection = cluster.get_query_connection(ClickhouseClientSettings.CLEANUP) num_dropped = run_cleanup(connection, storage, database, dry_run=dry_run) - logger.info("Dropped %s partitions on %s" % (num_dropped, cluster)) + logger.info(f"Dropped {num_dropped} partitions on {cluster}") diff --git a/snuba/cli/config.py b/snuba/cli/config.py index 8c08b2e0b1c..788b1d2dc03 100644 --- a/snuba/cli/config.py +++ b/snuba/cli/config.py @@ -1,5 +1,6 @@ import json -from typing import Any, Mapping +from collections.abc import Mapping +from typing import Any import click @@ -49,8 +50,8 @@ def get(*, key: str, format: str) -> None: try: rv = state.get_raw_configs()[key] - except KeyError: - raise click.ClickException(f"Key {key!r} not found.") + except KeyError as e: + raise click.ClickException(f"Key {key!r} not found.") from e click.echo(FORMATS[format]({key: rv})) @@ -88,8 +89,8 @@ def delete(*, key: str) -> None: try: rv = state.get_raw_configs()[key] - except KeyError: - raise click.ClickException(f"Key {key!r} not found.") + except KeyError as e: + raise click.ClickException(f"Key {key!r} not found.") from e click.echo(human_fmt({key: rv})) click.confirm("\nAre you sure you want to delete this?", abort=True) diff --git a/snuba/cli/consumer.py b/snuba/cli/consumer.py index d6910abdbab..66f95851d7a 100644 --- a/snuba/cli/consumer.py +++ b/snuba/cli/consumer.py @@ -1,6 +1,7 @@ import logging import signal -from typing import Any, Optional, Sequence +from collections.abc import Sequence +from typing import Any import click import sentry_sdk @@ -169,33 +170,33 @@ def consumer( *, storage_name: str, - raw_events_topic: Optional[str], - replacements_topic: Optional[str], - commit_log_topic: Optional[str], + raw_events_topic: str | None, + replacements_topic: str | None, + commit_log_topic: str | None, consumer_group: str, bootstrap_server: Sequence[str], commit_log_bootstrap_server: Sequence[str], replacement_bootstrap_server: Sequence[str], - slice_id: Optional[int], + slice_id: int | None, max_batch_size: int, max_batch_time_ms: int, - max_insert_batch_size: Optional[int], - max_insert_batch_time_ms: Optional[int], + max_insert_batch_size: int | None, + max_insert_batch_time_ms: int | None, auto_offset_reset: str, no_strict_offset_reset: bool, queued_max_messages_kbytes: int, queued_min_messages: int, - processes: Optional[int], - input_block_size: Optional[int], - output_block_size: Optional[int], + processes: int | None, + input_block_size: int | None, + output_block_size: int | None, join_timeout: int, enforce_schema: bool, - log_level: Optional[str], - profile_path: Optional[str], + log_level: str | None, + profile_path: str | None, max_poll_interval_ms: int, - quantized_rebalance_consumer_group_delay_secs: Optional[int], - health_check_file: Optional[str], - group_instance_id: Optional[str], + quantized_rebalance_consumer_group_delay_secs: int | None, + health_check_file: str | None, + group_instance_id: str | None, ) -> None: setup_logging(log_level) setup_sentry() diff --git a/snuba/cli/devserver.py b/snuba/cli/devserver.py index 744b1ccff29..041c969c0dd 100644 --- a/snuba/cli/devserver.py +++ b/snuba/cli/devserver.py @@ -568,9 +568,8 @@ def stream(name: str, proc: subprocess.Popen[bytes]) -> None: sys.stdout.flush() rc = proc.wait() with failure_lock: - if rc != 0 and not cleanup_started.is_set(): - if not first_failure: - first_failure.append(rc) + if rc != 0 and not cleanup_started.is_set() and not first_failure: + first_failure.append(rc) except BaseException: with failure_lock: if not cleanup_started.is_set() and not first_failure: diff --git a/snuba/cli/dlq_consumer.py b/snuba/cli/dlq_consumer.py index d3c715babde..1cc1302c739 100644 --- a/snuba/cli/dlq_consumer.py +++ b/snuba/cli/dlq_consumer.py @@ -2,7 +2,7 @@ import signal import time from dataclasses import replace -from typing import Any, Optional +from typing import Any import click from arroyo import configure_metrics @@ -105,13 +105,13 @@ def dlq_consumer( no_strict_offset_reset: bool, queued_max_messages_kbytes: int, queued_min_messages: int, - processes: Optional[int], - input_block_size: Optional[int], - output_block_size: Optional[int], - log_level: Optional[str] = None, + processes: int | None, + input_block_size: int | None, + output_block_size: int | None, + log_level: str | None = None, ) -> None: shutdown_requested = False - consumer: Optional[StreamProcessor[KafkaPayload]] = None + consumer: StreamProcessor[KafkaPayload] | None = None def handler(signum: int, frame: Any) -> None: nonlocal shutdown_requested diff --git a/snuba/cli/entities.py b/snuba/cli/entities.py index f2b85223449..3af6d710b04 100644 --- a/snuba/cli/entities.py +++ b/snuba/cli/entities.py @@ -1,5 +1,3 @@ -from typing import Optional - import click from snuba.datasets.configuration.entity_builder import build_entity_from_config @@ -16,7 +14,7 @@ def __init__(self) -> None: def __indent(self) -> str: return " " * self.__current_indentation - def visit_header(self, header: Optional[str]) -> None: + def visit_header(self, header: str | None) -> None: if header is not None: click.echo(f"{self.__indent()}{header}") click.echo(f"{self.__indent()}--------------------------------") diff --git a/snuba/cli/jobs.py b/snuba/cli/jobs.py index 45af005460a..1c04d3f732e 100644 --- a/snuba/cli/jobs.py +++ b/snuba/cli/jobs.py @@ -1,4 +1,5 @@ -from typing import Any, MutableMapping, Tuple +from collections.abc import MutableMapping +from typing import Any import click @@ -36,21 +37,21 @@ def _run_job_and_echo_status(job_spec: JobSpec) -> None: @click.option("--job_id") def run_from_manifest(*, json_manifest: str, job_id: str) -> None: job_specs = list_job_specs(json_manifest) - if job_id not in job_specs.keys(): + if job_id not in job_specs: raise click.ClickException("Provide a valid job id") _run_job_and_echo_status(job_specs[job_id]) -def _parse_params(pairs: Tuple[str, ...]) -> MutableMapping[Any, Any]: - return {k: v for k, v in (pair.split("=") for pair in pairs)} +def _parse_params(pairs: tuple[str, ...]) -> MutableMapping[Any, Any]: + return dict(pair.split("=") for pair in pairs) @jobs.command() @click.option("--job_type") @click.option("--job_id") @click.argument("pairs", nargs=-1) -def run(*, job_type: str, job_id: str, pairs: Tuple[str, ...]) -> None: +def run(*, job_type: str, job_id: str, pairs: tuple[str, ...]) -> None: if not job_type or not job_id: raise click.ClickException(JOB_SPECIFICATION_ERROR_MSG) job_spec = JobSpec(job_id=job_id, job_type=job_type, params=_parse_params(pairs)) diff --git a/snuba/cli/lw_deletions_consumer.py b/snuba/cli/lw_deletions_consumer.py index 4fd8df2f89a..ae0ec0cc95d 100644 --- a/snuba/cli/lw_deletions_consumer.py +++ b/snuba/cli/lw_deletions_consumer.py @@ -1,6 +1,7 @@ import logging import signal -from typing import Any, Optional, Sequence +from collections.abc import Sequence +from typing import Any import click import sentry_sdk @@ -104,7 +105,7 @@ def lw_deletions_consumer( queued_max_messages_kbytes: int, queued_min_messages: int, log_level: str, - group_instance_id: Optional[str], + group_instance_id: str | None, no_batch: bool, ) -> None: setup_logging(log_level) @@ -114,7 +115,7 @@ def lw_deletions_consumer( sentry_sdk.set_tag("storage", storage) shutdown_requested = False - consumer: Optional[StreamProcessor[KafkaPayload]] = None + consumer: StreamProcessor[KafkaPayload] | None = None def handler(signum: int, frame: Any) -> None: nonlocal shutdown_requested diff --git a/snuba/cli/migrations.py b/snuba/cli/migrations.py index bd6ea711565..04b8cba2b18 100644 --- a/snuba/cli/migrations.py +++ b/snuba/cli/migrations.py @@ -1,5 +1,5 @@ import re -from typing import Optional, Sequence +from collections.abc import Sequence import click @@ -73,13 +73,13 @@ def list() -> None: @click.option("--check-dangerous", is_flag=True) @click.option("--log-level", help="Logging level to use.", type=click.Choice(LOG_LEVELS)) def migrate( - group: Optional[str], - readiness_state: Optional[Sequence[str]], + group: str | None, + readiness_state: Sequence[str] | None, through: str, force: bool, fake: bool, check_dangerous: bool, - log_level: Optional[str] = None, + log_level: str | None = None, ) -> None: """ If group is specified, runs all the migrations for a group (including any pending @@ -117,7 +117,7 @@ def migrate( check_dangerous=check_dangerous, ) except MigrationError as e: - raise click.ClickException(str(e)) + raise click.ClickException(str(e)) from e click.echo("Finished running migrations") @@ -137,13 +137,13 @@ def migrate( @click.option("--include-system", is_flag=True) @click.option("--log-level", help="Logging level to use.", type=click.Choice(LOG_LEVELS)) def revert( - group: Optional[str], - readiness_state: Optional[Sequence[str]], + group: str | None, + readiness_state: Sequence[str] | None, through: str, force: bool, fake: bool, include_system: bool, - log_level: Optional[str] = None, + log_level: str | None = None, ) -> None: """ If group is specified, reverse all the migrations for a group. @@ -185,7 +185,7 @@ def revert( readiness_states=readiness_states, ) except MigrationError as e: - raise click.ClickException(str(e)) + raise click.ClickException(str(e)) from e click.echo("Finished reversing migrations") @@ -207,7 +207,7 @@ def run( dry_run: bool, yes: bool, check_dangerous: bool, - log_level: Optional[str] = None, + log_level: str | None = None, ) -> None: """ Runs a single migration. @@ -237,7 +237,7 @@ def run( ) runner.run_migration(migration_key, force=force, fake=fake, check_dangerous=check_dangerous) except MigrationError as e: - raise click.ClickException(str(e)) + raise click.ClickException(str(e)) from e click.echo(f"Finished running migration {migration_key}") @@ -257,7 +257,7 @@ def reverse( fake: bool, dry_run: bool, yes: bool, - log_level: Optional[str] = None, + log_level: str | None = None, ) -> None: """ Reverses a single migration. @@ -284,7 +284,7 @@ def reverse( ) runner.reverse_migration(migration_key, force=force, fake=fake) except MigrationError as e: - raise click.ClickException(str(e)) + raise click.ClickException(str(e)) from e click.echo(f"Finished reversing migration {migration_key}") @@ -299,8 +299,8 @@ def reverse_in_progress( fake: bool, dry_run: bool, yes: bool, - group: Optional[str] = None, - log_level: Optional[str] = None, + group: str | None = None, + log_level: str | None = None, ) -> None: """ Reverses any in progress migrations for all migration groups. @@ -332,7 +332,7 @@ def reverse_in_progress( ) runner.reverse_in_progress(group=migration_group, fake=fake) except MigrationError as e: - raise click.ClickException(str(e)) + raise click.ClickException(str(e)) from e click.echo("Finished reversing in progress migrations") @@ -340,7 +340,7 @@ def reverse_in_progress( @migrations.command() @click.argument("storage_path", type=str) @click.option("--name", type=str, help="optional name for the migration") -def generate(storage_path: str, name: Optional[str] = None) -> None: +def generate(storage_path: str, name: str | None = None) -> None: """ Given a path to user-modified storage.yaml definition (inside snuba/datasets/configuration/*/storages/*.yaml), and an optional name for the migration, diff --git a/snuba/cli/offline_replacer.py b/snuba/cli/offline_replacer.py index 4710e15ec1c..2278e70b185 100644 --- a/snuba/cli/offline_replacer.py +++ b/snuba/cli/offline_replacer.py @@ -1,7 +1,6 @@ import logging import time from datetime import datetime -from typing import Optional import click import simplejson as json @@ -61,7 +60,7 @@ def offline_replacer( dry_run: bool, start_from: int, delay: int, - log_level: Optional[str] = None, + log_level: str | None = None, ) -> None: """ Executes a list of replacements taken from a file instead of taking diff --git a/snuba/cli/optimize.py b/snuba/cli/optimize.py index f5d91dbe4a9..27e671c5b4f 100644 --- a/snuba/cli/optimize.py +++ b/snuba/cli/optimize.py @@ -1,5 +1,4 @@ from datetime import UTC, timedelta -from typing import Optional import click @@ -66,14 +65,14 @@ ) def optimize( *, - clickhouse_host: Optional[str], - clickhouse_port: Optional[int], + clickhouse_host: str | None, + clickhouse_port: int | None, clickhouse_secure: bool, - clickhouse_ca_certs: Optional[str], - clickhouse_verify: Optional[bool], + clickhouse_ca_certs: str | None, + clickhouse_verify: bool | None, storage_name: str, default_parallel_threads: int, - log_level: Optional[str] = None, + log_level: str | None = None, divide_partitions: int, ) -> None: from datetime import datetime @@ -157,4 +156,4 @@ def optimize( ) tracker.delete_all_states() - logger.info("Optimized %s partitions on %s" % (num_dropped, clickhouse_host)) + logger.info(f"Optimized {num_dropped} partitions on {clickhouse_host}") diff --git a/snuba/cli/querylog_to_csv.py b/snuba/cli/querylog_to_csv.py index 497bde52738..ba1028c4436 100644 --- a/snuba/cli/querylog_to_csv.py +++ b/snuba/cli/querylog_to_csv.py @@ -1,6 +1,7 @@ import csv +from collections.abc import Sequence from datetime import datetime -from typing import NamedTuple, Optional, Sequence, Tuple +from typing import NamedTuple import click import structlog @@ -53,7 +54,7 @@ def get_query_results( databases: list[str], tables: list[str], start_time: str, - end_time: Optional[str], + end_time: str | None, ) -> str: if start_time and end_time: start = f"toDateTime('{start_time}')" @@ -81,7 +82,7 @@ def get_query_results( """ -def get_credentials() -> Tuple[str, str]: +def get_credentials() -> tuple[str, str]: # TOOO don't hardcode credentials, use settings return ("default", "") @@ -143,8 +144,8 @@ def querylog_to_csv( event_type: str, start_time: str, notify: bool, - end_time: Optional[str] = None, - log_level: Optional[str] = None, + end_time: str | None = None, + log_level: str | None = None, ) -> None: """ Use this command when you want to capture the results from the diff --git a/snuba/cli/replacer.py b/snuba/cli/replacer.py index 4750357558f..77823762716 100644 --- a/snuba/cli/replacer.py +++ b/snuba/cli/replacer.py @@ -1,5 +1,6 @@ import signal -from typing import Any, Optional, Sequence +from collections.abc import Sequence +from typing import Any import click @@ -70,7 +71,7 @@ ) def replacer( *, - replacements_topic: Optional[str], + replacements_topic: str | None, consumer_group: str, bootstrap_server: Sequence[str], storage_name: str, @@ -78,8 +79,8 @@ def replacer( no_strict_offset_reset: bool, queued_max_messages_kbytes: int, queued_min_messages: int, - log_level: Optional[str] = None, - health_check_file: Optional[str] = None, + log_level: str | None = None, + health_check_file: str | None = None, max_poll_interval_ms: int = 30000, ) -> None: from arroyo import Topic, configure_metrics diff --git a/snuba/cli/rust_consumer.py b/snuba/cli/rust_consumer.py index e153ed41296..eb1755baad9 100644 --- a/snuba/cli/rust_consumer.py +++ b/snuba/cli/rust_consumer.py @@ -1,7 +1,7 @@ import json import sys +from collections.abc import Sequence from dataclasses import asdict -from typing import Optional, Sequence import click @@ -206,9 +206,9 @@ def rust_consumer( no_strict_offset_reset: bool, queued_max_messages_kbytes: int, queued_min_messages: int, - raw_events_topic: Optional[str], - commit_log_topic: Optional[str], - replacements_topic: Optional[str], + raw_events_topic: str | None, + commit_log_topic: str | None, + replacements_topic: str | None, bootstrap_servers: Sequence[str], commit_log_bootstrap_servers: Sequence[str], replacement_bootstrap_servers: Sequence[str], @@ -216,23 +216,23 @@ def rust_consumer( max_batch_size_calculation: str, max_batch_time_ms: int, log_level: str, - concurrency: Optional[int], - clickhouse_concurrency: Optional[int], + concurrency: int | None, + clickhouse_concurrency: int | None, use_rust_processor: bool, - group_instance_id: Optional[str], + group_instance_id: str | None, max_poll_interval_ms: int, async_inserts: bool, health_check: str, - python_max_queue_depth: Optional[int], - health_check_file: Optional[str], + python_max_queue_depth: int | None, + health_check_file: str | None, enforce_schema: bool, - stop_at_timestamp: Optional[int], - batch_write_timeout_ms: Optional[int], - max_dlq_buffer_length: Optional[int], - quantized_rebalance_consumer_group_delay_secs: Optional[int], - join_timeout_ms: Optional[int], + stop_at_timestamp: int | None, + batch_write_timeout_ms: int | None, + max_dlq_buffer_length: int | None, + quantized_rebalance_consumer_group_delay_secs: int | None, + join_timeout_ms: int | None, use_row_binary: bool, - consumer_version: Optional[str], + consumer_version: str | None, ) -> None: """ Experimental alternative to `snuba consumer` @@ -270,7 +270,7 @@ def rust_consumer( # the number of inserts/sec on clickhouse clickhouse_concurrency = 2 - exitcode = rust_snuba.consumer( # type: ignore + exitcode = rust_snuba.consumer( # type: ignore[attr-defined] consumer_group, auto_offset_reset, no_strict_offset_reset, diff --git a/snuba/cli/subscriptions_executor.py b/snuba/cli/subscriptions_executor.py index be0b18e6023..03896c1ac8e 100644 --- a/snuba/cli/subscriptions_executor.py +++ b/snuba/cli/subscriptions_executor.py @@ -1,6 +1,7 @@ import signal +from collections.abc import Iterator, Sequence from contextlib import contextmanager -from typing import Any, Iterator, Optional, Sequence +from typing import Any import click import structlog @@ -90,13 +91,13 @@ def subscriptions_executor( consumer_group: str, bootstrap_server: Sequence[str], result_bootstrap_server: Sequence[str], - slice_id: Optional[int], + slice_id: int | None, total_concurrent_queries: int, auto_offset_reset: str, no_strict_offset_reset: bool, - log_level: Optional[str], - stale_threshold_seconds: Optional[int], - health_check_file: Optional[str], + log_level: str | None, + stale_threshold_seconds: int | None, + health_check_file: str | None, ) -> None: """ The subscription's executor consumes scheduled subscriptions from the scheduled @@ -168,7 +169,7 @@ def handler(signum: int, frame: Any) -> None: @contextmanager -def closing(producer: KafkaProducer) -> Iterator[Optional[KafkaProducer]]: +def closing(producer: KafkaProducer) -> Iterator[KafkaProducer | None]: try: yield producer finally: diff --git a/snuba/cli/subscriptions_scheduler.py b/snuba/cli/subscriptions_scheduler.py index d2743991882..4839950c93e 100644 --- a/snuba/cli/subscriptions_scheduler.py +++ b/snuba/cli/subscriptions_scheduler.py @@ -1,7 +1,8 @@ import logging import signal +from collections.abc import Sequence from contextlib import closing -from typing import Any, Optional, Sequence +from typing import Any import click import structlog @@ -89,10 +90,10 @@ def subscriptions_scheduler( auto_offset_reset: str, no_strict_offset_reset: bool, schedule_ttl: int, - slice_id: Optional[int], - log_level: Optional[str], - stale_threshold_seconds: Optional[int], - health_check_file: Optional[str], + slice_id: int | None, + log_level: str | None, + stale_threshold_seconds: int | None, + health_check_file: str | None, ) -> None: """ The subscriptions scheduler's job is to schedule subscriptions for a single entity. diff --git a/snuba/cli/subscriptions_scheduler_executor.py b/snuba/cli/subscriptions_scheduler_executor.py index 3d66fc4adb5..b7472e1ea8d 100644 --- a/snuba/cli/subscriptions_scheduler_executor.py +++ b/snuba/cli/subscriptions_scheduler_executor.py @@ -1,6 +1,7 @@ import signal +from collections.abc import Iterator, Sequence from contextlib import contextmanager -from typing import Any, Iterator, Optional, Sequence +from typing import Any import click import structlog @@ -87,9 +88,9 @@ def subscriptions_scheduler_executor( auto_offset_reset: str, no_strict_offset_reset: bool, schedule_ttl: int, - stale_threshold_seconds: Optional[int], - health_check_file: Optional[str], - log_level: Optional[str], + stale_threshold_seconds: int | None, + health_check_file: str | None, + log_level: str | None, ) -> None: """ Combined subscriptions scheduler and executor. Alternative to the separate scheduler and executor processes. @@ -154,7 +155,7 @@ def handler(signum: int, frame: Any) -> None: @contextmanager -def closing(producer: KafkaProducer) -> Iterator[Optional[KafkaProducer]]: +def closing(producer: KafkaProducer) -> Iterator[KafkaProducer | None]: try: yield producer finally: diff --git a/snuba/clickhouse/columns.py b/snuba/clickhouse/columns.py index 5066c9412ca..80738763e78 100644 --- a/snuba/clickhouse/columns.py +++ b/snuba/clickhouse/columns.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Sequence, Union +from collections.abc import Sequence from snuba.utils.schemas import ( JSON, @@ -86,7 +86,7 @@ class ColumnSet(BaseColumnSet): def __init__( self, - columns: Sequence[Union[Column[SchemaModifiers], tuple[str, ColumnType[SchemaModifiers]]]], + columns: Sequence[Column[SchemaModifiers] | tuple[str, ColumnType[SchemaModifiers]]], ) -> None: for column in columns: assert not isinstance(column, WildcardColumn) @@ -94,11 +94,11 @@ def __init__( super().__init__(Column.to_columns(columns)) def __repr__(self) -> str: - return "ColumnSet({})".format(repr(self.columns)) + return f"ColumnSet({repr(self.columns)})" def __add__( self, - other: Union[ColumnSet, Sequence[tuple[str, ColumnType[SchemaModifiers]]]], + other: ColumnSet | Sequence[tuple[str, ColumnType[SchemaModifiers]]], ) -> ColumnSet: if isinstance(other, ColumnSet): return ColumnSet([*self.columns, *other.columns]) diff --git a/snuba/clickhouse/errors.py b/snuba/clickhouse/errors.py index def7d31535a..40929964f6b 100644 --- a/snuba/clickhouse/errors.py +++ b/snuba/clickhouse/errors.py @@ -11,7 +11,7 @@ def format_message(self, message: str) -> str: if "Stack trace:" in message: msg = message.split("Stack trace:")[0].strip() - return 'Code: {}. {}"'.format(self.code, msg) + return f'Code: {self.code}. {msg}"' @property def code(self) -> int: diff --git a/snuba/clickhouse/escaping.py b/snuba/clickhouse/escaping.py index 0c239aa8cf9..d55206fb050 100644 --- a/snuba/clickhouse/escaping.py +++ b/snuba/clickhouse/escaping.py @@ -1,5 +1,5 @@ import re -from typing import Optional, Pattern +from re import Pattern ESCAPE_STRING_RE = re.compile(r"(['\\])") ESCAPE_COL_RE = re.compile(r"([`\\])") @@ -13,29 +13,28 @@ def escape_string(str: str) -> str: str = ESCAPE_STRING_RE.sub(r"\\\1", str) - return "'{}'".format(str) + return f"'{str}'" -def escape_expression(expr: Optional[str], regex: Pattern[str]) -> Optional[str]: +def escape_expression(expr: str | None, regex: Pattern[str]) -> str | None: if not expr: return expr - elif regex.match(expr): + if regex.match(expr): # Column/Alias is safe to use without wrapping. return expr - else: - # Column/Alias needs special characters escaped, and to be wrapped with - # backticks. If the column starts with a '-', keep that outside the - # backticks as it is not part of the column name, but used by the query - # generator to signify the sort order if we are sorting by this column. - col = ESCAPE_COL_RE.sub(r"\\\1", expr) - negate_match = NEGATE_RE.match(col) - assert negate_match is not None - return "{}`{}`".format(*negate_match.groups()) - - -def escape_alias(alias: Optional[str]) -> Optional[str]: + # Column/Alias needs special characters escaped, and to be wrapped with + # backticks. If the column starts with a '-', keep that outside the + # backticks as it is not part of the column name, but used by the query + # generator to signify the sort order if we are sorting by this column. + col = ESCAPE_COL_RE.sub(r"\\\1", expr) + negate_match = NEGATE_RE.match(col) + assert negate_match is not None + return "{}`{}`".format(*negate_match.groups()) + + +def escape_alias(alias: str | None) -> str | None: return escape_expression(alias, SAFE_ALIAS_RE) -def escape_identifier(col: Optional[str]) -> Optional[str]: +def escape_identifier(col: str | None) -> str | None: return escape_expression(col, SAFE_COL_RE) diff --git a/snuba/clickhouse/formatter/expression.py b/snuba/clickhouse/formatter/expression.py index 7d351646bb1..2c688fc62f5 100644 --- a/snuba/clickhouse/formatter/expression.py +++ b/snuba/clickhouse/formatter/expression.py @@ -1,8 +1,9 @@ import re from abc import ABC, abstractmethod +from collections.abc import Sequence from datetime import date, datetime from functools import lru_cache -from typing import Optional, Sequence, cast +from typing import cast from snuba.clickhouse.escaping import escape_alias, escape_identifier, escape_string from snuba.query.conditions import ( @@ -46,22 +47,21 @@ class ExpressionFormatterBase(ExpressionVisitor[str], ABC): the visited expression, the return value is the formatted string. """ - def __init__(self, parsing_context: Optional[ParsingContext] = None) -> None: + def __init__(self, parsing_context: ParsingContext | None = None) -> None: self._parsing_context = parsing_context if parsing_context is not None else ParsingContext() - def _alias(self, formatted_exp: str, alias: Optional[str]) -> str: + def _alias(self, formatted_exp: str, alias: str | None) -> str: if not alias: return formatted_exp - elif self._parsing_context.is_alias_present(alias): + if self._parsing_context.is_alias_present(alias): ret = escape_alias(alias) # This is for the type checker. escape_alias can return None if # we pass None. But here we do not pass None so a None return value # is not valid. assert ret is not None return ret - else: - self._parsing_context.add_alias(alias) - return f"({formatted_exp} AS {escape_alias(alias)})" + self._parsing_context.add_alias(alias) + return f"({formatted_exp} AS {escape_alias(alias)})" @abstractmethod def _format_string_literal(self, exp: Literal) -> str: @@ -88,16 +88,15 @@ def visit_literal(self, exp: Literal) -> str: return self._alias("NULL", exp.alias) if isinstance(exp.value, bool): return self._format_boolean_literal(exp) - elif isinstance(exp.value, str): + if isinstance(exp.value, str): return self._format_string_literal(exp) - elif isinstance(exp.value, (int, float)): + if isinstance(exp.value, (int, float)): return self._format_number_literal(exp) - elif isinstance(exp.value, datetime): + if isinstance(exp.value, datetime): return self._format_datetime_literal(exp) - elif isinstance(exp.value, date): + if isinstance(exp.value, date): return self._format_date_literal(exp) - else: - raise ValueError(f"Unexpected literal type {type(exp.value)}") + raise ValueError(f"Unexpected literal type {type(exp.value)}") def visit_column(self, exp: Column) -> str: ret = [] @@ -127,8 +126,7 @@ def visit_column(self, exp: Column) -> str: # parsing so the names are preserved during query processing. if exp.alias != "".join(ret_unescaped): return self._alias("".join(ret), exp.alias) - else: - return "".join(ret) + return "".join(ret) def __visit_params(self, parameters: Sequence[Expression]) -> str: ret = [p.accept(self) for p in parameters] @@ -157,11 +155,11 @@ def visit_function_call(self, exp: FunctionCall) -> str: # will interpret (1) -> 1 which will break things like 1 IN tuple(1) return self._alias(f"({self.__visit_params(exp.parameters)})", exp.alias) - elif exp.function_name == BooleanFunctions.AND: + if exp.function_name == BooleanFunctions.AND: formatted = (c.accept(self) for c in get_first_level_and_conditions(exp)) return " AND ".join(formatted) - elif exp.function_name == BooleanFunctions.OR: + if exp.function_name == BooleanFunctions.OR: formatted = (c.accept(self) for c in get_first_level_or_conditions(exp)) return f"({' OR '.join(formatted)})" @@ -231,11 +229,11 @@ def _format_boolean_literal(self, exp: Literal) -> str: def _format_datetime_literal(self, exp: Literal) -> str: value = cast(datetime, exp.value).replace(tzinfo=None, microsecond=0) - return self._alias("toDateTime('{}', 'Universal')".format(value.isoformat()), exp.alias) + return self._alias(f"toDateTime('{value.isoformat()}', 'Universal')", exp.alias) def _format_date_literal(self, exp: Literal) -> str: return self._alias( - "toDate('{}', 'Universal')".format(cast(date, exp.value).isoformat()), + f"toDate('{cast(date, exp.value).isoformat()}', 'Universal')", exp.alias, ) @@ -260,16 +258,15 @@ def _anonimize_alias(self, alias: str) -> str: # if they are input by the user, but that is better than leaking PII return _BETWEEN_SQUARE_BRACKETS_REGEX.sub("$A", alias) - def _alias(self, formatted_exp: str, alias: Optional[str]) -> str: + def _alias(self, formatted_exp: str, alias: str | None) -> str: if not alias: return formatted_exp - elif self._parsing_context.is_alias_present(alias): + if self._parsing_context.is_alias_present(alias): ret = escape_alias(alias) # This is for the type checker. escape_alias can return None if # we pass None. But here we do not pass None so a None return value # is not valid. assert ret is not None return ret - else: - self._parsing_context.add_alias(alias) - return f"({formatted_exp} AS {escape_alias(self._anonimize_alias(alias))})" + self._parsing_context.add_alias(alias) + return f"({formatted_exp} AS {escape_alias(self._anonimize_alias(alias))})" diff --git a/snuba/clickhouse/formatter/nodes.py b/snuba/clickhouse/formatter/nodes.py index f81f1b3c2f2..09a8d8f8c27 100644 --- a/snuba/clickhouse/formatter/nodes.py +++ b/snuba/clickhouse/formatter/nodes.py @@ -1,10 +1,11 @@ from abc import ABC +from collections.abc import Sequence from dataclasses import dataclass -from typing import Any, Optional, Sequence, Union +from typing import Any @dataclass(frozen=True) -class FormattedNode(ABC): +class FormattedNode(ABC): # noqa: B024 - methods raise NotImplementedError rather than using @abstractmethod to preserve runtime behavior """ After formatting all the clauses of a query, we may serialize the query itself as a string or exporting it in a structured format for @@ -20,7 +21,7 @@ def __str__(self) -> str: """ raise NotImplementedError - def structured(self) -> Union[str, Sequence[Any]]: + def structured(self) -> str | Sequence[Any]: """ This exports the query as a Sequence of clauses. Each clause is either a string or a Sequence itself (like for subqueries). @@ -53,9 +54,9 @@ def structured(self) -> Sequence[Any]: @dataclass(frozen=True) class PaddingNode(FormattedNode): - prefix: Optional[str] + prefix: str | None node: FormattedNode - suffix: Optional[str] = None + suffix: str | None = None def __str__(self) -> str: prefix = f"{self.prefix} " if self.prefix else "" @@ -74,8 +75,7 @@ def structured(self) -> Sequence[Any]: return ret + [ self.suffix, ] - else: - return ret + return ret @dataclass(frozen=True) @@ -86,7 +86,7 @@ class FormattedQuery(SequenceNode): differently for different usages (running the query or tracing). """ - def get_sql(self, format: Optional[str] = None) -> str: + def get_sql(self, format: str | None = None) -> str: query = str(self) if format is not None: query = f"{query} FORMAT {format}" diff --git a/snuba/clickhouse/formatter/query.py b/snuba/clickhouse/formatter/query.py index 77e55ceb650..12624284df2 100644 --- a/snuba/clickhouse/formatter/query.py +++ b/snuba/clickhouse/formatter/query.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Optional, Sequence, Type, Union +from collections.abc import Sequence from snuba.clickhouse.escaping import escape_alias from snuba.clickhouse.formatter.expression import ( @@ -31,7 +31,7 @@ from snuba.query.expressions import Expression, ExpressionVisitor from snuba.query.parsing import ParsingContext -FormattableQuery = Union[Query, CompositeQuery[Table]] +FormattableQuery = Query | CompositeQuery[Table] def format_query(query: FormattableQuery) -> FormattedQuery: @@ -45,8 +45,7 @@ def format_query(query: FormattableQuery) -> FormattedQuery: """ if isinstance(query, Query) and query.is_delete(): return FormattedQuery(_format_delete_query_content(query, ClickhouseExpressionFormatter)) - else: - return FormattedQuery(_format_query_content(query, ClickhouseExpressionFormatter)) + return FormattedQuery(_format_query_content(query, ClickhouseExpressionFormatter)) def format_query_anonymized(query: FormattableQuery) -> FormattedQuery: @@ -60,7 +59,7 @@ class DataSourceFormatter(DataSourceVisitor[FormattedNode, Table]): in a Composite query. """ - def __init__(self, expression_formatter_type: Type[ExpressionFormatterBase]): + def __init__(self, expression_formatter_type: type[ExpressionFormatterBase]): self.__expression_formatter_type = expression_formatter_type def _visit_simple_source(self, data_source: Table) -> StringNode: @@ -89,7 +88,7 @@ def _visit_composite_query(self, data_source: CompositeQuery[Table]) -> Formatte def _format_query_content( query: FormattableQuery, - expression_formatter_type: Type[ExpressionFormatterBase], + expression_formatter_type: type[ExpressionFormatterBase], ) -> Sequence[FormattedNode]: """ Produces the content of the formatted query. @@ -127,7 +126,7 @@ def _format_query_content( def _format_delete_query_content( - query: FormattableQuery, expression_formatter_type: Type[ExpressionFormatterBase] + query: FormattableQuery, expression_formatter_type: type[ExpressionFormatterBase] ) -> Sequence[FormattedNode]: formatter = expression_formatter_type() return [ @@ -147,7 +146,7 @@ def _format_delete_query_content( def _format_on_cluster( query: AbstractQuery, formatter: ExpressionVisitor[str] -) -> Optional[StringNode]: +) -> StringNode | None: on_cluster = query.get_on_cluster() if on_cluster: return StringNode(f"ON CLUSTER {on_cluster.accept(formatter)}") @@ -161,16 +160,14 @@ def _format_select(query: AbstractQuery, formatter: ExpressionVisitor[str]) -> S def _build_optional_string_node( name: str, - expression: Optional[Expression], + expression: Expression | None, formatter: ExpressionVisitor[str], -) -> Optional[StringNode]: +) -> StringNode | None: return StringNode(f"{name} {expression.accept(formatter)}") if expression is not None else None -def _format_groupby( - query: AbstractQuery, formatter: ExpressionVisitor[str] -) -> Optional[StringNode]: - group_clause: Optional[StringNode] = None +def _format_groupby(query: AbstractQuery, formatter: ExpressionVisitor[str]) -> StringNode | None: + group_clause: StringNode | None = None ast_groupby = query.get_groupby() if ast_groupby: groupby_expressions = [e.accept(formatter) for e in ast_groupby] @@ -181,20 +178,15 @@ def _format_groupby( return group_clause -def _format_orderby( - query: AbstractQuery, formatter: ExpressionVisitor[str] -) -> Optional[StringNode]: +def _format_orderby(query: AbstractQuery, formatter: ExpressionVisitor[str]) -> StringNode | None: ast_orderby = query.get_orderby() if ast_orderby: orderby = [f"{e.expression.accept(formatter)} {e.direction.value}" for e in ast_orderby] return StringNode(f"ORDER BY {', '.join(orderby)}") - else: - return None + return None -def _format_limitby( - query: AbstractQuery, formatter: ExpressionVisitor[str] -) -> Optional[StringNode]: +def _format_limitby(query: AbstractQuery, formatter: ExpressionVisitor[str]) -> StringNode | None: ast_limitby = query.get_limitby() if ast_limitby is not None: @@ -204,9 +196,7 @@ def _format_limitby( return None -def _format_arrayjoin( - query: AbstractQuery, formatter: ExpressionVisitor[str] -) -> Optional[StringNode]: +def _format_arrayjoin(query: AbstractQuery, formatter: ExpressionVisitor[str]) -> StringNode | None: array_join = query.get_arrayjoin() if array_join is not None: column_likes_joined = [el.accept(formatter) for el in array_join] @@ -215,7 +205,7 @@ def _format_arrayjoin( return None -def _format_limit(query: AbstractQuery, formatter: ExpressionVisitor[str]) -> Optional[StringNode]: +def _format_limit(query: AbstractQuery, formatter: ExpressionVisitor[str]) -> StringNode | None: ast_limit = query.get_limit() return ( StringNode(f"LIMIT {ast_limit} OFFSET {query.get_offset()}") @@ -229,7 +219,7 @@ class JoinFormatter(JoinVisitor[FormattedNode, Table]): Formats a Join tree. """ - def __init__(self, ExpressionFormatter: Type[ExpressionFormatterBase]): + def __init__(self, ExpressionFormatter: type[ExpressionFormatterBase]): self.ExpressionFormatter = ExpressionFormatter def visit_individual_node(self, node: IndividualNode[Table]) -> FormattedNode: diff --git a/snuba/clickhouse/http.py b/snuba/clickhouse/http.py index 2b32dd5dd0c..ffd58e11e5d 100644 --- a/snuba/clickhouse/http.py +++ b/snuba/clickhouse/http.py @@ -2,18 +2,12 @@ import logging import re +from collections.abc import Iterable, Iterator, Mapping, Sequence from concurrent.futures import ThreadPoolExecutor from datetime import datetime from queue import Queue, SimpleQueue from typing import ( Any, - Iterable, - Iterator, - List, - Mapping, - Optional, - Sequence, - Union, cast, ) from urllib.parse import urlencode @@ -53,8 +47,7 @@ class JSONRowEncoder(Encoder[bytes, WriterTableRow]): def __default(self, value: Any) -> Any: if isinstance(value, datetime): return value.strftime(DATETIME_FORMAT) - else: - raise TypeError + raise TypeError def encode(self, value: WriterTableRow) -> bytes: return cast(bytes, rapidjson.dumps(value, default=self.__default).encode("utf-8")) @@ -68,21 +61,20 @@ def __init__(self, columns: Iterable[str]) -> None: def encode_value(self, value: Any) -> str: if isinstance(value, Expression): return value.accept(self.__formatter) - else: - raise TypeError("unknown Clickhouse value type", value.__class__) + raise TypeError("unknown Clickhouse value type", value.__class__) def encode(self, row: WriterTableRow) -> bytes: ordered_columns = [self.encode_value(row.get(column)) for column in self.__columns] ordered_columns_str = ",".join(ordered_columns) - return f"({ordered_columns_str})".encode("utf-8") + return f"({ordered_columns_str})".encode() class InsertStatement: def __init__(self, table_name: str) -> None: self.__table_name = table_name - self.__database: Optional[str] = None - self.__format: Optional[str] = None - self.__column_names: Optional[Sequence[str]] = None + self.__database: str | None = None + self.__format: str | None = None + self.__column_names: Sequence[str] | None = None def with_database(self, database_name: str) -> InsertStatement: self.__database = database_name @@ -104,7 +96,7 @@ def build_statement(self) -> str: format_statement = "" if self.__format: - if not self.__format == "VALUES": + if self.__format != "VALUES": format_statement = f"FORMAT {self.__format}" else: format_statement = "VALUES" @@ -156,16 +148,16 @@ def __init__( user: str, password: str, statement: InsertStatement, - encoding: Optional[str], + encoding: str | None, options: Mapping[str, Any], # should be ``Mapping[str, str]``? - chunk_size: Optional[int] = None, + chunk_size: int | None = None, buffer_size: int = 0, # 0 means unbounded - debug_buffer_size_bytes: Optional[int] = None, # None means disabled + debug_buffer_size_bytes: int | None = None, # None means disabled ) -> None: if chunk_size is None: chunk_size = settings.CLICKHOUSE_HTTP_CHUNK_SIZE - self.__queue: Union[Queue[Union[bytes, None]], SimpleQueue[Union[bytes, None]]] = ( + self.__queue: Queue[bytes | None] | SimpleQueue[bytes | None] = ( Queue(buffer_size) if buffer_size else SimpleQueue() ) @@ -193,7 +185,7 @@ def __init__( body=body, ) - self.__debug_buffer: List[bytes] = [] + self.__debug_buffer: list[bytes] = [] self.__size = 0 self.__debug_buffer_size_bytes = debug_buffer_size_bytes self.__closed = False @@ -226,7 +218,7 @@ def close(self) -> None: self.__queue.put(None) self.__closed = True - def join(self, timeout: Optional[float] = None) -> None: + def join(self, timeout: float | None = None) -> None: try: response = self._result.result(timeout) except TimeoutError: @@ -268,8 +260,7 @@ def join(self, timeout: Optional[float] = None) -> None: sentry_sdk.set_tag("snuba_has_errored_row", "false") raise ClickhouseWriterError(message, code=code, row=row) - else: - raise HTTPError(f"Received unexpected {response.status} response: {content}") + raise HTTPError(f"Received unexpected {response.status} response: {content}") class HTTPBatchWriter(BatchWriter[bytes]): @@ -280,13 +271,13 @@ def __init__( user: str, password: str, secure: bool, - ca_certs: Optional[str], - verify: Optional[bool], + ca_certs: str | None, + verify: bool | None, metrics: MetricsBackend, statement: InsertStatement, - encoding: Optional[str], - options: Optional[Mapping[str, Any]] = None, - chunk_size: Optional[int] = None, + encoding: str | None, + options: Mapping[str, Any] | None = None, + chunk_size: int | None = None, buffer_size: int = 0, max_connections: int = 1, block_connections: bool = False, diff --git a/snuba/clickhouse/native.py b/snuba/clickhouse/native.py index a737a5e30d4..fea9a822433 100644 --- a/snuba/clickhouse/native.py +++ b/snuba/clickhouse/native.py @@ -4,19 +4,14 @@ import queue import re import time +from collections.abc import Generator, Mapping, Sequence from contextlib import contextmanager from dataclasses import dataclass, field from datetime import date, datetime from io import StringIO from typing import ( Any, - Dict, - Generator, - Mapping, - Optional, - Sequence, TypedDict, - Union, cast, ) from uuid import UUID @@ -39,7 +34,7 @@ trace_logger = logging.getLogger("clickhouse_driver.log") trace_logger.setLevel("INFO") -Params = Optional[Union[Sequence[Any], Mapping[str, Any]]] +Params = Sequence[Any] | Mapping[str, Any] | None metrics = MetricsWrapper(environment.metrics, "clickhouse.native") @@ -61,7 +56,7 @@ class ClickhouseResult: @contextmanager -def capture_logging() -> Generator[StringIO, None, None]: +def capture_logging() -> Generator[StringIO]: buffer = StringIO() new_handler = logging.StreamHandler(buffer) trace_logger.addHandler(new_handler) @@ -72,7 +67,7 @@ def capture_logging() -> Generator[StringIO, None, None]: buffer.close() -class ClickhousePool(object): +class ClickhousePool: def __init__( self, host: str, @@ -81,10 +76,10 @@ def __init__( password: str, database: str, secure: bool = False, - ca_certs: Optional[str] = None, - verify: Optional[bool] = False, + ca_certs: str | None = None, + verify: bool | None = False, connect_timeout: int = 1, - send_receive_timeout: Optional[int] = 35, + send_receive_timeout: int | None = 35, max_pool_size: int = settings.CLICKHOUSE_MAX_POOL_SIZE, pool_get_timeout_seconds: float = settings.CLICKHOUSE_POOL_GET_TIMEOUT_SECONDS, client_settings: Mapping[str, Any] = {}, @@ -102,7 +97,7 @@ def __init__( self.pool_get_timeout_seconds = pool_get_timeout_seconds self.client_settings = client_settings - self.pool: queue.LifoQueue[Optional[Client]] = queue.LifoQueue(max_pool_size) + self.pool: queue.LifoQueue[Client | None] = queue.LifoQueue(max_pool_size) self.__gauge = ThreadSafeGauge(metrics, "connections") # Fill the queue up so that doing get() on it will block properly @@ -116,8 +111,8 @@ def execute( query: str, params: Params = None, with_column_types: bool = False, - query_id: Optional[str] = None, - settings: Optional[Mapping[str, Any]] = None, + query_id: str | None = None, + settings: Mapping[str, Any] | None = None, types_check: bool = False, columnar: bool = False, capture_trace: bool = False, @@ -161,12 +156,12 @@ def execute( else {"send_logs_level": "trace"} ) - def query_execute() -> Any: + def query_execute(conn: Client = conn, settings: Any = settings) -> Any: with sentry_sdk.start_span(description=query, op="db.clickhouse") as span: span.set_data(sentry_sdk.consts.SPANDATA.DB_SYSTEM, "clickhouse") span.set_data("query_id", query_id) span.set_data("settings", settings) - return conn.execute( # type: ignore + return conn.execute( query, params=params, with_column_types=with_column_types, @@ -227,12 +222,10 @@ def query_execute() -> Any: if attempts_remaining == 0: if isinstance(e, errors.Error): raise ClickhouseError(e.message, code=e.code) from e - else: - raise e - else: - # Short sleep to make sure we give the load - # balancer a chance to mark a bad host as down. - time.sleep(0.1) + raise e + # Short sleep to make sure we give the load + # balancer a chance to mark a bad host as down. + time.sleep(0.1) except errors.Error as e: if e.code == errors.ErrorCodes.TOO_MANY_SIMULTANEOUS_QUERIES: attempts_remaining -= 1 @@ -264,8 +257,8 @@ def execute_robust( query: str, params: Params = None, with_column_types: bool = False, - query_id: Optional[str] = None, - settings: Optional[Mapping[str, Any]] = None, + query_id: str | None = None, + settings: Mapping[str, Any] | None = None, types_check: bool = False, columnar: bool = False, capture_trace: bool = False, @@ -306,8 +299,7 @@ def execute_robust( if attempts_remaining <= 0: if isinstance(e, errors.Error): raise ClickhouseError(e.message, code=e.code) from e - else: - raise e + raise e time.sleep(1) continue except ClickhouseError as e: @@ -329,9 +321,8 @@ def execute_robust( float((total_attempts - attempts_remaining) * sleep_interval_seconds) ) continue - else: - # Quit immediately for other types of server errors. - raise e + # Quit immediately for other types of server errors. + raise e except errors.Error as e: raise ClickhouseError(e.message, code=e.code) from e @@ -403,9 +394,9 @@ def transform_uuid(value: UUID) -> str: class NativeDriverReader(Reader): def __init__( self, - cache_partition_id: Optional[str], + cache_partition_id: str | None, client: ClickhousePool, - query_settings_prefix: Optional[str], + query_settings_prefix: str | None, ) -> None: super().__init__( cache_partition_id=cache_partition_id, @@ -420,7 +411,7 @@ def __transform_result(self, result: ClickhouseResult, with_totals: bool) -> Res """ meta = result.meta if result.meta is not None else [] data = result.results - profile = cast(Optional[Dict[str, Any]], result.profile) + profile = cast(dict[str, Any] | None, result.profile) # XXX: Rows are represented as mappings that are keyed by column or # alias, which is problematic when the result set contains duplicate # names. To ensure that the column headers and row data are consistent @@ -458,7 +449,7 @@ def execute( self, query: FormattedQuery, # TODO: move Clickhouse specific arguments into clickhouse.query.Query - settings: Optional[Mapping[str, str]] = None, + settings: Mapping[str, str] | None = None, with_totals: bool = False, robust: bool = False, capture_trace: bool = False, diff --git a/snuba/clickhouse/optimize/optimize.py b/snuba/clickhouse/optimize/optimize.py index abe073d1146..300aad71705 100644 --- a/snuba/clickhouse/optimize/optimize.py +++ b/snuba/clickhouse/optimize/optimize.py @@ -5,9 +5,10 @@ import threading import time from collections import deque +from collections.abc import Mapping, Sequence from concurrent.futures import Future, ThreadPoolExecutor from datetime import UTC, datetime, timedelta -from typing import Any, Mapping, Optional, Sequence +from typing import Any import structlog from structlog.types import EventDict, WrappedLogger @@ -46,7 +47,7 @@ def thread_info_processor(_: WrappedLogger, __: str, event_dict: EventDict) -> E metrics = MetricsWrapper(environment.metrics, "optimize") -def _get_metrics_tags(table: str, clickhouse_host: Optional[str]) -> Mapping[str, str]: +def _get_metrics_tags(table: str, clickhouse_host: str | None) -> Mapping[str, str]: return {"table": table, "host": clickhouse_host} if clickhouse_host else {"table": table} @@ -54,7 +55,7 @@ def run_optimize( clickhouse: ClickhousePool, storage: ReadableTableStorage, database: str, - before: Optional[datetime] = None, + before: datetime | None = None, ) -> int: """ The most basic form of running an optimize final on a storage. @@ -87,7 +88,7 @@ def run_optimize_cron_job( default_parallel_threads: int, clickhouse_host: str, tracker: OptimizedPartitionTracker, - before: Optional[datetime] = None, + before: datetime | None = None, divide_partitions_count: int = 1, ) -> int: """ @@ -164,7 +165,7 @@ def get_partitions_from_clickhouse( storage: ReadableTableStorage, database: str, table: str, - before: Optional[datetime] = None, + before: datetime | None = None, ) -> Sequence[util.Part]: """ Get the partitions from ClickHouse that are active and would benefit from OPTIMIZE @@ -203,8 +204,7 @@ def get_partitions_from_clickhouse( if not response.results: logger.warning( - "Table %s.%s doesn't exist on %s:%s" - % (database, table, clickhouse.host, clickhouse.port) + f"Table {database}.{table} doesn't exist on {clickhouse.host}:{clickhouse.port}" ) return [] @@ -363,9 +363,9 @@ def optimize_partitions( database: str, table: str, partitions: Sequence[str], - cutoff_time: Optional[datetime] = None, - tracker: Optional[OptimizedPartitionTracker] = None, - clickhouse_host: Optional[str] = None, + cutoff_time: datetime | None = None, + tracker: OptimizedPartitionTracker | None = None, + clickhouse_host: str | None = None, ) -> None: query_template = f"""\ OPTIMIZE TABLE {database}.{table} @@ -433,7 +433,7 @@ def _hash_partition(partition_name: str) -> int: return int(sha1.hexdigest(), 16) -def _days_since_epoch(current_time: Optional[datetime] = None) -> int: +def _days_since_epoch(current_time: datetime | None = None) -> int: if current_time is None: current_time = datetime.now(UTC) return int(current_time.timestamp() / 86400) diff --git a/snuba/clickhouse/optimize/optimize_scheduler.py b/snuba/clickhouse/optimize/optimize_scheduler.py index d3842b5ab30..5adeb571a9e 100644 --- a/snuba/clickhouse/optimize/optimize_scheduler.py +++ b/snuba/clickhouse/optimize/optimize_scheduler.py @@ -1,7 +1,7 @@ import re +from collections.abc import MutableSequence, Sequence from dataclasses import dataclass from datetime import UTC, datetime, timedelta -from typing import MutableSequence, Sequence from snuba import settings from snuba.clickhouse.optimize.util import get_num_threads @@ -108,19 +108,17 @@ def get_next_schedule(self, partitions: Sequence[str]) -> OptimizationSchedule: cutoff_time=self.__last_midnight + timedelta(hours=settings.OPTIMIZE_JOB_CUTOFF_TIME), ) - else: - if current_time < self.__parallel_start_time: - return OptimizationSchedule( - partitions_groups=[self._sort_partitions(partitions)], - cutoff_time=self.__parallel_start_time, - ) - elif current_time < self.__parallel_end_time: - return OptimizationSchedule( - partitions_groups=self._subdivide_partitions(partitions, num_threads), - cutoff_time=self.__parallel_end_time, - ) - else: - return OptimizationSchedule( - partitions_groups=[self._sort_partitions(partitions)], - cutoff_time=self.__full_job_end_time, - ) + if current_time < self.__parallel_start_time: + return OptimizationSchedule( + partitions_groups=[self._sort_partitions(partitions)], + cutoff_time=self.__parallel_start_time, + ) + if current_time < self.__parallel_end_time: + return OptimizationSchedule( + partitions_groups=self._subdivide_partitions(partitions, num_threads), + cutoff_time=self.__parallel_end_time, + ) + return OptimizationSchedule( + partitions_groups=[self._sort_partitions(partitions)], + cutoff_time=self.__full_job_end_time, + ) diff --git a/snuba/clickhouse/optimize/optimize_tracker.py b/snuba/clickhouse/optimize/optimize_tracker.py index 4a6cf0729ea..b2b47855d86 100644 --- a/snuba/clickhouse/optimize/optimize_tracker.py +++ b/snuba/clickhouse/optimize/optimize_tracker.py @@ -1,5 +1,5 @@ +from collections.abc import Sequence from datetime import datetime -from typing import Sequence, Set from snuba.redis import RedisClientType @@ -44,11 +44,11 @@ def __init__( self.__completed_bucket = f"{common_prefix}:completed" self.__key_expire_time = expire_time - def __get_partitions(self, bucket: str) -> Set[str]: + def __get_partitions(self, bucket: str) -> set[str]: """ Get the partitions from a given bucket. """ - partitions_set: Set[str] = set() + partitions_set: set[str] = set() partitions = self.__redis_client.smembers(bucket) if partitions: for partition in partitions: @@ -57,13 +57,13 @@ def __get_partitions(self, bucket: str) -> Set[str]: return partitions_set - def get_all_partitions(self) -> Set[str]: + def get_all_partitions(self) -> set[str]: """ Get a set of partitions which need to be optimized. """ return self.__get_partitions(self.__all_bucket) - def get_completed_partitions(self) -> Set[str]: + def get_completed_partitions(self) -> set[str]: """ Get a set of partitions that have completed optimization. """ @@ -94,7 +94,7 @@ def update_completed_partitions(self, part_name: str) -> None: """ self.__update_partitions(self.__completed_bucket, [part_name.encode("utf-8")]) - def get_partitions_to_optimize(self) -> Set[str]: + def get_partitions_to_optimize(self) -> set[str]: """ Get a set of partition names which need optimization. @@ -113,8 +113,7 @@ def get_partitions_to_optimize(self) -> Set[str]: if not completed_partitions: return all_partitions - else: - return all_partitions - completed_partitions + return all_partitions - completed_partitions def delete_all_states(self) -> None: """ diff --git a/snuba/clickhouse/query.py b/snuba/clickhouse/query.py index 74e095d73fc..87f81332522 100644 --- a/snuba/clickhouse/query.py +++ b/snuba/clickhouse/query.py @@ -1,4 +1,4 @@ -from typing import Callable, Iterable, Optional, Sequence +from collections.abc import Callable, Iterable, Sequence from snuba.query import LimitBy, OrderBy, SelectedExpression from snuba.query import ProcessableQuery as AbstractQuery @@ -14,21 +14,21 @@ class Query(AbstractQuery[Table]): def __init__( self, - from_clause: Optional[Table], + from_clause: Table | None, # New data model to replace the one based on the dictionary - selected_columns: Optional[Sequence[SelectedExpression]] = None, - array_join: Optional[Sequence[Expression]] = None, - condition: Optional[Expression] = None, - prewhere: Optional[Expression] = None, - groupby: Optional[Sequence[Expression]] = None, - having: Optional[Expression] = None, - order_by: Optional[Sequence[OrderBy]] = None, - limitby: Optional[LimitBy] = None, - limit: Optional[int] = None, + selected_columns: Sequence[SelectedExpression] | None = None, + array_join: Sequence[Expression] | None = None, + condition: Expression | None = None, + prewhere: Expression | None = None, + groupby: Sequence[Expression] | None = None, + having: Expression | None = None, + order_by: Sequence[OrderBy] | None = None, + limitby: LimitBy | None = None, + limit: int | None = None, offset: int = 0, totals: bool = False, - granularity: Optional[int] = None, - on_cluster: Optional[Expression] = None, + granularity: int | None = None, + on_cluster: Expression | None = None, is_delete: bool = False, ) -> None: self.__prewhere = prewhere @@ -60,13 +60,13 @@ def _transform_impl(self, visitor: ExpressionVisitor[Expression]) -> None: if self.__prewhere is not None: self.__prewhere = self.__prewhere.accept(visitor) - def get_prewhere_ast(self) -> Optional[Expression]: + def get_prewhere_ast(self) -> Expression | None: """ Temporary method until pre where management is moved to Clickhouse query """ return self.__prewhere - def set_prewhere_ast_condition(self, condition: Optional[Expression]) -> None: + def set_prewhere_ast_condition(self, condition: Expression | None) -> None: self.__prewhere = condition def _eq_functions(self) -> Sequence[str]: diff --git a/snuba/clickhouse/query_dsl/accessors.py b/snuba/clickhouse/query_dsl/accessors.py index 91392d0c0b1..5ded64de33a 100644 --- a/snuba/clickhouse/query_dsl/accessors.py +++ b/snuba/clickhouse/query_dsl/accessors.py @@ -1,5 +1,6 @@ +from collections.abc import Sequence from datetime import datetime -from typing import Optional, Sequence, Set, Tuple, Union, cast +from typing import cast from snuba.query import FromClauseNotSet, ProcessableQuery from snuba.query import Query as AbstractQuery @@ -26,7 +27,7 @@ ) -def get_object_ids_in_condition(condition: Expression, object_column: str) -> Set[int]: +def get_object_ids_in_condition(condition: Expression, object_column: str) -> set[int]: """ Extract project ids from an expression. Returns None if no project if condition is found. It returns an empty set of conflicting project_id @@ -64,19 +65,18 @@ def get_object_ids_in_condition(condition: Expression, object_column: str) -> Se rhs_objects = get_object_ids_in_condition(match.expression("rhs"), object_column) if not lhs_objects: return rhs_objects - elif not rhs_objects: + if not rhs_objects: return lhs_objects - else: - return ( - lhs_objects & rhs_objects - if match.string("operator") == BooleanFunctions.AND - else lhs_objects | rhs_objects - ) + return ( + lhs_objects & rhs_objects + if match.string("operator") == BooleanFunctions.AND + else lhs_objects | rhs_objects + ) return set() -def get_object_ids_in_query_ast(query: AbstractQuery, object_column: str) -> Set[int]: +def get_object_ids_in_query_ast(query: AbstractQuery, object_column: str) -> set[int]: """ Finds the object ids (e.g. project ids) this query is filtering according to the AST query representation. @@ -95,7 +95,7 @@ def get_object_ids_in_query_ast(query: AbstractQuery, object_column: str) -> Set return this_query_object_ids if isinstance(from_clause, SimpleDataSource): return this_query_object_ids - elif isinstance(from_clause, AbstractQuery): + if isinstance(from_clause, AbstractQuery): subquery_project_ids = get_object_ids_in_query_ast(from_clause, object_column) return subquery_project_ids.union(this_query_object_ids) return set() @@ -104,13 +104,13 @@ def get_object_ids_in_query_ast(query: AbstractQuery, object_column: str) -> Set def get_time_range_expressions( conditions: Sequence[Expression], timestamp_field: str, - table_name: Optional[str] = None, -) -> Tuple[ - Optional[Tuple[datetime, FunctionCallExpr]], - Optional[Tuple[datetime, FunctionCallExpr]], + table_name: str | None = None, +) -> tuple[ + tuple[datetime, FunctionCallExpr] | None, + tuple[datetime, FunctionCallExpr] | None, ]: - max_lower_bound: Optional[Tuple[datetime, FunctionCallExpr]] = None - min_upper_bound: Optional[Tuple[datetime, FunctionCallExpr]] = None + max_lower_bound: tuple[datetime, FunctionCallExpr] | None = None + min_upper_bound: tuple[datetime, FunctionCallExpr] | None = None table_match = String(table_name) if table_name else None for c in conditions: match = FunctionCall( @@ -143,9 +143,9 @@ def get_time_range_expressions( def get_time_range( - query: Union[ProcessableQuery[Table], ProcessableQuery[LogicalDataSource]], + query: ProcessableQuery[Table] | ProcessableQuery[LogicalDataSource], timestamp_field: str, -) -> Tuple[Optional[datetime], Optional[datetime]]: +) -> tuple[datetime | None, datetime | None]: """ Finds the minimal time range for this query. Which means, it finds the >= timestamp condition with the highest datetime literal and @@ -169,7 +169,7 @@ def get_time_range( def get_time_range_estimate( query: ProcessableQuery[Table], -) -> Tuple[Optional[datetime], Optional[datetime]]: +) -> tuple[datetime | None, datetime | None]: """ Best guess to find the time range for the query. We pick the first column that is compared with a datetime Literal. diff --git a/snuba/clickhouse/query_inspector.py b/snuba/clickhouse/query_inspector.py index 3eb4b638c84..ee32750c30a 100644 --- a/snuba/clickhouse/query_inspector.py +++ b/snuba/clickhouse/query_inspector.py @@ -1,4 +1,4 @@ -from typing import Mapping, MutableMapping, Optional, Set +from collections.abc import Mapping, MutableMapping from snuba.clickhouse.query_dsl.accessors import get_time_range_estimate from snuba.query import ProcessableQuery @@ -12,12 +12,11 @@ from snuba.query.expressions import FunctionCall as FunctionCallExpr -def _get_date_range(query: ProcessableQuery[Table]) -> Optional[int]: +def _get_date_range(query: ProcessableQuery[Table]) -> int | None: from_date, to_date = get_time_range_estimate(query) if from_date is None or to_date is None: return None - else: - return (to_date - from_date).days + return (to_date - from_date).days class TablesCollector(DataSourceVisitor[None, Table], JoinVisitor[None, Table]): @@ -28,20 +27,20 @@ class TablesCollector(DataSourceVisitor[None, Table], JoinVisitor[None, Table]): """ def __init__(self) -> None: - self.__tables: Set[str] = set() - self.__max_time_range: Optional[int] = None + self.__tables: set[str] = set() + self.__max_time_range: int | None = None self.__has_complex_conditions: bool = False self.__final: bool = False - self.__sample_rate: Optional[float] = None - self.__all_raw_columns: MutableMapping[str, Set[ColumnExpr]] = {} + self.__sample_rate: float | None = None + self.__all_raw_columns: MutableMapping[str, set[ColumnExpr]] = {} self.__all_conditions: MutableMapping[str, Expression] = {} - self.__all_groupby: MutableMapping[str, Set[Expression]] = {} - self.__all_array_joins: MutableMapping[str, Set[Expression]] = {} + self.__all_groupby: MutableMapping[str, set[Expression]] = {} + self.__all_array_joins: MutableMapping[str, set[Expression]] = {} - def get_tables(self) -> Set[str]: + def get_tables(self) -> set[str]: return self.__tables - def get_max_time_range(self) -> Optional[int]: + def get_max_time_range(self) -> int | None: return self.__max_time_range def has_complex_condition(self) -> bool: @@ -50,19 +49,19 @@ def has_complex_condition(self) -> bool: def any_final(self) -> bool: return self.__final - def get_sample_rate(self) -> Optional[float]: + def get_sample_rate(self) -> float | None: return self.__sample_rate - def get_all_raw_columns(self) -> Mapping[str, Set[ColumnExpr]]: + def get_all_raw_columns(self) -> Mapping[str, set[ColumnExpr]]: return self.__all_raw_columns def get_all_conditions(self) -> Mapping[str, Expression]: return self.__all_conditions - def get_all_groupby(self) -> Mapping[str, Set[Expression]]: + def get_all_groupby(self) -> Mapping[str, set[Expression]]: return self.__all_groupby - def get_all_arrayjoin(self) -> Mapping[str, Set[Expression]]: + def get_all_arrayjoin(self) -> Mapping[str, set[Expression]]: return self.__all_array_joins def __find_complex_conditions(self, query: ProcessableQuery[Table]) -> bool: @@ -83,8 +82,8 @@ def _visit_simple_source(self, data_source: Table) -> None: def _visit_join(self, data_source: JoinClause[Table]) -> None: self.visit_join_clause(data_source) - def _list_array_join(self, query: ProcessableQuery[Table]) -> Set[Expression]: - ret: Set[Expression] = set() + def _list_array_join(self, query: ProcessableQuery[Table]) -> set[Expression]: + ret: set[Expression] = set() query_arrayjoin = query.get_arrayjoin() if query_arrayjoin is not None: ret.update(query_arrayjoin) @@ -105,9 +104,7 @@ def _visit_simple_query(self, data_source: ProcessableQuery[Table]) -> None: ) table_name = data_source.get_from_clause().table_name - self.__all_raw_columns[table_name] = { - c for c in data_source.get_all_ast_referenced_columns() - } + self.__all_raw_columns[table_name] = set(data_source.get_all_ast_referenced_columns()) condition = data_source.get_condition() if condition is not None: diff --git a/snuba/clickhouse/query_profiler.py b/snuba/clickhouse/query_profiler.py index 66b501cba83..66f68c04865 100644 --- a/snuba/clickhouse/query_profiler.py +++ b/snuba/clickhouse/query_profiler.py @@ -1,5 +1,5 @@ import logging -from typing import Iterable, List, Mapping, Set, Union +from collections.abc import Iterable, Mapping from snuba.clickhouse.query import Query from snuba.clickhouse.query_inspector import TablesCollector @@ -21,7 +21,7 @@ logger = logging.getLogger(__name__) -def _get_all_columns(all_columns: Mapping[str, Set[ColumnExpr]]) -> Columnset: +def _get_all_columns(all_columns: Mapping[str, set[ColumnExpr]]) -> Columnset: return { f"{table_name}.{c.column_name}" for table_name, columns in all_columns.items() @@ -33,7 +33,7 @@ def _get_columns_from_expression(expression: Expression, table_name: str) -> Col return {f"{table_name}.{c.column_name}" for c in expression if isinstance(c, ColumnExpr)} -def _list_columns(expressions: Mapping[str, Set[Expression]]) -> Columnset: +def _list_columns(expressions: Mapping[str, set[Expression]]) -> Columnset: ret = set() for table_name, expression_set in expressions.items(): for e in expression_set: @@ -42,7 +42,7 @@ def _list_columns(expressions: Mapping[str, Set[Expression]]) -> Columnset: return ret -def _flatten_col_set(nested_sets: Iterable[Set[str]]) -> Columnset: +def _flatten_col_set(nested_sets: Iterable[set[str]]) -> Columnset: ret = set() for s in nested_sets: ret |= s @@ -52,14 +52,14 @@ def _flatten_col_set(nested_sets: Iterable[Set[str]]) -> Columnset: def _list_columns_in_condition(condition_expression: Mapping[str, Expression]) -> Columnset: return _flatten_col_set( [ - {c for c in _get_columns_from_expression(expression, table_name)} + set(_get_columns_from_expression(expression, table_name)) for table_name, expression in condition_expression.items() ] ) def _list_mappings(condition_expression: Mapping[str, Expression]) -> Columnset: - nested_sets: List[Set[str]] = [] + nested_sets: list[set[str]] = [] for table_name, expression in condition_expression.items(): ret = set() for e in expression: @@ -74,7 +74,7 @@ def _list_mappings(condition_expression: Mapping[str, Expression]) -> Columnset: def generate_profile( - query: Union[Query, CompositeQuery[Table]], + query: Query | CompositeQuery[Table], ) -> ClickhouseQueryProfile: """ Takes a Physical query in, analyzes it and produces the @@ -88,7 +88,7 @@ def generate_profile( try: return ClickhouseQueryProfile( time_range=collector.get_max_time_range(), - table=",".join(sorted([t for t in collector.get_tables()])), + table=",".join(sorted(collector.get_tables())), all_columns=_get_all_columns(collector.get_all_raw_columns()), multi_level_condition=collector.has_complex_condition(), where_profile=FilterProfile( diff --git a/snuba/clickhouse/translators/snuba/allowed.py b/snuba/clickhouse/translators/snuba/allowed.py index ade9433f2d7..c499ac3e04e 100644 --- a/snuba/clickhouse/translators/snuba/allowed.py +++ b/snuba/clickhouse/translators/snuba/allowed.py @@ -1,7 +1,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import List, Optional, Type, TypeVar, Union, cast +from typing import TypeVar, cast from snuba.clickhouse.translators.snuba import SnubaClickhouseStrictTranslator from snuba.datasets.plans.translator.mapper import ExpressionMapper @@ -44,7 +44,7 @@ class LiteralMapper(SnubaClickhouseMapper[Literal, Literal]): pass -ValidColumnMappings = Union[Column, Literal, FunctionCall, CurriedFunctionCall] +ValidColumnMappings = Column | Literal | FunctionCall | CurriedFunctionCall class ColumnMapper(SnubaClickhouseMapper[Column, ValidColumnMappings], metaclass=RegisteredClass): @@ -60,8 +60,8 @@ def config_key(cls) -> str: return cls.__name__ @classmethod - def get_from_name(cls, name: str) -> Type["ColumnMapper"]: - return cast(Type["ColumnMapper"], cls.class_from_name(name)) + def get_from_name(cls, name: str) -> type[ColumnMapper]: + return cast(type["ColumnMapper"], cls.class_from_name(name)) class FunctionCallMapper( @@ -83,12 +83,12 @@ def config_key(cls) -> str: return cls.__name__ @classmethod - def get_from_name(cls, name: str) -> Type["FunctionCallMapper"]: - return cast(Type["FunctionCallMapper"], cls.class_from_name(name)) + def get_from_name(cls, name: str) -> type[FunctionCallMapper]: + return cast(type["FunctionCallMapper"], cls.class_from_name(name)) class CurriedFunctionCallMapper( - SnubaClickhouseMapper[CurriedFunctionCall, Union[CurriedFunctionCall, FunctionCall]], + SnubaClickhouseMapper[CurriedFunctionCall, CurriedFunctionCall | FunctionCall], metaclass=RegisteredClass, ): @classmethod @@ -96,14 +96,14 @@ def config_key(cls) -> str: return cls.__name__ @classmethod - def get_from_name(cls, name: str) -> Type["CurriedFunctionCallMapper"]: - return cast(Type["CurriedFunctionCallMapper"], cls.class_from_name(name)) + def get_from_name(cls, name: str) -> type[CurriedFunctionCallMapper]: + return cast(type["CurriedFunctionCallMapper"], cls.class_from_name(name)) class SubscriptableReferenceMapper( SnubaClickhouseMapper[ SubscriptableReference, - Union[FunctionCall, Literal, SubscriptableReference], + FunctionCall | Literal | SubscriptableReference, ], metaclass=RegisteredClass, ): @@ -121,8 +121,8 @@ def config_key(cls) -> str: return cls.__name__ @classmethod - def get_from_name(cls, name: str) -> Type["SubscriptableReferenceMapper"]: - return cast(Type["SubscriptableReferenceMapper"], cls.class_from_name(name)) + def get_from_name(cls, name: str) -> type[SubscriptableReferenceMapper]: + return cast(type["SubscriptableReferenceMapper"], cls.class_from_name(name)) class LambdaMapper(SnubaClickhouseMapper[Lambda, Lambda]): @@ -149,15 +149,14 @@ def attempt_map( self, expression: Column, children_translator: SnubaClickhouseStrictTranslator, - ) -> Optional[FunctionCall]: + ) -> FunctionCall | None: if expression.column_name in self.column_names: return identity( Literal(None, None), expression.alias or qualified_column(expression.column_name, expression.table_name or ""), ) - else: - return None + return None @dataclass @@ -166,7 +165,7 @@ class DefaultNoneFunctionMapper(FunctionCallMapper): Maps the list of function names to NULL. """ - function_names: List[str] + function_names: list[str] def __post_init__(self) -> None: self.function_match = FunctionCallMatch( @@ -177,7 +176,7 @@ def attempt_map( self, expression: FunctionCall, children_translator: SnubaClickhouseStrictTranslator, - ) -> Optional[FunctionCall]: + ) -> FunctionCall | None: if self.function_match.match(expression): return identity(Literal(None, None), expression.alias) @@ -199,7 +198,7 @@ def attempt_map( self, expression: FunctionCall, children_translator: SnubaClickhouseStrictTranslator, - ) -> Optional[FunctionCall]: + ) -> FunctionCall | None: # HACK: Quick fix to avoid this function dropping important conditions from the query logical_functions = {"and", "or", "xor"} @@ -231,7 +230,7 @@ def attempt_map( self, expression: CurriedFunctionCall, children_translator: SnubaClickhouseStrictTranslator, - ) -> Optional[Union[CurriedFunctionCall, FunctionCall]]: + ) -> CurriedFunctionCall | FunctionCall | None: internal_function = expression.internal_function.accept(children_translator) assert isinstance(internal_function, FunctionCall) # mypy parameters = tuple(p.accept(children_translator) for p in expression.parameters) @@ -255,14 +254,13 @@ class DefaultNoneSubscriptMapper(SubscriptableReferenceMapper): the discover dataset file. """ - subscript_names: List[str] + subscript_names: list[str] def attempt_map( self, expression: SubscriptableReference, children_translator: SnubaClickhouseStrictTranslator, - ) -> Optional[FunctionCall]: + ) -> FunctionCall | None: if expression.column.column_name in self.subscript_names: return identity(Literal(None, None), expression.alias) - else: - return None + return None diff --git a/snuba/clickhouse/translators/snuba/defaults.py b/snuba/clickhouse/translators/snuba/defaults.py index 71e0474928f..262a313f95b 100644 --- a/snuba/clickhouse/translators/snuba/defaults.py +++ b/snuba/clickhouse/translators/snuba/defaults.py @@ -1,5 +1,3 @@ -from typing import Optional - from snuba.clickhouse.translators.snuba import SnubaClickhouseStrictTranslator from snuba.clickhouse.translators.snuba.allowed import ( ArgumentMapper, @@ -26,7 +24,7 @@ def attempt_map( self, expression: Literal, children_translator: SnubaClickhouseStrictTranslator, - ) -> Optional[Literal]: + ) -> Literal | None: return expression @@ -35,7 +33,7 @@ def attempt_map( self, expression: Column, children_translator: SnubaClickhouseStrictTranslator, - ) -> Optional[Column]: + ) -> Column | None: return expression @@ -44,7 +42,7 @@ def attempt_map( self, expression: SubscriptableReference, children_translator: SnubaClickhouseStrictTranslator, - ) -> Optional[SubscriptableReference]: + ) -> SubscriptableReference | None: # TODO: remove a default for SubscriptableReference entirely. # Since there is not SubscriptableReference in clickhouse, such # columns have to be translated by a valid rule. They cannot have @@ -62,7 +60,7 @@ def attempt_map( self, expression: FunctionCall, children_translator: SnubaClickhouseStrictTranslator, - ) -> Optional[FunctionCall]: + ) -> FunctionCall | None: return FunctionCall( alias=expression.alias, function_name=expression.function_name, @@ -75,7 +73,7 @@ def attempt_map( self, expression: CurriedFunctionCall, children_translator: SnubaClickhouseStrictTranslator, - ) -> Optional[CurriedFunctionCall]: + ) -> CurriedFunctionCall | None: return CurriedFunctionCall( alias=expression.alias, internal_function=children_translator.translate_function_strict( @@ -90,7 +88,7 @@ def attempt_map( self, expression: Argument, children_translator: SnubaClickhouseStrictTranslator, - ) -> Optional[Argument]: + ) -> Argument | None: return expression @@ -99,7 +97,7 @@ def attempt_map( self, expression: Lambda, children_translator: SnubaClickhouseStrictTranslator, - ) -> Optional[Lambda]: + ) -> Lambda | None: return Lambda( alias=expression.alias, parameters=expression.parameters, diff --git a/snuba/clickhouse/translators/snuba/function_call_mappers.py b/snuba/clickhouse/translators/snuba/function_call_mappers.py index 2fb7b53dde4..c900278578e 100644 --- a/snuba/clickhouse/translators/snuba/function_call_mappers.py +++ b/snuba/clickhouse/translators/snuba/function_call_mappers.py @@ -1,5 +1,4 @@ from dataclasses import dataclass -from typing import Optional, Tuple, Union from snuba.clickhouse.translators.snuba import SnubaClickhouseStrictTranslator from snuba.clickhouse.translators.snuba.allowed import ( @@ -11,10 +10,10 @@ def _build_parameters( - expression: Union[FunctionCall, CurriedFunctionCall], + expression: FunctionCall | CurriedFunctionCall, children_translator: SnubaClickhouseStrictTranslator, aggregated_col_name: str, -) -> Tuple[Expression, ...]: +) -> tuple[Expression, ...]: assert isinstance(expression.parameters[0], ColumnExpr) return ( ColumnExpr(None, expression.parameters[0].table_name, aggregated_col_name), @@ -26,7 +25,7 @@ def _should_transform_aggregation( function_name: str, expected_function_name: str, column_to_map: str, - function_call: Union[FunctionCall, CurriedFunctionCall], + function_call: FunctionCall | CurriedFunctionCall, ) -> bool: return ( function_name == expected_function_name @@ -52,7 +51,7 @@ def attempt_map( self, expression: FunctionCall, children_translator: SnubaClickhouseStrictTranslator, - ) -> Optional[FunctionCall]: + ) -> FunctionCall | None: if not _should_transform_aggregation( expression.function_name, self.from_name, self.column_to_map, expression ): @@ -81,7 +80,7 @@ def attempt_map( self, expression: CurriedFunctionCall, children_translator: SnubaClickhouseStrictTranslator, - ) -> Optional[CurriedFunctionCall]: + ) -> CurriedFunctionCall | None: if not _should_transform_aggregation( expression.internal_function.function_name, self.from_name, diff --git a/snuba/clickhouse/translators/snuba/mappers.py b/snuba/clickhouse/translators/snuba/mappers.py index 07064f8e2a4..62e2d9fc79d 100644 --- a/snuba/clickhouse/translators/snuba/mappers.py +++ b/snuba/clickhouse/translators/snuba/mappers.py @@ -1,6 +1,5 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Optional, Tuple from snuba.clickhouse.translators.snuba import SnubaClickhouseStrictTranslator from snuba.clickhouse.translators.snuba.allowed import ( @@ -35,7 +34,7 @@ # This is a workaround for a mypy bug, found here: https://github.com/python/mypy/issues/5374 @dataclass(frozen=True) class _ColumnToExpression: - from_table_name: Optional[str] + from_table_name: str | None from_col_name: str @@ -51,14 +50,13 @@ def attempt_map( self, expression: ColumnExpr, children_translator: SnubaClickhouseStrictTranslator, - ) -> Optional[ValidColumnMappings]: + ) -> ValidColumnMappings | None: if ( expression.column_name == self.from_col_name and expression.table_name == self.from_table_name ): return self._produce_output(expression) - else: - return None + return None @abstractmethod def _produce_output(self, expression: ColumnExpr) -> ValidColumnMappings: @@ -73,7 +71,7 @@ class ColumnToColumn(ColumnToExpression): The alias is not transformed. """ - to_table_name: Optional[str] + to_table_name: str | None to_col_name: str def _produce_output(self, expression: ColumnExpr) -> ColumnExpr: @@ -103,7 +101,7 @@ class ColumnToFunction(ColumnToExpression): """ to_function_name: str - to_function_params: Tuple[Expression, ...] + to_function_params: tuple[Expression, ...] def _produce_output(self, expression: ColumnExpr) -> FunctionCallExpr: return FunctionCallExpr( @@ -137,8 +135,8 @@ class ColumnToIPAddress(ColumnToFunction): TODO: Can remove when we support dynamic expression parsing in config """ - def __init__(self, from_table_name: Optional[str], from_col_name: str) -> None: - to_function_params: Tuple[FunctionCallExpr, ...] = ( + def __init__(self, from_table_name: str | None, from_col_name: str) -> None: + to_function_params: tuple[FunctionCallExpr, ...] = ( FunctionCallExpr( None, "IPv4NumToString", @@ -160,8 +158,8 @@ class ColumnToNullIf(ColumnToFunction): TODO: Can remove when we support dynamic expression parsing in config """ - def __init__(self, from_table_name: Optional[str], from_col_name: str) -> None: - to_function_params: Tuple[ColumnExpr, LiteralExpr] = ( + def __init__(self, from_table_name: str | None, from_col_name: str) -> None: + to_function_params: tuple[ColumnExpr, LiteralExpr] = ( ColumnExpr(None, from_table_name, from_col_name), LiteralExpr(None, ""), ) @@ -175,7 +173,7 @@ class ColumnToCurriedFunction(ColumnToExpression): """ to_internal_function: FunctionCallExpr - to_function_params: Tuple[Expression, ...] + to_function_params: tuple[Expression, ...] def _produce_output(self, expression: ColumnExpr) -> CurriedFunctionCall: return CurriedFunctionCall( @@ -192,9 +190,9 @@ class SubscriptableMapper(SubscriptableReferenceMapper): into a Clickhouse array access. """ - from_column_table: Optional[str] + from_column_table: str | None from_column_name: str - to_nested_col_table: Optional[str] + to_nested_col_table: str | None to_nested_col_name: str value_subcolumn_name: str = "value" nullable: bool = False @@ -203,7 +201,7 @@ def attempt_map( self, expression: SubscriptableReference, children_translator: SnubaClickhouseStrictTranslator, - ) -> Optional[FunctionCallExpr]: + ) -> FunctionCallExpr | None: if ( expression.column.table_name == self.from_column_table and expression.column.column_name == self.from_column_name @@ -226,8 +224,7 @@ def attempt_map( self.value_subcolumn_name, ) ) - else: - return None + return None @dataclass(frozen=True) @@ -236,9 +233,9 @@ class SubscriptableHashBucketMapper(SubscriptableReferenceMapper): Maps a key into the appropriate bucket by hashing the key. For example, hello[test] might go to attr_str_22['test'] """ - from_column_table: Optional[str] + from_column_table: str | None from_column_name: str - to_col_table: Optional[str] + to_col_table: str | None to_col_name: str num_attribute_buckets: int @@ -246,7 +243,7 @@ def attempt_map( self, expression: SubscriptableReference, children_translator: SnubaClickhouseStrictTranslator, - ) -> Optional[FunctionCallExpr]: + ) -> FunctionCallExpr | None: if ( expression.column.table_name != self.from_column_table or expression.column.column_name != self.from_column_name @@ -273,7 +270,7 @@ class ColumnToMapping(ColumnToExpression): array access. """ - to_nested_col_table_name: Optional[str] + to_nested_col_table_name: str | None to_nested_col_name: str to_nested_mapping_key: str nullable: bool = False @@ -287,19 +284,18 @@ def _produce_output(self, expression: ColumnExpr) -> FunctionCallExpr: LiteralExpr(None, self.to_nested_mapping_key), "value", ) - else: - return build_nullable_mapping_expr( - expression.alias, - self.to_nested_col_table_name, - self.to_nested_col_name, - LiteralExpr(None, self.to_nested_mapping_key), - "value", - ) + return build_nullable_mapping_expr( + expression.alias, + self.to_nested_col_table_name, + self.to_nested_col_name, + LiteralExpr(None, self.to_nested_mapping_key), + "value", + ) def build_mapping_expr( - alias: Optional[str], - table_name: Optional[str], + alias: str | None, + table_name: str | None, col_name: str, mapping_key: Expression, value_subcolumn_name: str, @@ -316,8 +312,8 @@ def build_mapping_expr( def build_nullable_mapping_expr( - alias: Optional[str], - table_name: Optional[str], + alias: str | None, + table_name: str | None, col_name: str, mapping_key: Expression, value_subcolumn_name: str, @@ -375,7 +371,7 @@ def attempt_map( self, expression: FunctionCallExpr, children_translator: SnubaClickhouseStrictTranslator, - ) -> Optional[FunctionCallExpr]: + ) -> FunctionCallExpr | None: if expression.function_name != self.from_name: return None diff --git a/snuba/clickhouse/translators/snuba/mapping.py b/snuba/clickhouse/translators/snuba/mapping.py index 51afc3cbea4..a69f88d2da2 100644 --- a/snuba/clickhouse/translators/snuba/mapping.py +++ b/snuba/clickhouse/translators/snuba/mapping.py @@ -1,7 +1,7 @@ from __future__ import annotations +from collections.abc import MutableMapping, Sequence from dataclasses import dataclass, field, replace -from typing import MutableMapping, Sequence from snuba.clickhouse.query import Expression from snuba.clickhouse.translators.snuba import SnubaClickhouseStrictTranslator diff --git a/snuba/clusters/cluster.py b/snuba/clusters/cluster.py index 60fd982a62d..e5aabb32cc1 100644 --- a/snuba/clusters/cluster.py +++ b/snuba/clusters/cluster.py @@ -1,18 +1,12 @@ from abc import ABC, abstractmethod +from collections.abc import Mapping, MutableMapping, Sequence from dataclasses import dataclass from enum import Enum from threading import Lock from typing import ( Any, - Dict, Generic, - Mapping, - MutableMapping, NamedTuple, - Optional, - Sequence, - Set, - Tuple, TypeVar, ) @@ -36,7 +30,7 @@ class ClickhouseClientSettingsType(NamedTuple): settings: Mapping[str, Any] - timeout: Optional[int] + timeout: int | None class ConnectionId(NamedTuple): @@ -104,8 +98,8 @@ class ClickhouseClientSettings(Enum): class ClickhouseNode: host_name: str port: int - shard: Optional[int] = None - replica: Optional[int] = None + shard: int | None = None + replica: int | None = None def __str__(self) -> str: return f"{self.host_name}:{self.port}" @@ -137,13 +131,13 @@ class Cluster(ABC, Generic[TWriterOptions]): - optimize """ - def __init__(self, storage_sets: Set[str]): + def __init__(self, storage_sets: set[str]): self.__storage_sets = storage_sets # register the cluster's storage sets for storage_set in storage_sets: register_storage_set_key(storage_set) - def get_storage_set_keys(self) -> Set[StorageSetKey]: + def get_storage_set_keys(self) -> set[StorageSetKey]: return {StorageSetKey(storage_set) for storage_set in self.__storage_sets} @abstractmethod @@ -155,26 +149,26 @@ def get_batch_writer( self, metrics: MetricsBackend, insert_statement: InsertStatement, - encoding: Optional[str], + encoding: str | None, options: TWriterOptions, - chunk_size: Optional[int], + chunk_size: int | None, buffer_size: int, ) -> BatchWriter[JSONRow]: raise NotImplementedError -ClickhouseWriterOptions = Optional[Mapping[str, Any]] +ClickhouseWriterOptions = Mapping[str, Any] | None -CacheKey = Tuple[ +CacheKey = tuple[ ClickhouseNode, ClickhouseClientSettings, str, str, str, bool, - Optional[str], - Optional[bool], + str | None, + bool | None, ] @@ -191,8 +185,8 @@ def get_node_connection( password: str, database: str, secure: bool, - ca_certs: Optional[str], - verify: Optional[bool], + ca_certs: str | None, + verify: bool | None, ) -> ClickhousePool: with self.__lock: settings, timeout = client_settings.value @@ -255,16 +249,16 @@ def __init__( database: str, http_port: int, secure: bool, - ca_certs: Optional[str], - verify: Optional[bool], - storage_sets: Set[str], + ca_certs: str | None, + verify: bool | None, + storage_sets: set[str], single_node: bool, # The cluster name and distributed cluster name only apply if single_node is set to False - cluster_name: Optional[str] = None, - distributed_cluster_name: Optional[str] = None, - cache_partition_id: Optional[str] = None, - query_settings_prefix: Optional[str] = None, - max_connections: Optional[int] = None, + cluster_name: str | None = None, + distributed_cluster_name: str | None = None, + cache_partition_id: str | None = None, + query_settings_prefix: str | None = None, + max_connections: int | None = None, block_connections: bool = False, ): super().__init__(storage_sets) @@ -283,8 +277,8 @@ def __init__( self.__single_node = single_node self.__cluster_name = cluster_name self.__distributed_cluster_name = distributed_cluster_name - self.__reader: Optional[Reader] = None - self.__deleter: Optional[Reader] = None + self.__reader: Reader | None = None + self.__deleter: Reader | None = None self.__connection_cache = connection_cache self.__cache_partition_id = cache_partition_id self.__query_settings_prefix = query_settings_prefix @@ -292,7 +286,7 @@ def __init__( def __str__(self) -> str: return str(self.__query_node) - def get_credentials(self) -> Tuple[str, str]: + def get_credentials(self) -> tuple[str, str]: """ Returns the user credentials for the Clickhouse connection """ @@ -354,9 +348,9 @@ def get_batch_writer( self, metrics: MetricsBackend, insert_statement: InsertStatement, - encoding: Optional[str], + encoding: str | None, options: ClickhouseWriterOptions, - chunk_size: Optional[int], + chunk_size: int | None, buffer_size: int, ) -> BatchWriter[JSONRow]: return HTTPBatchWriter( @@ -385,10 +379,10 @@ def is_single_node(self) -> bool: """ return self.__single_node - def get_clickhouse_cluster_name(self) -> Optional[str]: + def get_clickhouse_cluster_name(self) -> str | None: return self.__cluster_name - def get_clickhouse_distributed_cluster_name(self) -> Optional[str]: + def get_clickhouse_distributed_cluster_name(self) -> str | None: return self.__distributed_cluster_name def get_database(self) -> str: @@ -462,10 +456,8 @@ def get_secure(self) -> bool: verify=cluster.get("verify", False), storage_sets=cluster["storage_sets"], single_node=cluster["single_node"], - cluster_name=cluster["cluster_name"] if "cluster_name" in cluster else None, - distributed_cluster_name=( - cluster["distributed_cluster_name"] if "distributed_cluster_name" in cluster else None - ), + cluster_name=cluster.get("cluster_name", None), + distributed_cluster_name=(cluster.get("distributed_cluster_name", None)), cache_partition_id=cluster.get("cache_partition_id"), query_settings_prefix=cluster.get("query_settings_prefix"), max_connections=cluster.get("max_connections", _DEFAULT_MAX_CONNECTIONS), @@ -483,12 +475,12 @@ def get_secure(self) -> bool: "Storage set registered to more than one cluster" ) -_STORAGE_SET_CLUSTER_MAP: Dict[StorageSetKey, ClickhouseCluster] = { +_STORAGE_SET_CLUSTER_MAP: dict[StorageSetKey, ClickhouseCluster] = { storage_set: cluster for cluster in CLUSTERS for storage_set in cluster.get_storage_set_keys() } -def _get_storage_set_cluster_map() -> Dict[StorageSetKey, ClickhouseCluster]: +def _get_storage_set_cluster_map() -> dict[StorageSetKey, ClickhouseCluster]: return _STORAGE_SET_CLUSTER_MAP @@ -505,19 +497,17 @@ def _build_sliced_cluster(cluster: Mapping[str, Any]) -> ClickhouseCluster: verify=cluster.get("verify", False), storage_sets={storage_tuple[0] for storage_tuple in cluster["storage_set_slices"]}, single_node=cluster["single_node"], - cluster_name=cluster["cluster_name"] if "cluster_name" in cluster else None, - distributed_cluster_name=( - cluster["distributed_cluster_name"] if "distributed_cluster_name" in cluster else None - ), + cluster_name=cluster.get("cluster_name", None), + distributed_cluster_name=(cluster.get("distributed_cluster_name", None)), cache_partition_id=cluster.get("cache_partition_id"), query_settings_prefix=cluster.get("query_settings_prefix"), ) -_SLICED_STORAGE_SET_CLUSTER_MAP: Dict[Tuple[StorageSetKey, int], ClickhouseCluster] = {} +_SLICED_STORAGE_SET_CLUSTER_MAP: dict[tuple[StorageSetKey, int], ClickhouseCluster] = {} -def _get_sliced_storage_set_cluster_map() -> Dict[Tuple[StorageSetKey, int], ClickhouseCluster]: +def _get_sliced_storage_set_cluster_map() -> dict[tuple[StorageSetKey, int], ClickhouseCluster]: if len(_SLICED_STORAGE_SET_CLUSTER_MAP) == 0: for cluster in settings.SLICED_CLUSTERS: for storage_set_tuple in cluster["storage_set_slices"]: @@ -532,9 +522,7 @@ class UndefinedClickhouseCluster(SerializableException): pass -def get_cluster( - storage_set_key: StorageSetKey, slice_id: Optional[int] = None -) -> ClickhouseCluster: +def get_cluster(storage_set_key: StorageSetKey, slice_id: int | None = None) -> ClickhouseCluster: """Return a clickhouse cluster for a storage set key. If passing in a sliced storage set, a slice_id must be specified. diff --git a/snuba/clusters/storage_sets.py b/snuba/clusters/storage_sets.py index 403507a6a32..0858c513ffa 100644 --- a/snuba/clusters/storage_sets.py +++ b/snuba/clusters/storage_sets.py @@ -1,6 +1,7 @@ from __future__ import annotations -from typing import Any, FrozenSet, Iterator +from collections.abc import Iterator +from typing import Any _HARDCODED_STORAGE_SET_KEYS = { "CDC": "cdc", @@ -27,7 +28,7 @@ class _StorageSetKey(type): - def __getattr__(self, attr: str) -> "StorageSetKey": + def __getattr__(self, attr: str) -> StorageSetKey: if attr not in _HARDCODED_STORAGE_SET_KEYS and attr not in _REGISTERED_STORAGE_SET_KEYS: raise AttributeError(attr) @@ -76,12 +77,12 @@ def register_storage_set_key(key: str) -> StorageSetKey: # Storage sets enabled only when development features are enabled. -DEV_STORAGE_SETS: FrozenSet[StorageSetKey] = frozenset() +DEV_STORAGE_SETS: frozenset[StorageSetKey] = frozenset() # Storage sets in a group share the same query and distributed nodes but # do not have the same local node cluster configuration. # Joins can be performed across storage sets in the same group. -JOINABLE_STORAGE_SETS: FrozenSet[FrozenSet[StorageSetKey]] = frozenset( +JOINABLE_STORAGE_SETS: frozenset[frozenset[StorageSetKey]] = frozenset( { frozenset({StorageSetKey.EVENTS, StorageSetKey.EVENTS_RO}), frozenset( @@ -101,8 +102,4 @@ def is_valid_storage_set_combination(*storage_sets: StorageSetKey) -> bool: if len(all_storage_sets) <= 1: return True - for group in JOINABLE_STORAGE_SETS: - if all_storage_sets.issubset(group): - return True - - return False + return any(all_storage_sets.issubset(group) for group in JOINABLE_STORAGE_SETS) diff --git a/snuba/configs/configuration.py b/snuba/configs/configuration.py index 2c0bfd0f097..b72706142c2 100644 --- a/snuba/configs/configuration.py +++ b/snuba/configs/configuration.py @@ -1,7 +1,7 @@ import logging from abc import ABC, abstractmethod from dataclasses import dataclass, field, replace -from typing import Any, Type, TypedDict, TypeVar, cast, final +from typing import Any, TypedDict, TypeVar, cast, final from snuba.datasets.storages.storage_key import StorageKey from snuba.state import delete_config as delete_runtime_config @@ -86,8 +86,12 @@ def to_definition_dict(self) -> dict[str, Any]: ], } - def to_config_dict(self, value: Any = None, params: dict[str, Any] = {}) -> dict[str, Any]: + def to_config_dict( + self, value: Any = None, params: dict[str, Any] | None = None + ) -> dict[str, Any]: """Returns a dict representation of a live Config.""" + if params is None: + params = {} return { **self.__to_base_dict(), "value": value if value is not None else self.default, @@ -222,7 +226,7 @@ def _validate_config_params( for key in config.param_types if key not in params } - ) != dict(): + ) != {}: raise InvalidConfig( f"'{config_key}' missing required parameters: {diff} for {class_name}!" ) @@ -239,25 +243,24 @@ def _validate_config_params( # try casting to the right type, eg try int("10") expected_type = config.param_types[param_name] params[param_name] = expected_type(params[param_name]) - except Exception: + except Exception as e: raise InvalidConfig( f"'{config_key}' parameter '{param_name}' needs to be of type" f" {config.param_types[param_name].__name__} (not {type(params[param_name]).__name__})" f" for {class_name}!" - ) + ) from e # value isn't correct type - if value is not None: - if not isinstance(value, config.value_type): - try: - # try casting to the right type - config.value_type(value) - except Exception: - raise InvalidConfig( - f"'{config_key}' value needs to be of type" - f" {config.value_type.__name__} (not {type(value).__name__})" - f" for {class_name}!" - ) + if value is not None and not isinstance(value, config.value_type): + try: + # try casting to the right type + config.value_type(value) + except Exception as e: + raise InvalidConfig( + f"'{config_key}' value needs to be of type" + f" {config.value_type.__name__} (not {type(value).__name__})" + f" for {class_name}!" + ) from e return config @@ -276,7 +279,7 @@ def __deserialize_runtime_config_key(self, key: str) -> tuple[str, dict[str, Any # key is "storage.policy.config" or "storage.policy.config.param1:val1,param2:val2" _, _, config_key, *params = key.split(".") # (config_key, params) is ("config", []) or ("config", ["param1:val1,param2:val2"]) - params_dict = dict() + params_dict = {} if params: # convert ["param1:val1,param2:val2"] to {"param1": "val1", "param2": "val2"} [params_string] = params @@ -297,11 +300,11 @@ def get_current_configs(self) -> list[dict[str, Any]]: runtime_configs = get_all_runtime_configs(self._get_hash()) definitions = self.config_definitions() - required_configs = set( + required_configs = { config_name for config_name, config_def in definitions.items() if not config_def.param_types - ) + } detailed_configs: list[dict[str, Any]] = [] @@ -379,7 +382,7 @@ def __build_runtime_config_key(self, config: str, params: dict[str, Any]) -> str - `"mystorage.MyAllocationPolicy.my_config.a:1,b:2"` # sorted params """ parameters = "." - for param in sorted(list(params.keys())): + for param in sorted(params.keys()): param_sanitized = self.__escape_delimiter_chars(param) value_sanitized = self.__escape_delimiter_chars(params[param]) parameters += f"{param_sanitized}:{value_sanitized}," @@ -392,10 +395,12 @@ def _get_hash(self) -> str: def get_config_value( self, config_key: str, - params: dict[str, Any] = {}, + params: dict[str, Any] | None = None, validate: bool = True, ) -> Any: """Returns value of a configuration on this ConfigurableComponent, or the default if none exists in Redis.""" + if params is None: + params = {} config_definition = ( self._validate_config_params(config_key, params) if validate @@ -411,10 +416,12 @@ def set_config_value( self, config_key: str, value: Any, - params: dict[str, Any] = {}, + params: dict[str, Any] | None = None, user: str | None = None, ) -> None: """Sets a value of a configuration on this ConfigurableComponent.""" + if params is None: + params = {} config_definition = self._validate_config_params(config_key, params, value) # ensure correct type is stored value = config_definition.value_type(value) @@ -429,13 +436,15 @@ def set_config_value( def delete_config_value( self, config_key: str, - params: dict[str, Any] = {}, + params: dict[str, Any] | None = None, user: str | None = None, ) -> None: """ Deletes an instance of an optional configuration on this ConfigurableComponent. If this function is run on a required configuration, it resets the value to default instead. """ + if params is None: + params = {} self._validate_config_params(config_key, params) delete_runtime_config( key=self.__build_runtime_config_key(config_key, params), @@ -461,15 +470,15 @@ def to_dict(self) -> ConfigurableComponentData: ) @classmethod - def get_component_class(cls, namespace: str) -> Type["ConfigurableComponent"]: + def get_component_class(cls, namespace: str) -> type["ConfigurableComponent"]: return cast( - Type["ConfigurableComponent"], + type["ConfigurableComponent"], cls.class_from_name(f"{namespace}.{namespace}"), ) @classmethod - def get_from_name(cls: Type[T], name: str) -> Type[T]: - return cast(Type[T], cls.class_from_name(f"{cls.component_namespace()}.{name}")) + def get_from_name(cls: type[T], name: str) -> type[T]: + return cast(type[T], cls.class_from_name(f"{cls.component_namespace()}.{name}")) @classmethod def create_minimal_instance(cls, resource_identifier: str) -> "ConfigurableComponent": @@ -480,15 +489,17 @@ def all_names(cls) -> list[str]: """Returns all registered class names that belong to this component's namespace.""" # If called on ConfigurableComponent itself, return all registered classes if cls is ConfigurableComponent: - return list(getattr(cls, "_registry").all_names()) + return list(cls._registry.all_names()) # For subclasses, return only classes in the same namespace namespaced_classes = [] - for registered_cls in getattr(cls, "_registry").all_classes(): + for registered_cls in cls._registry.all_classes(): if ( hasattr(registered_cls, "component_namespace") and registered_cls.component_namespace() == cls.component_namespace() and registered_cls.config_key() != cls.config_key() ): - namespaced_classes.append(registered_cls.class_name()) + namespaced_classes.append( + cast("ConfigurableComponent", registered_cls).class_name() + ) return namespaced_classes diff --git a/snuba/consumers/consumer.py b/snuba/consumers/consumer.py index 875efc5128f..4415ccbe20c 100644 --- a/snuba/consumers/consumer.py +++ b/snuba/consumers/consumer.py @@ -3,22 +3,13 @@ import random import time from collections import defaultdict +from collections.abc import Callable, Mapping, MutableMapping, MutableSequence, Sequence from datetime import datetime from pickle import PickleBuffer from typing import ( Any, - Callable, - List, - Mapping, - MutableMapping, - MutableSequence, NamedTuple, - Optional, - Sequence, - Set, SupportsIndex, - Tuple, - Union, cast, ) @@ -67,12 +58,12 @@ class BytesInsertBatch(NamedTuple): # refer to InsertBatch for the meaning of these values, or the Rust # implementation of BytesInsertBatch - origin_timestamp: Optional[datetime] - sentry_received_timestamp: Optional[datetime] = None + origin_timestamp: datetime | None + sentry_received_timestamp: datetime | None = None def __reduce_ex__( self, protocol: SupportsIndex - ) -> Tuple[Any, Tuple[Sequence[Any], Optional[datetime], Optional[datetime]]]: + ) -> tuple[Any, tuple[Sequence[Any], datetime | None, datetime | None]]: if int(protocol) >= 5: return ( type(self), @@ -82,18 +73,17 @@ def __reduce_ex__( self.sentry_received_timestamp, ), ) - else: - return type(self), ( - self.rows, - self.origin_timestamp, - self.sentry_received_timestamp, - ) + return type(self), ( + self.rows, + self.origin_timestamp, + self.sentry_received_timestamp, + ) class LatencyRecorder: def __init__(self) -> None: self._sum = 0.0 - self._max: Optional[float] = None + self._max: float | None = None self._msg_count = 0 def record(self, latency_seconds: float) -> None: @@ -107,7 +97,7 @@ def avg_ms(self) -> float: return (self._sum / self._msg_count) * 1000 @property - def max_ms(self) -> Optional[float]: + def max_ms(self) -> float | None: if not self._max: return None return self._max * 1000 @@ -204,7 +194,7 @@ def submit(self, message: Message[ReplacementBatch]) -> None: self.__messages.append(message) def __delivery_callback( - self, error: Optional[Exception], message: Message[ReplacementBatch] + self, error: Exception | None, message: Message[ReplacementBatch] ) -> None: if error is not None: # errors are KafkaError objects and inherit from BaseException @@ -224,7 +214,10 @@ def close(self) -> None: self.__topic.name, key=key, value=rapidjson.dumps(value).encode("utf-8"), - on_delivery=self.__delivery_callback, + on_delivery=cast( + Callable[[KafkaError | None, ConfluentMessage], None], + self.__delivery_callback, + ), ) self.__producer.flush() @@ -234,21 +227,21 @@ class ProcessedMessageBatchWriter: def __init__( self, insert_batch_writer: InsertBatchWriter, - replacement_batch_writer: Optional[ReplacementBatchWriter] = None, + replacement_batch_writer: ReplacementBatchWriter | None = None, # If commit log config is passed, we will produce to the commit log topic # upon closing each batch. - commit_log_config: Optional[CommitLogConfig] = None, - metrics: Optional[MetricsBackend] = None, + commit_log_config: CommitLogConfig | None = None, + metrics: MetricsBackend | None = None, ) -> None: self.__insert_batch_writer = insert_batch_writer self.__replacement_batch_writer = replacement_batch_writer self.__commit_log_config = commit_log_config - self.__offsets_to_produce: MutableMapping[Partition, Tuple[int, datetime]] = {} - self.__received_timestamps: MutableMapping[Partition, List[float]] = defaultdict(list) + self.__offsets_to_produce: MutableMapping[Partition, tuple[int, datetime]] = {} + self.__received_timestamps: MutableMapping[Partition, list[float]] = defaultdict(list) self.__closed = False - def submit(self, message: Message[Union[None, BytesInsertBatch, ReplacementBatch]]) -> None: + def submit(self, message: Message[None | BytesInsertBatch | ReplacementBatch]) -> None: assert not self.__closed if message.payload is None: @@ -277,7 +270,7 @@ def submit(self, message: Message[Union[None, BytesInsertBatch, ReplacementBatch ) def __commit_message_delivery_callback( - self, error: Optional[KafkaError], message: ConfluentMessage + self, error: KafkaError | None, message: ConfluentMessage ) -> None: if error is not None: raise Exception(error.str()) @@ -315,7 +308,7 @@ def close(self) -> None: self.__commit_log_config.topic.name, key=payload.key, value=payload.value, - headers=payload.headers, + headers=cast("list[tuple[str, str | bytes | None]]", payload.headers), on_delivery=self.__commit_message_delivery_callback, ) self.__commit_log_config.producer.poll(0.0) @@ -325,7 +318,7 @@ def close(self) -> None: json_row_encoder = JSONRowEncoder() -values_row_encoders: MutableMapping[StorageKey, ValuesRowEncoder] = dict() +values_row_encoders: MutableMapping[StorageKey, ValuesRowEncoder] = {} def get_values_row_encoder(storage_key: StorageKey) -> ValuesRowEncoder: @@ -341,10 +334,10 @@ def get_values_row_encoder(storage_key: StorageKey) -> ValuesRowEncoder: def build_batch_writer( table_writer: TableWriter, metrics: MetricsBackend, - replacements_producer: Optional[ConfluentKafkaProducer] = None, - replacements_topic: Optional[Topic] = None, - commit_log_config: Optional[CommitLogConfig] = None, - slice_id: Optional[int] = None, + replacements_producer: ConfluentKafkaProducer | None = None, + replacements_topic: Topic | None = None, + commit_log_config: CommitLogConfig | None = None, + slice_id: int | None = None, ) -> Callable[[], ProcessedMessageBatchWriter]: assert not (replacements_producer is None) ^ (replacements_topic is None) supports_replacements = replacements_producer is not None @@ -358,7 +351,7 @@ def build_batch_writer( def build_writer() -> ProcessedMessageBatchWriter: insert_metrics = MetricsWrapper(metrics, "insertions") - replacement_batch_writer: Optional[ReplacementBatchWriter] + replacement_batch_writer: ReplacementBatchWriter | None if supports_replacements: assert replacements_producer is not None assert replacements_topic is not None @@ -383,23 +376,21 @@ def __init__( self, steps: Mapping[StorageKey, ProcessedMessageBatchWriter], # If passed, produces to the commit log after each batch is closed - commit_log_config: Optional[CommitLogConfig], - ignore_errors: Optional[Set[StorageKey]] = None, + commit_log_config: CommitLogConfig | None, + ignore_errors: set[StorageKey] | None = None, ): self.__steps = steps self.__closed = False self.__commit_log_config = commit_log_config self.__messages: MutableMapping[ StorageKey, - List[Message[Tuple[StorageKey, Union[None, BytesInsertBatch, ReplacementBatch]]]], + list[Message[tuple[StorageKey, None | BytesInsertBatch | ReplacementBatch]]], ] = defaultdict(list) - self.__offsets_to_produce: MutableMapping[Partition, Tuple[int, datetime]] = {} + self.__offsets_to_produce: MutableMapping[Partition, tuple[int, datetime]] = {} def submit( self, - message: Message[ - Sequence[Tuple[StorageKey, Union[None, BytesInsertBatch, ReplacementBatch]]] - ], + message: Message[Sequence[tuple[StorageKey, None | BytesInsertBatch | ReplacementBatch]]], ) -> None: assert not self.__closed @@ -423,7 +414,7 @@ def submit( def close(self) -> None: self.__closed = True - for storage_key, step in self.__steps.items(): + for _storage_key, step in self.__steps.items(): step.close() if self.__commit_log_config is not None: @@ -441,7 +432,7 @@ def close(self) -> None: self.__commit_log_config.topic.name, key=payload.key, value=payload.value, - headers=payload.headers, + headers=cast("list[tuple[str, str | bytes | None]]", payload.headers), on_delivery=self.__commit_message_delivery_callback, ) self.__commit_log_config.producer.poll(0.0) @@ -452,7 +443,7 @@ def close(self) -> None: self.__offsets_to_produce.clear() def __commit_message_delivery_callback( - self, error: Optional[KafkaError], message: ConfluentMessage + self, error: KafkaError | None, message: ConfluentMessage ) -> None: if error is not None: raise Exception(error.str()) @@ -464,7 +455,7 @@ class MultistorageKafkaPayload(NamedTuple): MultistorageProcessedMessage = Sequence[ - Tuple[StorageKey, Union[None, BytesInsertBatch, ReplacementBatch]] + tuple[StorageKey, None | BytesInsertBatch | ReplacementBatch] ] @@ -474,7 +465,7 @@ def process_message( snuba_logical_topic: SnubaTopic, enforce_schema: bool, message: Message[KafkaPayload], -) -> Union[None, BytesInsertBatch, ReplacementBatch]: +) -> None | BytesInsertBatch | ReplacementBatch: local_metrics = MetricsWrapper( metrics, tags={ @@ -557,5 +548,4 @@ def process_message( result.origin_timestamp, result.sentry_received_timestamp, ) - else: - return result + return result diff --git a/snuba/consumers/consumer_builder.py b/snuba/consumers/consumer_builder.py index 4b95a53f166..e583fc4f5a8 100644 --- a/snuba/consumers/consumer_builder.py +++ b/snuba/consumers/consumer_builder.py @@ -1,7 +1,7 @@ import functools import logging +from collections.abc import MutableMapping from dataclasses import dataclass -from typing import MutableMapping, Optional from arroyo.backends.kafka import ( KafkaConsumer, @@ -39,16 +39,16 @@ class KafkaParameters: group_id: str auto_offset_reset: str - strict_offset_reset: Optional[bool] + strict_offset_reset: bool | None queued_max_messages_kbytes: int queued_min_messages: int @dataclass(frozen=True) class ProcessingParameters: - processes: Optional[int] - input_block_size: Optional[int] - output_block_size: Optional[int] + processes: int | None + input_block_size: int | None + output_block_size: int | None class ConsumerBuilder: @@ -68,17 +68,17 @@ def __init__( processing_params: ProcessingParameters, max_batch_size: int, max_batch_time_ms: int, - max_insert_batch_size: Optional[int], - max_insert_batch_time_ms: Optional[int], + max_insert_batch_size: int | None, + max_insert_batch_time_ms: int | None, metrics: MetricsBackend, metrics_tags: MutableMapping[str, str], - slice_id: Optional[int], - join_timeout: Optional[float], + slice_id: int | None, + join_timeout: float | None, enforce_schema: bool, - profile_path: Optional[str] = None, - max_poll_interval_ms: Optional[int] = None, - health_check_file: Optional[str] = None, - group_instance_id: Optional[str] = None, + profile_path: str | None = None, + max_poll_interval_ms: int | None = None, + health_check_file: str | None = None, + group_instance_id: str | None = None, ) -> None: assert len(consumer_config.storages) == 1, "Only one storage supported" storage_key = StorageKey(consumer_config.storages[0].name) @@ -106,6 +106,7 @@ def __init__( else None ) + self.replacements_producer: Producer | None if self.__consumer_config.replacements_topic is not None: self.replacements_producer = Producer( build_kafka_configuration( @@ -125,6 +126,7 @@ def __init__( else None ) + self.commit_log_producer: Producer | None if self.__consumer_config.commit_log_topic is not None: self.commit_log_producer = Producer( build_kafka_configuration(self.__consumer_config.commit_log_topic.broker_config) @@ -151,13 +153,13 @@ def __init__( self.health_check_file = health_check_file self.group_instance_id = group_instance_id - self.dlq_producer: Optional[KafkaProducer] = None + self.dlq_producer: KafkaProducer | None = None def __build_consumer( self, strategy_factory: ProcessingStrategyFactory[KafkaPayload], input_topic: Topic, - dlq_policy: Optional[DlqPolicy[KafkaPayload]], + dlq_policy: DlqPolicy[KafkaPayload] | None, ) -> StreamProcessor[KafkaPayload]: configuration = build_kafka_consumer_configuration( self.__consumer_config.raw_topic.broker_config, @@ -215,6 +217,7 @@ def build_streaming_strategy_factory( processor = stream_loader.get_processor() if self.commit_log_topic: + assert self.commit_log_producer is not None commit_log_config = CommitLogConfig( self.commit_log_producer, self.commit_log_topic, self.group_id ) @@ -376,7 +379,7 @@ def build_lw_deletions_consumer( self.__build_default_dlq_policy(), ) - def __build_default_dlq_policy(self) -> Optional[DlqPolicy[KafkaPayload]]: + def __build_default_dlq_policy(self) -> DlqPolicy[KafkaPayload] | None: """ Default DLQ policy applies to the base consumer or the DLQ consumer when the selected policy is re-insert to DLQ. diff --git a/snuba/consumers/consumer_config.py b/snuba/consumers/consumer_config.py index 115cad18983..34aa2bdf8f6 100644 --- a/snuba/consumers/consumer_config.py +++ b/snuba/consumers/consumer_config.py @@ -1,5 +1,6 @@ +from collections.abc import Mapping, Sequence from dataclasses import dataclass, replace -from typing import Any, Mapping, Optional, Sequence +from typing import Any from snuba import settings from snuba.datasets.schemas.tables import TableSchema @@ -41,14 +42,14 @@ class TopicConfig: broker_config: Mapping[str, Any] logical_topic_name: str physical_topic_name: str - quantized_rebalance_consumer_group_delay_secs: Optional[int] + quantized_rebalance_consumer_group_delay_secs: int | None @dataclass(frozen=True) class EnvConfig: - sentry_dsn: Optional[str] - dogstatsd_host: Optional[str] - dogstatsd_port: Optional[int] + sentry_dsn: str | None + dogstatsd_host: str | None + dogstatsd_port: int | None default_retention_days: int lower_retention_days: int valid_retention_days: list[int] @@ -64,14 +65,14 @@ class ConsumerConfig: storages: Sequence[StorageConfig] raw_topic: TopicConfig - commit_log_topic: Optional[TopicConfig] - replacements_topic: Optional[TopicConfig] - accepted_outcomes_topic: Optional[TopicConfig] - dlq_topic: Optional[TopicConfig] + commit_log_topic: TopicConfig | None + replacements_topic: TopicConfig | None + accepted_outcomes_topic: TopicConfig | None + dlq_topic: TopicConfig | None max_batch_size: int max_batch_time_ms: int max_batch_size_calculation: str - env: Optional[EnvConfig] + env: EnvConfig | None accountant_topic: TopicConfig @@ -86,18 +87,18 @@ def _add_to_topic_broker_config( assert isinstance(param_key, str) # copy the broker config to avoid modifying the original - broker_config = {k: v for k, v in topic_config.broker_config.items()} + broker_config = dict(topic_config.broker_config.items()) broker_config[param_key] = param_value return replace(topic_config, broker_config=broker_config) def _resolve_topic_config( param: str, - topic_spec: Optional[KafkaTopicSpec], - cli_param: Optional[str], - slice_id: Optional[int], - quantized_rebalance_consumer_group_delay_secs: Optional[int] = None, -) -> Optional[TopicConfig]: + topic_spec: KafkaTopicSpec | None, + cli_param: str | None, + slice_id: int | None, + quantized_rebalance_consumer_group_delay_secs: int | None = None, +) -> TopicConfig | None: if topic_spec is None: if cli_param is not None: raise ValueError(f"{param} not supported for this storage") @@ -144,22 +145,22 @@ def _resolve_env_config() -> EnvConfig: def resolve_consumer_config( *, storage_names: Sequence[str], - raw_topic: Optional[str], - commit_log_topic: Optional[str], - replacements_topic: Optional[str], + raw_topic: str | None, + commit_log_topic: str | None, + replacements_topic: str | None, bootstrap_servers: Sequence[str], commit_log_bootstrap_servers: Sequence[str], replacement_bootstrap_servers: Sequence[str], - slice_id: Optional[int], + slice_id: int | None, max_batch_size: int, max_batch_time_ms: int, max_batch_size_calculation: str = "rows", - accepted_outcomes_topic: Optional[str] = None, + accepted_outcomes_topic: str | None = None, accepted_outcomes_bootstrap_servers: Sequence[str] = (), - queued_max_messages_kbytes: Optional[int] = None, - queued_min_messages: Optional[int] = None, - group_instance_id: Optional[str] = None, - quantized_rebalance_consumer_group_delay_secs: Optional[int] = None, + queued_max_messages_kbytes: int | None = None, + queued_min_messages: int | None = None, + group_instance_id: str | None = None, + quantized_rebalance_consumer_group_delay_secs: int | None = None, ) -> ConsumerConfig: """ Resolves the ClickHouse cluster and Kafka brokers, and the physical topic name @@ -318,36 +319,30 @@ def validate_storages(storages: Sequence[WritableTableStorage]) -> None: """ assert ( len( - set( - [ - storage.get_table_writer().get_stream_loader().get_default_topic_spec() - for storage in storages - ] - ) + { + storage.get_table_writer().get_stream_loader().get_default_topic_spec() + for storage in storages + } ) < 2 ), "All storages must have the same default topic spec" assert ( len( - set( - [ - storage.get_table_writer().get_stream_loader().get_commit_log_topic_spec() - for storage in storages - ] - ) + { + storage.get_table_writer().get_stream_loader().get_commit_log_topic_spec() + for storage in storages + } ) < 2 ), "All storages must have the same commit log topic spec" assert ( len( - set( - [ - storage.get_table_writer().get_stream_loader().get_replacement_topic_spec() - for storage in storages - ] - ) + { + storage.get_table_writer().get_stream_loader().get_replacement_topic_spec() + for storage in storages + } ) < 2 ), "All storages must have the same replacement topic spec" diff --git a/snuba/consumers/dlq.py b/snuba/consumers/dlq.py index 7b8c3d54b82..ce580b5d324 100644 --- a/snuba/consumers/dlq.py +++ b/snuba/consumers/dlq.py @@ -5,7 +5,7 @@ import time from dataclasses import dataclass from enum import Enum -from typing import Optional, TypeVar +from typing import TypeVar import rapidjson from arroyo.dlq import InvalidMessage @@ -44,7 +44,7 @@ class DlqInstruction: policy: DlqReplayPolicy status: DlqInstructionStatus storage_key: StorageKey - slice_id: Optional[int] + slice_id: int | None max_messages_to_process: int def to_bytes(self) -> bytes: @@ -79,7 +79,7 @@ def is_valid(self) -> bool: return self.storage_key.value not in ("errors", "transactions", "search_issues") -def load_instruction() -> Optional[DlqInstruction]: +def load_instruction() -> DlqInstruction | None: value = redis_client.get(DLQ_REDIS_KEY) if value is None: @@ -169,5 +169,5 @@ def terminate(self) -> None: logger.warning("Closing DLQ consumer after %d messages", self.__processed_messages) self.__next_step.terminate() - def join(self, timeout: Optional[float] = None) -> None: + def join(self, timeout: float | None = None) -> None: self.__next_step.join(timeout) diff --git a/snuba/consumers/rust_processor.py b/snuba/consumers/rust_processor.py index 91152d8c876..1eed623df5c 100644 --- a/snuba/consumers/rust_processor.py +++ b/snuba/consumers/rust_processor.py @@ -15,16 +15,10 @@ import logging import os from collections import deque -from datetime import datetime, timezone +from collections.abc import Mapping, Sequence +from datetime import UTC, datetime from typing import ( - Deque, - Mapping, - Optional, - Sequence, - Tuple, - Type, TypedDict, - Union, cast, ) @@ -44,10 +38,10 @@ logger = logging.getLogger(__name__) -processor: Optional[DatasetMessageProcessor] = None +processor: DatasetMessageProcessor | None = None -def initialize_processor(module: Optional[str] = None, classname: Optional[str] = None) -> None: +def initialize_processor(module: str | None = None, classname: str | None = None) -> None: if not module or not classname: module = os.environ.get("RUST_SNUBA_PROCESSOR_MODULE") classname = os.environ.get("RUST_SNUBA_PROCESSOR_CLASSNAME") @@ -56,7 +50,7 @@ def initialize_processor(module: Optional[str] = None, classname: Optional[str] return module_object = importlib.import_module(module) - Processor: Type[DatasetMessageProcessor] = getattr(module_object, classname) + Processor: type[DatasetMessageProcessor] = getattr(module_object, classname) global processor processor = Processor() @@ -65,15 +59,15 @@ def initialize_processor(module: Optional[str] = None, classname: Optional[str] initialize_processor() -def ensure_utc(value: Optional[datetime]) -> Optional[datetime]: +def ensure_utc(value: datetime | None) -> datetime | None: if value and value.tzinfo is None: - return value.replace(tzinfo=timezone.utc) + return value.replace(tzinfo=UTC) return value def process_rust_message( message: bytes, offset: int, partition: int, timestamp: datetime -) -> Tuple[Sequence[bytes], Optional[datetime], Optional[datetime]]: +) -> tuple[Sequence[bytes], datetime | None, datetime | None]: if processor is None: raise RuntimeError("processor not yet initialized") rv = processor.process_message( @@ -93,26 +87,26 @@ def process_rust_message( ) -Committable = Mapping[Tuple[str, int], int] +Committable = Mapping[tuple[str, int], int] MessageTimestamp = datetime -ReturnValue = Tuple[Sequence[bytes], Optional[datetime], Optional[datetime]] +ReturnValue = tuple[Sequence[bytes], datetime | None, datetime | None] -ReturnValueWithCommittable = Tuple[ReturnValue, MessageTimestamp, Committable] +ReturnValueWithCommittable = tuple[ReturnValue, MessageTimestamp, Committable] def wrap_process_message( message: Message[bytes], -) -> Union[FilteredPayload, ReturnValue]: +) -> FilteredPayload | ReturnValue: value = message.value assert isinstance(value, BrokerValue) try: return process_rust_message( value.payload, value.offset, value.partition.index, value.timestamp ) - except Exception: - raise InvalidMessage(value.partition, value.offset) + except Exception as e: + raise InvalidMessage(value.partition, value.offset) from e class TransformedMessages: @@ -121,7 +115,7 @@ class TransformedMessages: """ def __init__(self) -> None: - self.messages: Deque[Message[ReturnValue]] = deque() + self.messages: deque[Message[ReturnValue]] = deque() def append(self, value: Message[ReturnValue]) -> None: self.messages.append(value) @@ -135,7 +129,7 @@ def pop(self) -> Sequence[Message[ReturnValue]]: return messages -class Next(ProcessingStrategy[Union[FilteredPayload, ReturnValue]]): +class Next(ProcessingStrategy[FilteredPayload | ReturnValue]): """ Messages are passed to this step from RunTaskWithMultiprocessing, and they are added to transformed messages. They are removed and @@ -145,7 +139,7 @@ class Next(ProcessingStrategy[Union[FilteredPayload, ReturnValue]]): def __init__(self, transformed_messages: TransformedMessages) -> None: self.__transformed_messages = transformed_messages - def submit(self, message: Message[Union[FilteredPayload, ReturnValue]]) -> None: + def submit(self, message: Message[FilteredPayload | ReturnValue]) -> None: # XXX: Filtered payload are created by the multiprocessing strategy in place # of invalid messages so their offsets get committed. They are not currently # supported in the hybrid consumer. This means the offsets of invalid messages @@ -157,7 +151,7 @@ def submit(self, message: Message[Union[FilteredPayload, ReturnValue]]) -> None: def poll(self) -> None: pass - def join(self, timeout: Optional[float] = None) -> None: + def join(self, timeout: float | None = None) -> None: pass def close(self) -> None: @@ -170,9 +164,10 @@ def terminate(self) -> None: DEFAULT_BLOCK_SIZE = int(32 * 1e6) -InvalidMessageMetadata = TypedDict( - "InvalidMessageMetadata", {"topic": str, "partition": int, "offset": int} -) +class InvalidMessageMetadata(TypedDict): + topic: str + partition: int + offset: int class RunPythonMultiprocessing: @@ -189,7 +184,7 @@ def __init__(self, concurrency: int, max_queue_depth: int) -> None: transform_fn = wrap_process_message self.__pool = MultiprocessingPool(concurrency) # Message is carried over if we got MessageRejected from the next step - self.__carried_over_message: Optional[Message[bytes]] = None + self.__carried_over_message: Message[bytes] | None = None self.__inner = RunTaskWithMultiprocessing( transform_fn, @@ -208,7 +203,7 @@ def submit( offset: int, partition: int, timestamp: datetime, - ) -> Tuple[int, Optional[InvalidMessageMetadata]]: + ) -> tuple[int, InvalidMessageMetadata | None]: # HACK: There is probably a better way to handle exceptions in Rust # 0 means message successfully submitted # 1 means backpressure @@ -266,7 +261,7 @@ def poll(self) -> Sequence[ReturnValueWithCommittable]: return self.__get_transformed_messages() - def join(self, timeout: Optional[float] = None) -> Sequence[ReturnValueWithCommittable]: + def join(self, timeout: float | None = None) -> Sequence[ReturnValueWithCommittable]: """ Close and join inner strategy. Returns all available transformed rows """ diff --git a/snuba/consumers/schemas.py b/snuba/consumers/schemas.py index 880da6ede66..a4b4066ba57 100644 --- a/snuba/consumers/schemas.py +++ b/snuba/consumers/schemas.py @@ -1,5 +1,6 @@ import logging -from typing import Any, MutableMapping, Optional +from collections.abc import MutableMapping +from typing import Any import sentry_kafka_schemas import sentry_sdk @@ -11,7 +12,7 @@ logger = logging.getLogger(__name__) -def _get_codec_impl(topic: Topic) -> Optional[Codec[Any]]: +def _get_codec_impl(topic: Topic) -> Codec[Any] | None: """ This function returns either the schema if it is defined, or None if not. """ diff --git a/snuba/consumers/strategy_factory.py b/snuba/consumers/strategy_factory.py index 6b16f9d4e89..22dc2808840 100644 --- a/snuba/consumers/strategy_factory.py +++ b/snuba/consumers/strategy_factory.py @@ -1,5 +1,6 @@ from abc import abstractmethod -from typing import Callable, Mapping, MutableMapping, Optional, Protocol, Union +from collections.abc import Callable, Mapping, MutableMapping +from typing import Protocol from arroyo.backends.kafka import KafkaPayload from arroyo.commit import ONCE_PER_SECOND @@ -23,7 +24,7 @@ from snuba.consumers.dlq import ExitAfterNMessages from snuba.processor import ReplacementBatch -ProcessedMessage = Union[None, BytesInsertBatch, ReplacementBatch] +ProcessedMessage = None | BytesInsertBatch | ReplacementBatch class StreamMessageFilter(Protocol): @@ -60,22 +61,22 @@ class KafkaConsumerStrategyFactory(ProcessingStrategyFactory[KafkaPayload]): def __init__( self, - prefilter: Optional[StreamMessageFilter], + prefilter: StreamMessageFilter | None, process_message: Callable[[Message[KafkaPayload]], ProcessedMessage], collector: Callable[[], ProcessedMessageBatchWriter], max_batch_size: int, max_batch_time: float, - processes: Optional[int], - input_block_size: Optional[int], - output_block_size: Optional[int], - max_insert_batch_size: Optional[int], - max_insert_batch_time: Optional[float], + processes: int | None, + input_block_size: int | None, + output_block_size: int | None, + max_insert_batch_size: int | None, + max_insert_batch_time: float | None, metrics_tags: MutableMapping[str, str], # Passed in the case of DLQ consumer which exits after a certain number of messages # is processed - max_messages_to_process: Optional[int] = None, - initialize_parallel_transform: Optional[Callable[[], None]] = None, - health_check_file: Optional[str] = None, + max_messages_to_process: int | None = None, + initialize_parallel_transform: Callable[[], None] | None = None, + health_check_file: str | None = None, ) -> None: self.__prefilter = prefilter self.__process_message = process_message @@ -150,7 +151,7 @@ def flush_batch( transform_function = self.__process_message - strategy: ProcessingStrategy[Union[FilteredPayload, KafkaPayload]] + strategy: ProcessingStrategy[FilteredPayload | KafkaPayload] if self.__pool is None: strategy = RunTask(transform_function, collect) else: diff --git a/snuba/datasets/cdc/cdcprocessors.py b/snuba/datasets/cdc/cdcprocessors.py index 7b8a67b1cf6..1a692337fa1 100644 --- a/snuba/datasets/cdc/cdcprocessors.py +++ b/snuba/datasets/cdc/cdcprocessors.py @@ -2,8 +2,9 @@ import re from abc import ABC, abstractmethod +from collections.abc import Mapping, Sequence from datetime import datetime -from typing import Any, List, Mapping, Optional, Sequence, Type +from typing import Any from snuba.consumers.types import KafkaMessageMetadata from snuba.datasets.processors import DatasetMessageProcessor @@ -25,8 +26,7 @@ def parse_postgres_datetime(date: str) -> datetime: date = f"{date}00" if date_with_nanosec.match(date): return datetime.strptime(date, POSTGRES_DATE_FORMAT_WITH_NS) - else: - return datetime.strptime(date, POSTGRES_DATE_FORMAT_WITHOUT_NS) + return datetime.strptime(date, POSTGRES_DATE_FORMAT_WITHOUT_NS) def postgres_date_to_clickhouse(date: str) -> str: @@ -71,7 +71,7 @@ def to_clickhouse(self) -> WriterTableRow: class CdcProcessor(DatasetMessageProcessor, metaclass=RegisteredClass): - def __init__(self, pg_table: str, message_row_class: Type[CdcMessageRow]): + def __init__(self, pg_table: str, message_row_class: type[CdcMessageRow]): self.pg_table = pg_table self._message_row_class = message_row_class @@ -96,10 +96,10 @@ def _process_update( columnnames: Sequence[str], columnvalues: Sequence[Any], ) -> Sequence[WriterTableRow]: - old_key = dict(zip(key["keynames"], key["keyvalues"])) + old_key = dict(zip(key["keynames"], key["keyvalues"], strict=False)) new_key = {key: columnvalues[columnnames.index(key)] for key in key["keynames"]} - ret: List[WriterTableRow] = [] + ret: list[WriterTableRow] = [] if old_key != new_key: ret.extend(self._process_delete(offset, key)) @@ -117,12 +117,12 @@ def _process_delete( def process_message( self, value: Mapping[str, Any], metadata: KafkaMessageMetadata - ) -> Optional[ProcessedMessage]: + ) -> ProcessedMessage | None: assert isinstance(value, dict) offset = metadata.offset event = value["event"] - timestamp: Optional[datetime] = None + timestamp: datetime | None = None if event == "begin": messages = self._process_begin(offset) elif event == "commit": @@ -148,10 +148,12 @@ def process_message( messages = self._process_delete(offset, value["oldkeys"]) else: raise ValueError( - "Invalid value for operation in replication log: %s" % value["kind"] + "Invalid value for operation in replication log: {}".format(value["kind"]) ) else: - raise ValueError("Invalid value for event in replication log: %s" % value["event"]) + raise ValueError( + "Invalid value for event in replication log: {}".format(value["event"]) + ) if not messages: return None diff --git a/snuba/datasets/cdc/groupassignee_processor.py b/snuba/datasets/cdc/groupassignee_processor.py index 3431b981046..6e6f1b18051 100644 --- a/snuba/datasets/cdc/groupassignee_processor.py +++ b/snuba/datasets/cdc/groupassignee_processor.py @@ -1,8 +1,9 @@ from __future__ import annotations +from collections.abc import Mapping, Sequence from dataclasses import dataclass from datetime import datetime -from typing import Any, Mapping, Optional, Sequence, Union +from typing import Any from snuba.datasets.cdc.cdcprocessors import ( CdcMessageRow, @@ -15,18 +16,18 @@ @dataclass(frozen=True) class GroupAssigneeRecord: - date_added: Union[datetime, str] - user_id: Optional[int] - team_id: Optional[int] + date_added: datetime | str + user_id: int | None + team_id: int | None @dataclass(frozen=True) class GroupAssigneeRow(CdcMessageRow): - offset: Optional[int] + offset: int | None record_deleted: bool project_id: int group_id: int - record_content: Union[None, GroupAssigneeRecord] + record_content: None | GroupAssigneeRecord @classmethod def from_wal( @@ -35,7 +36,7 @@ def from_wal( columnnames: Sequence[str], columnvalues: Sequence[Any], ) -> GroupAssigneeRow: - raw_data = dict(zip(columnnames, columnvalues)) + raw_data = dict(zip(columnnames, columnvalues, strict=False)) return cls( offset=offset, record_deleted=False, diff --git a/snuba/datasets/cdc/groupedmessage_processor.py b/snuba/datasets/cdc/groupedmessage_processor.py index 7d465cd15b5..9dad1bd00f4 100644 --- a/snuba/datasets/cdc/groupedmessage_processor.py +++ b/snuba/datasets/cdc/groupedmessage_processor.py @@ -1,8 +1,9 @@ from __future__ import annotations +from collections.abc import Mapping, Sequence from dataclasses import dataclass from datetime import datetime -from typing import Any, Mapping, Optional, Sequence, Union +from typing import Any from snuba.datasets.cdc.cdcprocessors import ( CdcMessageRow, @@ -18,8 +19,8 @@ class GroupMessageRecord: status: int last_seen: datetime first_seen: datetime - active_at: Optional[datetime] = None - first_release_id: Optional[int] = None + active_at: datetime | None = None + first_release_id: int | None = None @dataclass(frozen=True) @@ -34,17 +35,17 @@ class RawGroupMessageRecord: status: int last_seen: str first_seen: str - active_at: Optional[str] = None - first_release_id: Optional[int] = None + active_at: str | None = None + first_release_id: int | None = None @dataclass(frozen=True) class GroupedMessageRow(CdcMessageRow): - offset: Optional[int] + offset: int | None project_id: int id: int record_deleted: bool - record_content: Union[None, GroupMessageRecord, RawGroupMessageRecord] + record_content: None | GroupMessageRecord | RawGroupMessageRecord @classmethod def from_wal( @@ -53,7 +54,7 @@ def from_wal( columnnames: Sequence[str], columnvalues: Sequence[Any], ) -> GroupedMessageRow: - raw_data = dict(zip(columnnames, columnvalues)) + raw_data = dict(zip(columnnames, columnvalues, strict=False)) return cls( offset=offset, project_id=raw_data["project_id"], @@ -111,7 +112,7 @@ def to_clickhouse(self) -> WriterTableRow: class GroupedMessageProcessor(CdcProcessor): def __init__(self) -> None: postgres_table = "sentry_groupedmessage" - super(GroupedMessageProcessor, self).__init__( + super().__init__( pg_table=postgres_table, message_row_class=GroupedMessageRow, ) diff --git a/snuba/datasets/cdc/row_processors.py b/snuba/datasets/cdc/row_processors.py index 9e7f4d0ddec..ba6b479e7c5 100644 --- a/snuba/datasets/cdc/row_processors.py +++ b/snuba/datasets/cdc/row_processors.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Type, cast +from typing import cast from snuba.datasets.cdc.groupassignee_processor import GroupAssigneeRow from snuba.datasets.cdc.groupedmessage_processor import GroupedMessageRow @@ -18,8 +18,8 @@ def from_kwargs(cls, **kwargs: str) -> "CdcRowProcessor": return cls(**kwargs) @classmethod - def get_from_name(cls, name: str) -> Type["CdcRowProcessor"]: - return cast(Type["CdcRowProcessor"], cls.class_from_name(name)) + def get_from_name(cls, name: str) -> type["CdcRowProcessor"]: + return cast(type["CdcRowProcessor"], cls.class_from_name(name)) @classmethod def config_key(cls) -> str: diff --git a/snuba/datasets/cdc/types.py b/snuba/datasets/cdc/types.py index fc49b6edc76..5d9c2ba86c6 100644 --- a/snuba/datasets/cdc/types.py +++ b/snuba/datasets/cdc/types.py @@ -1,4 +1,5 @@ -from typing import Any, Literal, Sequence, TypedDict, Union +from collections.abc import Sequence +from typing import Any, Literal, TypedDict class BeginEvent(TypedDict): @@ -44,10 +45,4 @@ class CommitEvent(TypedDict): event: Literal["commit"] -Event = Union[ - BeginEvent, - InsertEvent, - UpdateEvent, - DeleteEvent, - CommitEvent, -] +Event = BeginEvent | InsertEvent | UpdateEvent | DeleteEvent | CommitEvent diff --git a/snuba/datasets/configuration/entity_builder.py b/snuba/datasets/configuration/entity_builder.py index 4828c0d52a4..109d883453e 100644 --- a/snuba/datasets/configuration/entity_builder.py +++ b/snuba/datasets/configuration/entity_builder.py @@ -1,7 +1,8 @@ from __future__ import annotations import logging -from typing import Any, List, Optional, Sequence +from collections.abc import Sequence +from typing import Any import snuba.clickhouse.translators.snuba.function_call_mappers # noqa from snuba.clickhouse.translators.snuba.allowed import ( @@ -104,7 +105,7 @@ def _build_storage_selector(config_storage_selector: dict[str, Any]) -> QuerySto def _build_subscription_processors( config: dict[str, Any], -) -> Optional[Sequence[EntitySubscriptionProcessor]]: +) -> Sequence[EntitySubscriptionProcessor] | None: if "subscription_processors" in config: processors: Sequence[EntitySubscriptionProcessor] = [ EntitySubscriptionProcessor.get_from_name(pro_config["processor"])(**pro_config["args"]) @@ -116,7 +117,7 @@ def _build_subscription_processors( def _build_subscription_validators( config: dict[str, Any], -) -> Optional[Sequence[EntitySubscriptionValidator]]: +) -> Sequence[EntitySubscriptionValidator] | None: if "subscription_validators" in config: validators: Sequence[EntitySubscriptionValidator] = [ EntitySubscriptionValidator.get_from_name(val_config["validator"])(**val_config["args"]) @@ -128,7 +129,7 @@ def _build_subscription_validators( def _build_storage_connections( config_storages: list[dict[str, Any]], -) -> List[EntityStorageConnection]: +) -> list[EntityStorageConnection]: return [ EntityStorageConnection( storage=get_storage(StorageKey(storage_connection["storage"])), @@ -137,9 +138,7 @@ def _build_storage_connections( if "translation_mappers" in storage_connection else TranslationMappers() ), - is_writable=( - storage_connection["is_writable"] if "is_writable" in storage_connection else False - ), + is_writable=(storage_connection.get("is_writable", False)), ) for storage_connection in config_storages ] @@ -178,9 +177,9 @@ def _build_validation_mode(mode: str | None) -> ColumnValidationMode | None: if mode == "do_nothing": return ColumnValidationMode.DO_NOTHING - elif mode == "warn": + if mode == "warn": return ColumnValidationMode.WARN - elif mode == "error": + if mode == "error": return ColumnValidationMode.ERROR raise InvalidEntityConfigException(f"{mode} is not a valid validation mode") diff --git a/snuba/datasets/configuration/storage_builder.py b/snuba/datasets/configuration/storage_builder.py index 77cf00071c0..793c586666f 100644 --- a/snuba/datasets/configuration/storage_builder.py +++ b/snuba/datasets/configuration/storage_builder.py @@ -2,6 +2,8 @@ from typing import Any +from arroyo.backends.kafka import KafkaPayload + from snuba.clickhouse.columns import ColumnSet from snuba.clusters.storage_sets import StorageSetKey from snuba.datasets.cdc.cdcstorage import CdcStorage @@ -79,17 +81,13 @@ def __build_readable_storage_kwargs(config: dict[str, Any]) -> dict[str, Any]: "storage_set_key": StorageSetKey(config[STORAGE][SET_KEY]), SCHEMA: __build_storage_schema(config), READINESS_STATE: ReadinessState(config[READINESS_STATE]), - QUERY_PROCESSORS: get_query_processors( - config[QUERY_PROCESSORS] if QUERY_PROCESSORS in config else [] - ), + QUERY_PROCESSORS: get_query_processors(config.get(QUERY_PROCESSORS, [])), DELETION_SETTINGS: ( DeletionSettings(**config[DELETION_SETTINGS]) if DELETION_SETTINGS in config else {} ), - DELETION_PROCESSORS: get_query_processors( - config[DELETION_PROCESSORS] if DELETION_PROCESSORS in config else [] - ), + DELETION_PROCESSORS: get_query_processors(config.get(DELETION_PROCESSORS, [])), MANDATORY_CONDITION_CHECKERS: get_mandatory_condition_checkers( - config[MANDATORY_CONDITION_CHECKERS] if MANDATORY_CONDITION_CHECKERS in config else [] + config.get(MANDATORY_CONDITION_CHECKERS, []) ), ALLOCATION_POLICIES: ( [ @@ -117,14 +115,14 @@ def __build_readable_storage_kwargs(config: dict[str, Any]) -> dict[str, Any]: if DELETE_ALLOCATION_POLICIES in config else [] ), - REQUIRED_TIME_COLUMN: config.get(REQUIRED_TIME_COLUMN, None), + REQUIRED_TIME_COLUMN: config.get(REQUIRED_TIME_COLUMN), } def __build_writable_storage_kwargs(config: dict[str, Any]) -> dict[str, Any]: return { STREAM_LOADER: build_stream_loader(config[STREAM_LOADER]), - WRITER_OPTIONS: config[WRITER_OPTIONS] if WRITER_OPTIONS in config else {}, + WRITER_OPTIONS: config.get(WRITER_OPTIONS, {}), REPLACER_PROCESSOR: ( ReplacerProcessor.get_from_name(config[REPLACER_PROCESSOR]["processor"]).from_kwargs( **config[REPLACER_PROCESSOR].get("args", {}) @@ -183,7 +181,7 @@ def build_stream_loader(loader_config: dict[str, Any]) -> KafkaStreamLoader: assert processor is not None default_topic = Topic(loader_config["default_topic"]) # optionals - pre_filter = None + pre_filter: StreamMessageFilter[KafkaPayload] | None = None if PRE_FILTER in loader_config and loader_config[PRE_FILTER] is not None: pre_filter = StreamMessageFilter.get_from_name( loader_config[PRE_FILTER]["type"] diff --git a/snuba/datasets/configuration/utils.py b/snuba/datasets/configuration/utils.py index 3567caf54f3..cbfc875df44 100644 --- a/snuba/datasets/configuration/utils.py +++ b/snuba/datasets/configuration/utils.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, Type, TypedDict +from typing import Any, TypedDict from snuba.clickhouse.columns import ( Array, @@ -67,7 +67,7 @@ def get_mandatory_condition_checkers( "Float": Float, } -SIMPLE_COLUMN_TYPES: dict[str, Type[ColumnType[SchemaModifiers]]] = { +SIMPLE_COLUMN_TYPES: dict[str, type[ColumnType[SchemaModifiers]]] = { **NUMBER_COLUMN_TYPES, "String": String, "DateTime": DateTime, @@ -89,7 +89,7 @@ def __parse_number( col: dict[str, Any], modifiers: SchemaModifiers | None ) -> ColumnType[SchemaModifiers]: col_type = NUMBER_COLUMN_TYPES[col["type"]](col["args"]["size"], modifiers) - assert isinstance(col_type, UInt) or isinstance(col_type, Float) or isinstance(col_type, Int) + assert isinstance(col_type, (UInt, Float, Int)) return col_type diff --git a/snuba/datasets/dataset.py b/snuba/datasets/dataset.py index 697aadb7cad..0c58a54f5f8 100644 --- a/snuba/datasets/dataset.py +++ b/snuba/datasets/dataset.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Sequence +from collections.abc import Sequence from snuba.datasets.entities.entity_key import EntityKey from snuba.datasets.entities.factory import get_entity diff --git a/snuba/datasets/deletion_settings.py b/snuba/datasets/deletion_settings.py index 67e75678376..ecf7d42b37a 100644 --- a/snuba/datasets/deletion_settings.py +++ b/snuba/datasets/deletion_settings.py @@ -1,7 +1,8 @@ from __future__ import annotations +from collections.abc import Sequence from dataclasses import dataclass, field -from typing import Dict, List, Optional, Sequence +from typing import cast from sentry_protos.snuba.v1.request_common_pb2 import TraceItemType @@ -15,8 +16,8 @@ class DeletionSettings: bulk_delete_only: bool = False allowed_columns: Sequence[str] = field(default_factory=list) max_rows_to_delete: int = MAX_ROWS_TO_DELETE_DEFAULT - allowed_attributes_by_item_type: Dict[str, List[str]] = field(default_factory=dict) - partition_column: Optional[str] = None + allowed_attributes_by_item_type: dict[str, list[str]] = field(default_factory=dict) + partition_column: str | None = None def get_trace_item_type_name(item_type: int) -> str: @@ -38,7 +39,7 @@ def get_trace_item_type_name(item_type: int) -> str: try: # Get the full protobuf enum name (e.g., "TRACE_ITEM_TYPE_SPAN") # Cast to TraceItemType.ValueType to satisfy type checker - full_name = TraceItemType.Name(item_type) # type: ignore[arg-type] + full_name = TraceItemType.Name(cast("TraceItemType.ValueType", item_type)) # Strip the "TRACE_ITEM_TYPE_" prefix and convert to lowercase prefix = "TRACE_ITEM_TYPE_" diff --git a/snuba/datasets/entities/entity_data_model.py b/snuba/datasets/entities/entity_data_model.py index 18d8101cce6..6186ddcd6b9 100644 --- a/snuba/datasets/entities/entity_data_model.py +++ b/snuba/datasets/entities/entity_data_model.py @@ -7,4 +7,4 @@ class EntityColumnSet(ColumnSet): """ def __repr__(self) -> str: - return "EntityColumnSet({})".format(repr(self.columns)) + return f"EntityColumnSet({repr(self.columns)})" diff --git a/snuba/datasets/entities/entity_key.py b/snuba/datasets/entities/entity_key.py index 2c01b54a20f..a9e63ad0406 100644 --- a/snuba/datasets/entities/entity_key.py +++ b/snuba/datasets/entities/entity_key.py @@ -1,12 +1,13 @@ from __future__ import annotations -from typing import Any, Iterator +from collections.abc import Iterator +from typing import Any REGISTERED_ENTITY_KEYS: dict[str, str] = {} class _EntityKey(type): - def __getattr__(cls, attr: str) -> "EntityKey": + def __getattr__(cls, attr: str) -> EntityKey: if attr not in REGISTERED_ENTITY_KEYS: raise AttributeError(attr) diff --git a/snuba/datasets/entities/factory.py b/snuba/datasets/entities/factory.py index b34c228a1a7..17dd457a62a 100644 --- a/snuba/datasets/entities/factory.py +++ b/snuba/datasets/entities/factory.py @@ -1,7 +1,7 @@ from __future__ import annotations +from collections.abc import Sequence from glob import glob -from typing import Optional, Sequence, Type import sentry_sdk @@ -21,7 +21,7 @@ def __init__(self) -> None: with sentry_sdk.start_span(op="initialize", description="Entity Factory"): initialize_storage_factory() self._entity_map: dict[EntityKey, PluggableEntity] = {} - self._name_map: dict[Type[Entity], EntityKey] = {} + self._name_map: dict[type[Entity], EntityKey] = {} self.__initialize() def __initialize(self) -> None: @@ -41,7 +41,7 @@ def __initialize(self) -> None: self._name_map = {v.__class__: k for k, v in self._entity_map.items()} def all_names(self) -> Sequence[EntityKey]: - return [name for name in self._entity_map.keys()] + return list(self._entity_map.keys()) def get(self, name: EntityKey) -> Entity: try: @@ -63,7 +63,7 @@ class InvalidEntityError(SerializableException): """Exception raised on invalid entity access.""" -_ENT_FACTORY: Optional[_EntityFactory] = None +_ENT_FACTORY: _EntityFactory | None = None def _ent_factory() -> _EntityFactory: diff --git a/snuba/datasets/entities/storage_selectors/__init__.py b/snuba/datasets/entities/storage_selectors/__init__.py index bb8e7296f16..d2179ca2853 100644 --- a/snuba/datasets/entities/storage_selectors/__init__.py +++ b/snuba/datasets/entities/storage_selectors/__init__.py @@ -1,6 +1,7 @@ import os from abc import ABC, abstractmethod -from typing import Sequence, Type, cast +from collections.abc import Sequence +from typing import cast from snuba.datasets.storage import EntityStorageConnection from snuba.query.logical import Query @@ -19,8 +20,8 @@ def config_key(cls) -> str: return cls.__name__ @classmethod - def get_from_name(cls, name: str) -> Type["QueryStorageSelector"]: - return cast(Type["QueryStorageSelector"], cls.class_from_name(name)) + def get_from_name(cls, name: str) -> type["QueryStorageSelector"]: + return cast(type["QueryStorageSelector"], cls.class_from_name(name)) @abstractmethod def select_storage( diff --git a/snuba/datasets/entities/storage_selectors/eap_items.py b/snuba/datasets/entities/storage_selectors/eap_items.py index cc80c5cf7f1..87a966b274b 100644 --- a/snuba/datasets/entities/storage_selectors/eap_items.py +++ b/snuba/datasets/entities/storage_selectors/eap_items.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba import state from snuba.datasets.entities.storage_selectors import QueryStorageSelector diff --git a/snuba/datasets/entities/storage_selectors/errors.py b/snuba/datasets/entities/storage_selectors/errors.py index d5ec9df084b..1ec7d8fbb72 100644 --- a/snuba/datasets/entities/storage_selectors/errors.py +++ b/snuba/datasets/entities/storage_selectors/errors.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba import state from snuba.datasets.entities.storage_selectors import QueryStorageSelector diff --git a/snuba/datasets/entities/storage_selectors/outcomes.py b/snuba/datasets/entities/storage_selectors/outcomes.py index 4ad94ec2d03..82a3b3d9614 100644 --- a/snuba/datasets/entities/storage_selectors/outcomes.py +++ b/snuba/datasets/entities/storage_selectors/outcomes.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.datasets.entities.storage_selectors import QueryStorageSelector from snuba.datasets.entities.storage_selectors.selector import QueryStorageSelectorError diff --git a/snuba/datasets/entities/storage_selectors/selector.py b/snuba/datasets/entities/storage_selectors/selector.py index 2d368721d07..17cd023db37 100644 --- a/snuba/datasets/entities/storage_selectors/selector.py +++ b/snuba/datasets/entities/storage_selectors/selector.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.datasets.entities.storage_selectors import QueryStorageSelector from snuba.datasets.storage import EntityStorageConnection, ReadableTableStorage diff --git a/snuba/datasets/entity.py b/snuba/datasets/entity.py index d1fe8edcc58..b9f221ff332 100644 --- a/snuba/datasets/entity.py +++ b/snuba/datasets/entity.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Mapping, Optional, Sequence +from collections.abc import Mapping, Sequence from snuba.clickhouse.columns import ColumnSet from snuba.datasets.entity_subscriptions.processors import EntitySubscriptionProcessor @@ -32,11 +32,11 @@ def __init__( storages: Sequence[EntityStorageConnection], abstract_column_set: ColumnSet, join_relationships: Mapping[str, JoinRelationship], - validators: Optional[Sequence[QueryValidator]], - required_time_column: Optional[str], + validators: Sequence[QueryValidator] | None, + required_time_column: str | None, validate_data_model: ColumnValidationMode = ColumnValidationMode.ERROR, - subscription_processors: Optional[Sequence[EntitySubscriptionProcessor]], - subscription_validators: Optional[Sequence[EntitySubscriptionValidator]], + subscription_processors: Sequence[EntitySubscriptionProcessor] | None, + subscription_validators: Sequence[EntitySubscriptionValidator] | None, ) -> None: self.__storages = storages @@ -77,7 +77,7 @@ def get_data_model(self) -> ColumnSet: """ return self.__data_model - def get_join_relationship(self, relationship: str) -> Optional[JoinRelationship]: + def get_join_relationship(self, relationship: str) -> JoinRelationship | None: """ Fetch the join relationship specified by the relationship string. """ @@ -103,7 +103,7 @@ def get_all_storage_connections(self) -> Sequence[EntityStorageConnection]: """ return self.__storages - def get_writable_storage(self) -> Optional[WritableTableStorage]: + def get_writable_storage(self) -> WritableTableStorage | None: """ Temporarily support getting the writable storage from an entity. Once consumers/replacers no longer reference entity, this can be removed @@ -136,7 +136,7 @@ def get_validators(self) -> Sequence[QueryValidator]: def get_subscription_processors( self, - ) -> Optional[Sequence[EntitySubscriptionProcessor]]: + ) -> Sequence[EntitySubscriptionProcessor] | None: """ Provides an entity subscription processors to be run on on subscription queries. """ @@ -144,7 +144,7 @@ def get_subscription_processors( def get_subscription_validators( self, - ) -> Optional[Sequence[EntitySubscriptionValidator]]: + ) -> Sequence[EntitySubscriptionValidator] | None: """ Provides an entity subscription validators to be run on on subscription queries. """ diff --git a/snuba/datasets/entity_subscriptions/processors.py b/snuba/datasets/entity_subscriptions/processors.py index 290293fdf54..f7802771b7c 100644 --- a/snuba/datasets/entity_subscriptions/processors.py +++ b/snuba/datasets/entity_subscriptions/processors.py @@ -1,5 +1,6 @@ from abc import abstractmethod -from typing import Any, Mapping, Optional, Type, Union, cast +from collections.abc import Mapping +from typing import Any, cast from snuba.query.composite import CompositeQuery from snuba.query.conditions import ( @@ -20,8 +21,8 @@ def config_key(cls) -> str: return cls.__name__ @classmethod - def get_from_name(cls, name: str) -> Type["EntitySubscriptionProcessor"]: - return cast(Type["EntitySubscriptionProcessor"], cls.class_from_name(name)) + def get_from_name(cls, name: str) -> type["EntitySubscriptionProcessor"]: + return cast(type["EntitySubscriptionProcessor"], cls.class_from_name(name)) @abstractmethod def to_dict(self, metadata: Mapping[str, Any]) -> Mapping[str, Any]: @@ -30,9 +31,9 @@ def to_dict(self, metadata: Mapping[str, Any]) -> Mapping[str, Any]: @abstractmethod def process( self, - query: Union[CompositeQuery[Entity], Query], + query: CompositeQuery[Entity] | Query, metadata: Mapping[str, Any], - offset: Optional[int] = None, + offset: int | None = None, ) -> None: raise NotImplementedError @@ -51,9 +52,9 @@ def to_dict(self, metadata: Mapping[str, Any]) -> Mapping[str, Any]: def process( self, - query: Union[CompositeQuery[Entity], Query], + query: CompositeQuery[Entity] | Query, metadata: Mapping[str, Any], - offset: Optional[int] = None, + offset: int | None = None, ) -> None: if self.extra_condition_data_key not in metadata: raise InvalidQueryException( diff --git a/snuba/datasets/entity_subscriptions/validators.py b/snuba/datasets/entity_subscriptions/validators.py index c7da572380f..e7fe6ec89df 100644 --- a/snuba/datasets/entity_subscriptions/validators.py +++ b/snuba/datasets/entity_subscriptions/validators.py @@ -1,5 +1,6 @@ from abc import abstractmethod -from typing import Optional, Sequence, Type, Union, cast +from collections.abc import Sequence +from typing import cast from snuba.query.composite import CompositeQuery from snuba.query.data_source.simple import Entity @@ -21,11 +22,11 @@ def config_key(cls) -> str: return cls.__name__ @classmethod - def get_from_name(cls, name: str) -> Type["EntitySubscriptionValidator"]: - return cast(Type["EntitySubscriptionValidator"], cls.class_from_name(name)) + def get_from_name(cls, name: str) -> type["EntitySubscriptionValidator"]: + return cast(type["EntitySubscriptionValidator"], cls.class_from_name(name)) @abstractmethod - def validate(self, query: Union[CompositeQuery[Entity], Query]) -> None: + def validate(self, query: CompositeQuery[Entity] | Query) -> None: raise NotImplementedError @@ -34,7 +35,7 @@ def __init__( self, max_allowed_aggregations: int, disallowed_aggregations: Sequence[str], - required_time_column: Optional[str] = None, + required_time_column: str | None = None, allows_group_by_without_condition: bool = False, ): self.max_allowed_aggregations = max_allowed_aggregations @@ -42,7 +43,7 @@ def __init__( self.required_time_column = required_time_column self.allows_group_by_without_condition = allows_group_by_without_condition - def validate(self, query: Union[CompositeQuery[Entity], Query]) -> None: + def validate(self, query: CompositeQuery[Entity] | Query) -> None: SubscriptionAllowedClausesValidator( self.max_allowed_aggregations, self.disallowed_aggregations, diff --git a/snuba/datasets/events_format.py b/snuba/datasets/events_format.py index fb5aba81fa9..b7adfb41167 100644 --- a/snuba/datasets/events_format.py +++ b/snuba/datasets/events_format.py @@ -1,13 +1,7 @@ +from collections.abc import Callable, Iterable, Mapping, MutableMapping, Sequence from datetime import datetime, timedelta from typing import ( Any, - Callable, - Iterable, - Mapping, - MutableMapping, - Optional, - Sequence, - Tuple, TypeVar, ) @@ -53,13 +47,13 @@ def extract_http(output: MutableMapping[str, Any], request: Mapping[str, Any]) - def extract_extra_tags( nested_col: Mapping[str, Any], -) -> Tuple[Sequence[str], Sequence[str]]: +) -> tuple[Sequence[str], Sequence[str]]: return extract_nested(nested_col, lambda s: _unicodify(s) or None) def extract_nested( - nested_col: Mapping[str, Any], val_processor: Callable[[Any], Optional[TVal]] -) -> Tuple[Sequence[str], Sequence[TVal]]: + nested_col: Mapping[str, Any], val_processor: Callable[[Any], TVal | None] +) -> tuple[Sequence[str], Sequence[TVal]]: keys = [] values = [] for key, value in sorted(nested_col.items()): @@ -76,17 +70,17 @@ def extract_nested( def extract_extra_contexts( contexts: Mapping[str, Any], sort: bool = False, -) -> Tuple[Sequence[str], Sequence[str]]: +) -> tuple[Sequence[str], Sequence[str]]: context_keys = [] context_values = [] valid_types = (int, float, str) - contexts_iter: Iterable[Tuple[str, Any]] = contexts.items() + contexts_iter: Iterable[tuple[str, Any]] = contexts.items() if sort: contexts_iter = sorted(contexts_iter) for ctx_name, ctx_obj in contexts_iter: if isinstance(ctx_obj, dict): ctx_obj.pop("type", None) # ignore type alias - ctx_iter: Iterable[Tuple[str, Any]] = ctx_obj.items() + ctx_iter: Iterable[tuple[str, Any]] = ctx_obj.items() if sort: ctx_iter = sorted(ctx_iter) for inner_ctx_name, ctx_value in ctx_iter: @@ -106,17 +100,16 @@ def extract_extra_contexts( return (context_keys, context_values) -def enforce_retention(retention_days: Optional[int], timestamp: Optional[datetime]) -> int: +def enforce_retention(retention_days: int | None, timestamp: datetime | None) -> int: if not isinstance(retention_days, int): retention_days = settings.DEFAULT_RETENTION_DAYS - if settings.ENFORCE_RETENTION: - if retention_days not in settings.VALID_RETENTION_DAYS: - retention_days = ( - settings.LOWER_RETENTION_DAYS - if retention_days <= settings.LOWER_RETENTION_DAYS - else settings.DEFAULT_RETENTION_DAYS - ) + if settings.ENFORCE_RETENTION and retention_days not in settings.VALID_RETENTION_DAYS: + retention_days = ( + settings.LOWER_RETENTION_DAYS + if retention_days <= settings.LOWER_RETENTION_DAYS + else settings.DEFAULT_RETENTION_DAYS + ) # This is not ideal but it should never happen anyways timestamp = _ensure_valid_date(timestamp) diff --git a/snuba/datasets/factory.py b/snuba/datasets/factory.py index 146f82240a6..2a9dfe4305d 100644 --- a/snuba/datasets/factory.py +++ b/snuba/datasets/factory.py @@ -1,7 +1,6 @@ from __future__ import annotations from glob import glob -from typing import Type import sentry_sdk @@ -20,7 +19,7 @@ def __init__(self) -> None: with sentry_sdk.start_span(op="initialize", description="Dataset Factory"): initialize_entity_factory() self._dataset_map: dict[str, Dataset] = {} - self._name_map: dict[Type[Dataset], str] = {} + self._name_map: dict[type[Dataset], str] = {} self.__initialize() def __initialize(self) -> None: @@ -36,7 +35,7 @@ def __initialize(self) -> None: self._name_map = {v.__class__: k for k, v in self._dataset_map.items()} def all_names(self) -> list[str]: - return [name for name in self._dataset_map.keys() if name not in settings.DISABLED_DATASETS] + return [name for name in self._dataset_map if name not in settings.DISABLED_DATASETS] def get(self, name: str) -> Dataset: if name in settings.DISABLED_DATASETS: diff --git a/snuba/datasets/message_filters.py b/snuba/datasets/message_filters.py index 8c269820c98..72bd68892d4 100644 --- a/snuba/datasets/message_filters.py +++ b/snuba/datasets/message_filters.py @@ -25,11 +25,11 @@ def config_key(cls) -> str: return cls.__name__ @classmethod - def get_from_name(cls, name: str) -> "StreamMessageFilter[TPayload]": + def get_from_name(cls, name: str) -> StreamMessageFilter[TPayload]: return cast("StreamMessageFilter[TPayload]", cls.class_from_name(name)) @classmethod - def from_kwargs(cls, **kwargs: str) -> "StreamMessageFilter[TPayload]": + def from_kwargs(cls, **kwargs: str) -> StreamMessageFilter[TPayload]: return cls(**kwargs) @abstractmethod @@ -72,7 +72,7 @@ def __init__(self, postgres_table: str) -> None: self.__postgres_table = postgres_table def should_drop(self, message: Message[KafkaPayload]) -> bool: - assert [p.index for p in message.committable.keys()] == [KAFKA_ONLY_PARTITION], ( + assert [p.index for p in message.committable] == [KAFKA_ONLY_PARTITION], ( "CDC can only work with single partition topics for consistency" ) diff --git a/snuba/datasets/metrics_messages.py b/snuba/datasets/metrics_messages.py index 944d08e2649..2ef969eb840 100644 --- a/snuba/datasets/metrics_messages.py +++ b/snuba/datasets/metrics_messages.py @@ -1,5 +1,6 @@ +from collections.abc import Iterable, Mapping, MutableMapping from enum import Enum -from typing import Any, Iterable, Mapping, MutableMapping +from typing import Any class InputType(Enum): diff --git a/snuba/datasets/plans/cluster_selector.py b/snuba/datasets/plans/cluster_selector.py index a51d05c8f5c..21328d4130e 100644 --- a/snuba/datasets/plans/cluster_selector.py +++ b/snuba/datasets/plans/cluster_selector.py @@ -100,7 +100,6 @@ def select_cluster( logical_partition = map_org_id_to_logical_partition(org_id) if _should_use_mega_cluster(self.storage_set, logical_partition): return get_cluster(self.storage_set) - else: - slice_id = map_logical_partition_to_slice(self.storage_set, logical_partition) - cluster = get_cluster(self.storage_set, slice_id) - return cluster + slice_id = map_logical_partition_to_slice(self.storage_set, logical_partition) + cluster = get_cluster(self.storage_set, slice_id) + return cluster diff --git a/snuba/datasets/plans/entity_processing.py b/snuba/datasets/plans/entity_processing.py index 394baf26bd9..fafb71430b9 100644 --- a/snuba/datasets/plans/entity_processing.py +++ b/snuba/datasets/plans/entity_processing.py @@ -1,6 +1,7 @@ from __future__ import annotations -from typing import Optional, Sequence, cast +from collections.abc import Sequence +from typing import cast import sentry_sdk @@ -41,8 +42,8 @@ def __init__( self, storages: Sequence[EntityStorageConnection], selector: QueryStorageSelector, - post_processors: Optional[Sequence[ClickhouseQueryProcessor]] = None, - partition_key_column_name: Optional[str] = None, + post_processors: Sequence[ClickhouseQueryProcessor] | None = None, + partition_key_column_name: str | None = None, ) -> None: # A list of storages and the translation mappers they are associated with. # This list will only contain one storage and mappers for single storage entities. diff --git a/snuba/datasets/plans/entity_validation.py b/snuba/datasets/plans/entity_validation.py index 7ca3b941f74..bfb6ab64443 100644 --- a/snuba/datasets/plans/entity_validation.py +++ b/snuba/datasets/plans/entity_validation.py @@ -1,7 +1,5 @@ from __future__ import annotations -from typing import Union - import sentry_sdk from snuba.datasets.entities.factory import get_entity @@ -28,7 +26,7 @@ def _validate_query(query: Query) -> None: v.validate(exp, query.get_from_clause()) -def _validate_entities_with_query(query: Union[CompositeQuery[QueryEntity], EntityQuery]) -> None: +def _validate_entities_with_query(query: CompositeQuery[QueryEntity] | EntityQuery) -> None: """ Applies all validator defined on the entities in the query """ @@ -41,7 +39,7 @@ def _validate_entities_with_query(query: Union[CompositeQuery[QueryEntity], Enti raise ValidationException( f"Validation failed for entity {query.get_from_clause().key.value}: {e}", should_report=e.should_report, - ) + ) from e else: from_clause = query.get_from_clause() if isinstance(from_clause, JoinClause): @@ -56,14 +54,14 @@ def _validate_entities_with_query(query: Union[CompositeQuery[QueryEntity], Enti raise ValidationException( f"Validation failed for entity {node.data_source.key.value}: {e}", should_report=e.should_report, - ) + ) from e VALIDATORS = [_validate_query, _validate_entities_with_query] def run_entity_validators( - query: Union[CompositeQuery[QueryEntity], EntityQuery], + query: CompositeQuery[QueryEntity] | EntityQuery, settings: QuerySettings | None = None, ) -> None: """ diff --git a/snuba/datasets/plans/query_plan.py b/snuba/datasets/plans/query_plan.py index bbd6d7f8c11..ac142cddf5f 100644 --- a/snuba/datasets/plans/query_plan.py +++ b/snuba/datasets/plans/query_plan.py @@ -1,14 +1,17 @@ from __future__ import annotations from abc import ABC +from collections.abc import Mapping, Sequence from dataclasses import dataclass -from typing import Generic, Mapping, Optional, Sequence, Tuple, TypeVar, Union +from typing import Generic, NamedTuple, TypeVar from snuba.clickhouse.query import Query +from snuba.clickhouse.query import Query as ClickhouseQuery from snuba.clusters.storage_sets import StorageSetKey from snuba.query import ProcessableQuery from snuba.query import Query as AbstractQuery from snuba.query.composite import CompositeQuery +from snuba.query.data_source.join import JoinClause from snuba.query.data_source.simple import Table from snuba.query.processors.physical import ClickhouseQueryProcessor @@ -39,7 +42,7 @@ class QueryPlan(ABC, Generic[TQuery]): @dataclass(frozen=True) -class ClickhouseQueryPlan(QueryPlan[Union[Query, ProcessableQuery[Table]]]): +class ClickhouseQueryPlan(QueryPlan[Query | ProcessableQuery[Table]]): """ Query plan for a single entity, single storage query. @@ -68,12 +71,6 @@ class SubqueryProcessors: db_processors: Sequence[ClickhouseQueryProcessor] -from typing import NamedTuple - -from snuba.clickhouse.query import Query as ClickhouseQuery -from snuba.query.data_source.join import JoinClause - - class CompositeQueryPlan(NamedTuple): """ Intermediate query plan data structure maintained when visiting @@ -85,19 +82,16 @@ class CompositeQueryPlan(NamedTuple): (e.g. aliased processors). """ - translated_source: Union[ - ClickhouseQuery, - ProcessableQuery[Table], - CompositeQuery[Table], - JoinClause[Table], - ] + translated_source: ( + ClickhouseQuery | ProcessableQuery[Table] | CompositeQuery[Table] | JoinClause[Table] + ) storage_set_key: StorageSetKey - root_processors: Optional[SubqueryProcessors] = None - aliased_processors: Optional[Mapping[str, SubqueryProcessors]] = None + root_processors: SubqueryProcessors | None = None + aliased_processors: Mapping[str, SubqueryProcessors] | None = None def get_db_processors( self, - ) -> Tuple[ + ) -> tuple[ Sequence[ClickhouseQueryProcessor], Mapping[str, Sequence[ClickhouseQueryProcessor]], ]: @@ -115,7 +109,7 @@ def get_db_processors( def get_plan_processors( self, - ) -> Tuple[ + ) -> tuple[ Sequence[ClickhouseQueryProcessor], Mapping[str, Sequence[ClickhouseQueryProcessor]], ]: diff --git a/snuba/datasets/plans/storage_processing.py b/snuba/datasets/plans/storage_processing.py index 6e6f4352001..f7b0220e37b 100644 --- a/snuba/datasets/plans/storage_processing.py +++ b/snuba/datasets/plans/storage_processing.py @@ -1,6 +1,7 @@ from __future__ import annotations -from typing import Optional, Sequence, TypeVar, Union +from collections.abc import Sequence +from typing import TypeVar import sentry_sdk @@ -46,7 +47,7 @@ def get_query_data_source( relational_source: RelationalSource, allocation_policies: list[AllocationPolicy], final: bool, - sampling_rate: Optional[float], + sampling_rate: float | None, storage_key: StorageKey, ) -> Table: assert isinstance(relational_source, TableSource) @@ -74,7 +75,7 @@ def check_storage_readiness(storage: ReadableStorage) -> None: def build_best_plan( - physical_query: Union[Query, ProcessableQuery[Table]], + physical_query: Query | ProcessableQuery[Table], settings: QuerySettings, post_processors: Sequence[ClickhouseQueryProcessor] = [], ) -> ClickhouseQueryPlan: diff --git a/snuba/datasets/plans/translator/mapper.py b/snuba/datasets/plans/translator/mapper.py index 18f2ab84dcf..2c37fc4a0d2 100644 --- a/snuba/datasets/plans/translator/mapper.py +++ b/snuba/datasets/plans/translator/mapper.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod -from typing import Generic, Optional, Sequence, TypeVar +from collections.abc import Sequence +from typing import Generic, TypeVar TExpIn = TypeVar("TExpIn") TExpOut = TypeVar("TExpOut") @@ -28,7 +29,7 @@ def attempt_map( self, expression: TExpIn, children_translator: TTranslator, - ) -> Optional[TExpOut]: + ) -> TExpOut | None: """ Maps an expression if this rule matches such expression. If not, it returns None. """ diff --git a/snuba/datasets/pluggable_entity.py b/snuba/datasets/pluggable_entity.py index fac7190bddb..291f0730c60 100644 --- a/snuba/datasets/pluggable_entity.py +++ b/snuba/datasets/pluggable_entity.py @@ -1,7 +1,8 @@ from __future__ import annotations +from collections.abc import Mapping, Sequence from dataclasses import dataclass, field -from typing import Any, List, Mapping, Optional, Sequence +from typing import Any from snuba.clickhouse.columns import Column, ColumnSet from snuba.datasets.entities.entity_key import EntityKey @@ -41,20 +42,20 @@ class PluggableEntity(Entity): """ entity_key: EntityKey - storages: List[EntityStorageConnection] + storages: list[EntityStorageConnection] query_processors: Sequence[LogicalQueryProcessor] columns: Sequence[Column[SchemaModifiers]] validators: Sequence[QueryValidator] - required_time_column: Optional[str] + required_time_column: str | None storage_selector: QueryStorageSelector validate_data_model: ColumnValidationMode | None = None join_relationships: Mapping[str, JoinRelationship] = field(default_factory=dict) function_call_validators: Mapping[str, FunctionCallValidator] = field(default_factory=dict) # partition_key_column_name is used in data slicing (the value in this storage column # will be used to "choose" slices) - partition_key_column_name: Optional[str] = None - subscription_processors: Optional[Sequence[EntitySubscriptionProcessor]] = None - subscription_validators: Optional[Sequence[EntitySubscriptionValidator]] = None + partition_key_column_name: str | None = None + subscription_processors: Sequence[EntitySubscriptionProcessor] | None = None + subscription_validators: Sequence[EntitySubscriptionValidator] | None = None def _get_builtin_validators(self) -> Sequence[QueryValidator]: mappers = [s.translation_mappers for s in self.storages] @@ -73,7 +74,7 @@ def get_query_processors(self) -> Sequence[LogicalQueryProcessor]: def get_data_model(self) -> ColumnSet: return ColumnSet(self.columns) - def get_join_relationship(self, relationship: str) -> Optional[JoinRelationship]: + def get_join_relationship(self, relationship: str) -> JoinRelationship | None: return self.join_relationships.get(relationship) def get_all_join_relationships(self) -> Mapping[str, JoinRelationship]: @@ -97,7 +98,7 @@ def get_all_storages(self) -> Sequence[Storage]: def get_all_storage_connections(self) -> Sequence[EntityStorageConnection]: return self.storages - def get_writable_storage(self) -> Optional[WritableTableStorage]: + def get_writable_storage(self) -> WritableTableStorage | None: for storage_connection in self.storages: if storage_connection.is_writable and isinstance( storage_connection.storage, WritableTableStorage @@ -105,7 +106,7 @@ def get_writable_storage(self) -> Optional[WritableTableStorage]: return storage_connection.storage return None - def get_storage_selector(self) -> Optional[QueryStorageSelector]: + def get_storage_selector(self) -> QueryStorageSelector | None: return self.storage_selector def get_function_call_validators(self) -> Mapping[str, FunctionCallValidator]: @@ -116,12 +117,12 @@ def get_validators(self) -> Sequence[QueryValidator]: def get_subscription_processors( self, - ) -> Optional[Sequence[EntitySubscriptionProcessor]]: + ) -> Sequence[EntitySubscriptionProcessor] | None: return self.subscription_processors def get_subscription_validators( self, - ) -> Optional[Sequence[EntitySubscriptionValidator]]: + ) -> Sequence[EntitySubscriptionValidator] | None: return self.subscription_validators def __eq__(self, other: Any) -> bool: diff --git a/snuba/datasets/processors/__init__.py b/snuba/datasets/processors/__init__.py index 74bac4f9a33..084bfc71c0f 100644 --- a/snuba/datasets/processors/__init__.py +++ b/snuba/datasets/processors/__init__.py @@ -2,7 +2,7 @@ import os from abc import abstractmethod -from typing import Any, Optional, Type, cast +from typing import Any, cast from snuba.consumers.types import KafkaMessageMetadata from snuba.processor import MessageProcessor, ProcessedMessage @@ -27,7 +27,7 @@ def __init__(self) -> None: @classmethod def from_name(cls, name: str) -> DatasetMessageProcessor: - return cast(Type["DatasetMessageProcessor"], cls.class_from_name(name))() + return cast(type["DatasetMessageProcessor"], cls.class_from_name(name))() @classmethod def config_key(cls) -> str: @@ -36,7 +36,7 @@ def config_key(cls) -> str: @abstractmethod def process_message( self, message: Any, metadata: KafkaMessageMetadata - ) -> Optional[ProcessedMessage]: + ) -> ProcessedMessage | None: raise NotImplementedError diff --git a/snuba/datasets/processors/group_attributes_processor.py b/snuba/datasets/processors/group_attributes_processor.py index 42b9ddd395f..9757b6dfb17 100644 --- a/snuba/datasets/processors/group_attributes_processor.py +++ b/snuba/datasets/processors/group_attributes_processor.py @@ -1,5 +1,4 @@ from datetime import datetime -from typing import Optional from sentry_kafka_schemas.schema_types.group_attributes_v1 import ( GroupAttributesSnapshot, @@ -17,7 +16,7 @@ class GroupAttributesMessageProcessor(DatasetMessageProcessor): def process_message( self, message: GroupAttributesSnapshot, metadata: KafkaMessageMetadata - ) -> Optional[ProcessedMessage]: + ) -> ProcessedMessage | None: return InsertBatch( [ { diff --git a/snuba/datasets/processors/rust_compat_processor.py b/snuba/datasets/processors/rust_compat_processor.py index ceb85731f8c..61ddd8f4e87 100644 --- a/snuba/datasets/processors/rust_compat_processor.py +++ b/snuba/datasets/processors/rust_compat_processor.py @@ -1,8 +1,8 @@ from __future__ import annotations import logging -from datetime import timezone -from typing import Any, Optional +from datetime import UTC +from typing import Any import simplejson as json @@ -17,12 +17,12 @@ class RustCompatProcessor(DatasetMessageProcessor): def __init__(self, processor_name: str): import rust_snuba - self.__process_message = rust_snuba.process_message # type: ignore + self.__process_message = rust_snuba.process_message # type: ignore[attr-defined] self.__processor_name = processor_name def process_message( self, message: Any, metadata: KafkaMessageMetadata - ) -> Optional[ProcessedMessage]: + ) -> ProcessedMessage | None: if self.__processor_name == "EAPItemsProcessor": payload = message else: @@ -32,7 +32,7 @@ def process_message( payload, metadata.partition, metadata.offset, - int(metadata.timestamp.replace(tzinfo=timezone.utc).timestamp() * 1000), + int(metadata.timestamp.replace(tzinfo=UTC).timestamp() * 1000), ) if insert_payload is not None: @@ -45,11 +45,10 @@ def process_message( origin_timestamp=None, sentry_received_timestamp=None, ) - elif replacement_payload is not None: + if replacement_payload is not None: assert insert_payload is None key, values_bytes = replacement_payload values = [json.loads(line) for line in values_bytes.rstrip(b"\n").split(b"\n") if line] return ReplacementBatch(key=key.decode("utf8"), values=values) - else: - raise ValueError("unsupported return value from snuba_rust") + raise ValueError("unsupported return value from snuba_rust") diff --git a/snuba/datasets/processors/search_issues_processor.py b/snuba/datasets/processors/search_issues_processor.py index e3e76fef6d7..46c911c217e 100644 --- a/snuba/datasets/processors/search_issues_processor.py +++ b/snuba/datasets/processors/search_issues_processor.py @@ -1,14 +1,9 @@ import numbers import uuid -from datetime import datetime, timezone +from collections.abc import Mapping, MutableMapping, Sequence +from datetime import UTC, datetime from typing import ( Any, - Dict, - Mapping, - MutableMapping, - Optional, - Sequence, - Tuple, TypedDict, cast, ) @@ -54,7 +49,7 @@ class IssueOccurrenceData(TypedDict, total=False): subtitle: str culprit: str level: str - resource_id: Optional[str] + resource_id: str | None detection_time: float @@ -72,9 +67,9 @@ class IssueEventData(TypedDict, total=False): request: Mapping[str, Any] # http_method, http_referer # tag aliases - environment: Optional[str] # tags[environment] -> environment - release: Optional[str] # tags[sentry:release] -> release - dist: Optional[str] # tags[sentry:dist] -> dist + environment: str | None # tags[environment] -> environment + release: str | None # tags[sentry:release] -> release + dist: str | None # tags[sentry:dist] -> dist # (tags[sentry:user] or user[id]) -> user # contexts aliases @@ -140,7 +135,7 @@ def _process_tags( self, event_data: IssueEventData, processed: MutableMapping[str, Any] ) -> None: existing_tags = event_data.get("tags", None) - tags: Mapping[str, Any] = _as_dict_safe(cast(Dict[str, Any], existing_tags)) + tags: Mapping[str, Any] = _as_dict_safe(cast(dict[str, Any], existing_tags)) if not existing_tags: processed["tags.key"], processed["tags.value"] = [], [] else: @@ -227,7 +222,7 @@ def _process_timestamp_ms( # NOTE: we do this conversion because the JSONRowEncoder will strip out milliseconds out # of datetime objects specifically. To work around that, we convert the datetime to a # timestamp in milliseconds - client_timestamp = client_timestamp.replace(tzinfo=timezone.utc) + client_timestamp = client_timestamp.replace(tzinfo=UTC) processed["timestamp_ms"] = int(client_timestamp.timestamp() * 1000) def process_insert_v1( @@ -317,8 +312,8 @@ def process_insert_v1( ] def process_message( - self, message: Tuple[int, str, SearchIssueEvent], metadata: KafkaMessageMetadata - ) -> Optional[ProcessedMessage]: + self, message: tuple[int, str, SearchIssueEvent], metadata: KafkaMessageMetadata + ) -> ProcessedMessage | None: if not (isinstance(message, (list, tuple)) and len(message) >= 2): raise InvalidMessageFormat( f"Expected message format (, , >)), got {message} instead" diff --git a/snuba/datasets/processors/transactions_processor.py b/snuba/datasets/processors/transactions_processor.py index 859818d0b16..06cde78a096 100644 --- a/snuba/datasets/processors/transactions_processor.py +++ b/snuba/datasets/processors/transactions_processor.py @@ -1,9 +1,11 @@ +import contextlib import copy import logging import numbers import uuid +from collections.abc import Mapping, MutableMapping from datetime import datetime -from typing import Any, Dict, Mapping, MutableMapping, Optional, Tuple +from typing import Any from sentry_relay.consts import SPAN_STATUS_NAME_TO_CODE @@ -40,8 +42,8 @@ UNKNOWN_SPAN_STATUS = 2 GROUP_IDS_LIMIT = 10 -EventDict = Dict[str, Any] -SpanDict = Dict[str, Any] +EventDict = dict[str, Any] +SpanDict = dict[str, Any] RetentionDays = int @@ -54,7 +56,7 @@ class TransactionsMessageProcessor(DatasetMessageProcessor): "replayId", } - def __extract_timestamp(self, field: int) -> Tuple[datetime, int]: + def __extract_timestamp(self, field: int) -> tuple[datetime, int]: # We are purposely using a naive datetime here to work with the rest of the codebase. # We can be confident that clients are only sending UTC dates. timestamp = _ensure_valid_date(datetime.utcfromtimestamp(field)) @@ -64,8 +66,8 @@ def __extract_timestamp(self, field: int) -> Tuple[datetime, int]: return (timestamp, milliseconds) def _structure_and_validate_message( - self, message: Tuple[int, str, Dict[str, Any]] - ) -> Optional[Tuple[EventDict, RetentionDays]]: + self, message: tuple[int, str, dict[str, Any]] + ) -> tuple[EventDict, RetentionDays] | None: if not (isinstance(message, (list, tuple)) and len(message) >= 2): return None @@ -159,12 +161,10 @@ def _process_tags( replay_id = promoted_tags.get("replayId") if replay_id: - try: + # replay_id as a tag is not guarenteed to be UUID (user could set value in theory) + # so simply continue if not UUID. + with contextlib.suppress(ValueError): processed["replay_id"] = str(uuid.UUID(replay_id)) - except ValueError: - # replay_id as a tag is not guarenteed to be UUID (user could set value in theory) - # so simply continue if not UUID. - pass processed["dist"] = _unicodify( promoted_tags.get("sentry:dist", event_dict["data"].get("dist")), @@ -326,7 +326,7 @@ def _process_sdk_data( if processed["sdk_version"] == "": metrics.increment("missing_sdk_version") - def _process_span(self, span_dict: SpanDict) -> Optional[Tuple[str, int, float]]: + def _process_span(self, span_dict: SpanDict) -> tuple[str, int, float] | None: op = span_dict.get("op") group = span_dict.get("hash") exclusive_time = span_dict.get("exclusive_time") @@ -445,8 +445,8 @@ def _sanitize_contexts( return sanitized_context def process_message( - self, message: Tuple[int, str, Dict[Any, Any]], metadata: KafkaMessageMetadata - ) -> Optional[ProcessedMessage]: + self, message: tuple[int, str, dict[Any, Any]], metadata: KafkaMessageMetadata + ) -> ProcessedMessage | None: event_dict, retention_days = self._structure_and_validate_message(message) or ( None, None, diff --git a/snuba/datasets/schemas/__init__.py b/snuba/datasets/schemas/__init__.py index 9c9eca4e05c..455f0e4d4bf 100644 --- a/snuba/datasets/schemas/__init__.py +++ b/snuba/datasets/schemas/__init__.py @@ -1,7 +1,7 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import List, Mapping, Sequence +from collections.abc import Mapping, Sequence from snuba.clickhouse.columns import ColumnSet, ColumnType, Nullable, TModifiers from snuba.query.expressions import FunctionCall @@ -60,17 +60,15 @@ def get_columns(self) -> ColumnSet: def get_column_differences( self, expected_columns: Mapping[str, ColumnType[TModifiers]] - ) -> List[str]: + ) -> list[str]: """ Returns a list of differences between the expected_columns and the columns described in the schema. """ - errors: List[str] = [] + errors: list[str] = [] for column in self.get_columns(): if column.flattened not in expected_columns: - errors.append( - "Column '%s' exists in schema but not local ClickHouse!" % column.name - ) + errors.append(f"Column '{column.name}' exists in schema but not local ClickHouse!") continue expected_type = expected_columns[column.flattened] @@ -79,8 +77,7 @@ def get_column_differences( Nullable ) != expected_type.has_modifier(Nullable): errors.append( - "Column '%s' type differs between local ClickHouse and schema! (expected: %s, is: %s)" - % (column.name, expected_type, column) + f"Column '{column.name}' type differs between local ClickHouse and schema! (expected: {expected_type}, is: {column})" ) return errors diff --git a/snuba/datasets/schemas/tables.py b/snuba/datasets/schemas/tables.py index 0e2441cbed1..4b4733fbb35 100644 --- a/snuba/datasets/schemas/tables.py +++ b/snuba/datasets/schemas/tables.py @@ -1,8 +1,8 @@ from __future__ import annotations +from collections.abc import Sequence from dataclasses import dataclass from enum import Enum -from typing import Optional, Sequence from snuba import util from snuba.clickhouse.columns import ColumnSet @@ -26,7 +26,7 @@ class TableSource(RelationalSource): table_name: str columns: ColumnSet - mandatory_conditions: Optional[Sequence[FunctionCall]] = None + mandatory_conditions: Sequence[FunctionCall] | None = None def get_table_name(self) -> str: return self.table_name @@ -51,8 +51,8 @@ def __init__( local_table_name: str, dist_table_name: str, storage_set_key: StorageSetKey, - mandatory_conditions: Optional[Sequence[FunctionCall]] = None, - partition_format: Optional[Sequence[util.PartSegment]] = None, + mandatory_conditions: Sequence[FunctionCall] | None = None, + partition_format: Sequence[util.PartSegment] | None = None, ): self.__local_table_name = local_table_name self.__dist_table_name = dist_table_name @@ -98,7 +98,7 @@ def get_table_name(self) -> str: else self.__dist_table_name ) - def get_partition_format(self) -> Optional[Sequence[util.PartSegment]]: + def get_partition_format(self) -> Sequence[util.PartSegment] | None: """ Partition format required for cleanup and optimize. """ diff --git a/snuba/datasets/slicing.py b/snuba/datasets/slicing.py index d363ff5a0f3..9de005bff00 100644 --- a/snuba/datasets/slicing.py +++ b/snuba/datasets/slicing.py @@ -39,4 +39,4 @@ def is_storage_set_sliced(storage_set: StorageSetKey) -> bool: """ from snuba.settings import SLICED_STORAGE_SETS - return True if storage_set.value in SLICED_STORAGE_SETS.keys() else False + return storage_set.value in SLICED_STORAGE_SETS diff --git a/snuba/datasets/storage.py b/snuba/datasets/storage.py index f0ca6060735..b49bdfa1fcb 100644 --- a/snuba/datasets/storage.py +++ b/snuba/datasets/storage.py @@ -1,8 +1,9 @@ from __future__ import annotations from abc import ABC, abstractmethod +from collections.abc import Sequence from dataclasses import dataclass -from typing import Any, Optional, Sequence +from typing import Any from snuba.clickhouse.translators.snuba.mapping import TranslationMappers from snuba.clusters.cluster import ( @@ -24,7 +25,7 @@ from snuba.replacers.replacer_processor import ReplacerProcessor -class Storage(ABC): +class Storage(ABC): # noqa: B024 abstract base for storage subclasses; not meant to be instantiated directly """ Storage is an abstraction that represent a DB object that stores data and has a schema. @@ -41,7 +42,7 @@ def __init__( storage_set_key: StorageSetKey, schema: Schema, readiness_state: ReadinessState, - required_time_column: Optional[str] = None, + required_time_column: str | None = None, ): self.__storage_set_key = storage_set_key self.__schema = schema @@ -51,7 +52,7 @@ def __init__( def get_storage_set_key(self) -> StorageSetKey: return self.__storage_set_key - def get_cluster(self, slice_id: Optional[int] = None) -> ClickhouseCluster: + def get_cluster(self, slice_id: int | None = None) -> ClickhouseCluster: return get_cluster(self.__storage_set_key, slice_id) def get_schema(self) -> Schema: @@ -130,17 +131,19 @@ def __init__( storage_set_key: StorageSetKey, schema: Schema, readiness_state: ReadinessState, - query_processors: Optional[Sequence[ClickhouseQueryProcessor]] = None, - deletion_settings: Optional[DeletionSettings] = None, - deletion_processors: Optional[Sequence[ClickhouseQueryProcessor]] = None, - mandatory_condition_checkers: Optional[Sequence[ConditionChecker]] = None, - allocation_policies: Optional[list[AllocationPolicy]] = None, - delete_allocation_policies: Optional[list[AllocationPolicy]] = None, - required_time_column: Optional[str] = None, + query_processors: Sequence[ClickhouseQueryProcessor] | None = None, + deletion_settings: DeletionSettings | None = None, + deletion_processors: Sequence[ClickhouseQueryProcessor] | None = None, + mandatory_condition_checkers: Sequence[ConditionChecker] | None = None, + allocation_policies: list[AllocationPolicy] | None = None, + delete_allocation_policies: list[AllocationPolicy] | None = None, + required_time_column: str | None = None, ) -> None: self.__storage_key = storage_key self.__query_processors = query_processors or [] - self.__deletion_settings = deletion_settings or DeletionSettings(0, [], [], 0) + self.__deletion_settings = deletion_settings or DeletionSettings( + is_enabled=0, tables=[], bulk_delete_only=False, allowed_columns=[] + ) self.__deletion_processors = deletion_processors or [] self.__mandatory_condition_checkers = mandatory_condition_checkers or [] self.__allocation_policies = allocation_policies or [] @@ -183,16 +186,16 @@ def __init__( schema: Schema, query_processors: Sequence[ClickhouseQueryProcessor], stream_loader: KafkaStreamLoader, - mandatory_condition_checkers: Optional[Sequence[ConditionChecker]] = None, - allocation_policies: Optional[list[AllocationPolicy]] = None, - delete_allocation_policies: Optional[list[AllocationPolicy]] = None, - replacer_processor: Optional[ReplacerProcessor[Any]] = None, - deletion_settings: Optional[DeletionSettings] = None, - deletion_processors: Optional[Sequence[ClickhouseQueryProcessor]] = None, + mandatory_condition_checkers: Sequence[ConditionChecker] | None = None, + allocation_policies: list[AllocationPolicy] | None = None, + delete_allocation_policies: list[AllocationPolicy] | None = None, + replacer_processor: ReplacerProcessor[Any] | None = None, + deletion_settings: DeletionSettings | None = None, + deletion_processors: Sequence[ClickhouseQueryProcessor] | None = None, writer_options: ClickhouseWriterOptions = None, write_format: WriteFormat = WriteFormat.JSON, ignore_write_errors: bool = False, - required_time_column: Optional[str] = None, + required_time_column: str | None = None, ) -> None: self.__storage_key = storage_key super().__init__( diff --git a/snuba/datasets/storages/factory.py b/snuba/datasets/storages/factory.py index 3b353351284..0c111acc568 100644 --- a/snuba/datasets/storages/factory.py +++ b/snuba/datasets/storages/factory.py @@ -1,7 +1,8 @@ from __future__ import annotations +import contextlib +from collections.abc import MutableSequence, Sequence from glob import glob -from typing import MutableSequence, Sequence import sentry_sdk @@ -88,10 +89,8 @@ def get_writable_storages() -> Sequence[WritableTableStorage]: writable_storages: MutableSequence[WritableTableStorage] = [] storage_keys = get_all_storage_keys() for storage_key in storage_keys: - try: + with contextlib.suppress(AssertionError): writable_storages.append(get_writable_storage(storage_key)) - except AssertionError: - pass return writable_storages diff --git a/snuba/datasets/storages/storage_key.py b/snuba/datasets/storages/storage_key.py index a10650a216a..556d4e35575 100644 --- a/snuba/datasets/storages/storage_key.py +++ b/snuba/datasets/storages/storage_key.py @@ -1,12 +1,13 @@ from __future__ import annotations -from typing import Any, Iterator +from collections.abc import Iterator +from typing import Any REGISTERED_STORAGE_KEYS: dict[str, str] = {} class _StorageKey(type): - def __getattr__(cls, attr: str) -> "StorageKey": + def __getattr__(cls, attr: str) -> StorageKey: if attr not in REGISTERED_STORAGE_KEYS: raise AttributeError(attr) diff --git a/snuba/datasets/storages/validator.py b/snuba/datasets/storages/validator.py index 7ecc6c75b2b..114aefbee43 100644 --- a/snuba/datasets/storages/validator.py +++ b/snuba/datasets/storages/validator.py @@ -1,13 +1,11 @@ from __future__ import annotations -from typing import Union - from snuba.datasets.storage import ReadableTableStorage, WritableTableStorage from snuba.migrations.groups import get_group_readiness_state_from_storage_set class StorageValidator: - def __init__(self, storage: Union[ReadableTableStorage, WritableTableStorage]) -> None: + def __init__(self, storage: ReadableTableStorage | WritableTableStorage) -> None: self.storage = storage def validate(self) -> None: diff --git a/snuba/datasets/table_storage.py b/snuba/datasets/table_storage.py index 28575a930f3..f24f4b8b469 100644 --- a/snuba/datasets/table_storage.py +++ b/snuba/datasets/table_storage.py @@ -1,8 +1,9 @@ +from collections.abc import Mapping, Sequence from functools import cached_property -from typing import Any, Mapping, Optional, Sequence +from typing import Any from arroyo.backends.kafka import KafkaPayload -from confluent_kafka.admin import ( +from confluent_kafka.admin import ( # type: ignore[attr-defined] # not re-exported in stubs AdminClient, ConfigResource, ResourceType, @@ -54,7 +55,7 @@ def topic_name(self) -> str: topic = self.__topic.value return settings.KAFKA_TOPIC_MAP.get(topic, topic) - def get_physical_topic_name(self, slice_id: Optional[int] = None) -> str: + def get_physical_topic_name(self, slice_id: int | None = None) -> str: """ Slice aware version of``topic_name``. """ @@ -104,15 +105,15 @@ def __init__( self, processor: MessageProcessor, default_topic_spec: KafkaTopicSpec, - pre_filter: Optional[StreamMessageFilter[KafkaPayload]] = None, - replacement_topic_spec: Optional[KafkaTopicSpec] = None, - commit_log_topic_spec: Optional[KafkaTopicSpec] = None, - subscription_scheduler_mode: Optional[SchedulingWatermarkMode] = None, - subscription_synchronization_timestamp: Optional[str] = None, - subscription_scheduled_topic_spec: Optional[KafkaTopicSpec] = None, - subscription_result_topic_spec: Optional[KafkaTopicSpec] = None, - subscription_delay_seconds: Optional[int] = None, - dlq_topic_spec: Optional[KafkaTopicSpec] = None, + pre_filter: StreamMessageFilter[KafkaPayload] | None = None, + replacement_topic_spec: KafkaTopicSpec | None = None, + commit_log_topic_spec: KafkaTopicSpec | None = None, + subscription_scheduler_mode: SchedulingWatermarkMode | None = None, + subscription_synchronization_timestamp: str | None = None, + subscription_scheduled_topic_spec: KafkaTopicSpec | None = None, + subscription_result_topic_spec: KafkaTopicSpec | None = None, + subscription_delay_seconds: int | None = None, + dlq_topic_spec: KafkaTopicSpec | None = None, ) -> None: subscription_values = [ bool(subscription_scheduled_topic_spec), @@ -140,7 +141,7 @@ def __init__( def get_processor(self) -> MessageProcessor: return self.__processor - def get_pre_filter(self) -> Optional[StreamMessageFilter[KafkaPayload]]: + def get_pre_filter(self) -> StreamMessageFilter[KafkaPayload] | None: """ Returns a filter (or none if none is defined) to be applied to the messages coming from the Kafka stream before parsing the content of the message. @@ -150,66 +151,66 @@ def get_pre_filter(self) -> Optional[StreamMessageFilter[KafkaPayload]]: def get_default_topic_spec(self) -> KafkaTopicSpec: return self.__default_topic_spec - def get_replacement_topic_spec(self) -> Optional[KafkaTopicSpec]: + def get_replacement_topic_spec(self) -> KafkaTopicSpec | None: return self.__replacement_topic_spec - def get_commit_log_topic_spec(self) -> Optional[KafkaTopicSpec]: + def get_commit_log_topic_spec(self) -> KafkaTopicSpec | None: return self.__commit_log_topic_spec - def get_subscription_scheduler_mode(self) -> Optional[SchedulingWatermarkMode]: + def get_subscription_scheduler_mode(self) -> SchedulingWatermarkMode | None: return self.__subscription_scheduler_mode - def get_subscription_sychronization_timestamp(self) -> Optional[str]: + def get_subscription_sychronization_timestamp(self) -> str | None: return self.__subscription_synchronization_timestamp - def get_subscription_scheduled_topic_spec(self) -> Optional[KafkaTopicSpec]: + def get_subscription_scheduled_topic_spec(self) -> KafkaTopicSpec | None: return self.__subscription_scheduled_topic_spec - def get_subscription_result_topic_spec(self) -> Optional[KafkaTopicSpec]: + def get_subscription_result_topic_spec(self) -> KafkaTopicSpec | None: return self.__subscription_result_topic_spec - def get_subscription_delay_seconds(self) -> Optional[int]: + def get_subscription_delay_seconds(self) -> int | None: return self.__subscription_delay_seconds - def get_dlq_topic_spec(self) -> Optional[KafkaTopicSpec]: + def get_dlq_topic_spec(self) -> KafkaTopicSpec | None: return self.__dlq_topic_spec def build_kafka_stream_loader_from_settings( processor: MessageProcessor, default_topic: Topic, - pre_filter: Optional[StreamMessageFilter[KafkaPayload]] = None, - replacement_topic: Optional[Topic] = None, - commit_log_topic: Optional[Topic] = None, - subscription_scheduler_mode: Optional[SchedulingWatermarkMode] = None, - subscription_scheduled_topic: Optional[Topic] = None, - subscription_result_topic: Optional[Topic] = None, - subscription_synchronization_timestamp: Optional[str] = None, - subscription_delay_seconds: Optional[int] = None, - dlq_topic: Optional[Topic] = None, + pre_filter: StreamMessageFilter[KafkaPayload] | None = None, + replacement_topic: Topic | None = None, + commit_log_topic: Topic | None = None, + subscription_scheduler_mode: SchedulingWatermarkMode | None = None, + subscription_scheduled_topic: Topic | None = None, + subscription_result_topic: Topic | None = None, + subscription_synchronization_timestamp: str | None = None, + subscription_delay_seconds: int | None = None, + dlq_topic: Topic | None = None, ) -> KafkaStreamLoader: default_topic_spec = KafkaTopicSpec(default_topic) - replacement_topic_spec: Optional[KafkaTopicSpec] + replacement_topic_spec: KafkaTopicSpec | None if replacement_topic is not None: replacement_topic_spec = KafkaTopicSpec(replacement_topic) else: replacement_topic_spec = None - commit_log_topic_spec: Optional[KafkaTopicSpec] + commit_log_topic_spec: KafkaTopicSpec | None if commit_log_topic is not None: commit_log_topic_spec = KafkaTopicSpec(commit_log_topic) else: commit_log_topic_spec = None - subscription_scheduled_topic_spec: Optional[KafkaTopicSpec] + subscription_scheduled_topic_spec: KafkaTopicSpec | None if subscription_scheduled_topic is not None: subscription_scheduled_topic_spec = KafkaTopicSpec(subscription_scheduled_topic) else: subscription_scheduled_topic_spec = None - subscription_result_topic_spec: Optional[KafkaTopicSpec] + subscription_result_topic_spec: KafkaTopicSpec | None if subscription_result_topic is not None: subscription_result_topic_spec = KafkaTopicSpec(subscription_result_topic) else: @@ -255,7 +256,7 @@ def __init__( storage_set: StorageSetKey, write_schema: WritableTableSchema, stream_loader: KafkaStreamLoader, - replacer_processor: Optional[ReplacerProcessor[Any]] = None, + replacer_processor: ReplacerProcessor[Any] | None = None, writer_options: ClickhouseWriterOptions = None, write_format: WriteFormat = WriteFormat.JSON, ) -> None: @@ -273,9 +274,9 @@ def get_batch_writer( self, metrics: MetricsBackend, options: ClickhouseWriterOptions = None, - table_name: Optional[str] = None, + table_name: str | None = None, chunk_size: int = settings.CLICKHOUSE_HTTP_CHUNK_SIZE, - slice_id: Optional[int] = None, + slice_id: int | None = None, ) -> BatchWriter[JSONRow]: table_name = table_name or self.__table_schema.get_table_name() if self.__write_format == WriteFormat.JSON: @@ -308,10 +309,10 @@ def get_writeable_columns(self) -> Sequence[str]: def get_bulk_writer( self, metrics: MetricsBackend, - encoding: Optional[str], + encoding: str | None, column_names: Sequence[str], options: ClickhouseWriterOptions = None, - table_name: Optional[str] = None, + table_name: str | None = None, ) -> BatchWriter[bytes]: table_name = table_name or self.__table_schema.get_table_name() @@ -331,7 +332,7 @@ def get_bulk_loader( source: BulkLoadSource, source_table: str, row_processor: CdcRowProcessor, - table_name: Optional[str] = None, + table_name: str | None = None, ) -> BulkLoader: """ Returns the instance of the bulk loader to populate the dataset from an @@ -351,7 +352,7 @@ def get_bulk_loader( def get_stream_loader(self) -> KafkaStreamLoader: return self.__stream_loader - def get_replacer_processor(self) -> Optional[ReplacerProcessor[Any]]: + def get_replacer_processor(self) -> ReplacerProcessor[Any] | None: """ Returns a replacement processor if this table writer knows how to do replacements on the table it manages. diff --git a/snuba/environment.py b/snuba/environment.py index 1850c346a5c..32dda5e8ad0 100644 --- a/snuba/environment.py +++ b/snuba/environment.py @@ -1,8 +1,5 @@ -from __future__ import absolute_import - import logging import os -from typing import Optional import sentry_sdk import structlog @@ -46,7 +43,7 @@ def drop_level(logger: logging.Logger, method_name: str, event_dict: EventDict) return event_dict -def setup_logging(level: Optional[str] = None) -> None: +def setup_logging(level: str | None = None) -> None: if level is None: level = settings.LOG_LEVEL diff --git a/snuba/lw_deletions/batching.py b/snuba/lw_deletions/batching.py index 0560de33f59..3365714dac0 100644 --- a/snuba/lw_deletions/batching.py +++ b/snuba/lw_deletions/batching.py @@ -1,7 +1,8 @@ from __future__ import annotations import time -from typing import Callable, Generic, MutableSequence, Optional, TypeVar, Union, cast +from collections.abc import Callable, MutableSequence +from typing import Generic, TypeVar, cast from arroyo.processing.strategies.abstract import ProcessingStrategy from arroyo.processing.strategies.buffer import Buffer @@ -27,7 +28,7 @@ def __init__( initial_value: Callable[[], TResult], max_batch_size: int, max_batch_time: float, - increment_by: Optional[Callable[[BaseValue[TPayload]], int]] = None, + increment_by: Callable[[BaseValue[TPayload]], int] | None = None, ): self.accumulator = accumulator self.initial_value = initial_value @@ -72,7 +73,7 @@ def append(self, message: BaseValue[TPayload]) -> None: self._buffer_size += buffer_increment self._buffered_messages += 1 - def new(self) -> "ReduceRowsBuffer[TPayload, TResult]": + def new(self) -> ReduceRowsBuffer[TPayload, TResult]: return ReduceRowsBuffer( accumulator=self.accumulator, initial_value=self.initial_value, @@ -82,9 +83,7 @@ def new(self) -> "ReduceRowsBuffer[TPayload, TResult]": ) -class ReduceCustom( - ProcessingStrategy[Union[FilteredPayload, TPayload]], Generic[TPayload, TResult] -): +class ReduceCustom(ProcessingStrategy[FilteredPayload | TPayload], Generic[TPayload, TResult]): def __init__( self, max_batch_size: int, @@ -92,7 +91,7 @@ def __init__( accumulator: Accumulator[TResult, TPayload], initial_value: Callable[[], TResult], next_step: ProcessingStrategy[TResult], - increment_by: Optional[Callable[[BaseValue[TPayload]], int]] = None, + increment_by: Callable[[BaseValue[TPayload]], int] | None = None, ) -> None: self.__buffer_step = Buffer( buffer=ReduceRowsBuffer( @@ -105,7 +104,7 @@ def __init__( next_step=next_step, ) - def submit(self, message: Message[Union[FilteredPayload, TPayload]]) -> None: + def submit(self, message: Message[FilteredPayload | TPayload]) -> None: self.__buffer_step.submit(message) def poll(self) -> None: @@ -117,18 +116,18 @@ def close(self) -> None: def terminate(self) -> None: self.__buffer_step.terminate() - def join(self, timeout: Optional[float] = None) -> None: + def join(self, timeout: float | None = None) -> None: self.__buffer_step.join(timeout) -class NoBatchStep(ProcessingStrategy[Union[FilteredPayload, TStrategyPayload]]): +class NoBatchStep(ProcessingStrategy[FilteredPayload | TStrategyPayload]): def __init__( self, next_step: ProcessingStrategy[ValuesBatch[TStrategyPayload]], ) -> None: self.__next_step = next_step - def submit(self, message: Message[Union[FilteredPayload, TStrategyPayload]]) -> None: + def submit(self, message: Message[FilteredPayload | TStrategyPayload]) -> None: if isinstance(message.payload, FilteredPayload): return value = cast(BaseValue[TStrategyPayload], message.value) @@ -143,17 +142,17 @@ def close(self) -> None: def terminate(self) -> None: self.__next_step.terminate() - def join(self, timeout: Optional[float] = None) -> None: + def join(self, timeout: float | None = None) -> None: self.__next_step.join(timeout) -class BatchStepCustom(ProcessingStrategy[Union[FilteredPayload, TStrategyPayload]]): +class BatchStepCustom(ProcessingStrategy[FilteredPayload | TStrategyPayload]): def __init__( self, max_batch_size: int, max_batch_time: float, next_step: ProcessingStrategy[ValuesBatch[TStrategyPayload]], - increment_by: Optional[Callable[[BaseValue[TStrategyPayload]], int]] = None, + increment_by: Callable[[BaseValue[TStrategyPayload]], int] | None = None, ) -> None: def accumulator( result: ValuesBatch[TStrategyPayload], value: BaseValue[TStrategyPayload] @@ -172,7 +171,7 @@ def accumulator( ) ) - def submit(self, message: Message[Union[FilteredPayload, TStrategyPayload]]) -> None: + def submit(self, message: Message[FilteredPayload | TStrategyPayload]) -> None: self.__reduce_step.submit(message) def poll(self) -> None: @@ -184,5 +183,5 @@ def close(self) -> None: def terminate(self) -> None: self.__reduce_step.terminate() - def join(self, timeout: Optional[float] = None) -> None: + def join(self, timeout: float | None = None) -> None: self.__reduce_step.join(timeout) diff --git a/snuba/lw_deletions/formatters.py b/snuba/lw_deletions/formatters.py index 962e80c1717..ce4c81a46e6 100644 --- a/snuba/lw_deletions/formatters.py +++ b/snuba/lw_deletions/formatters.py @@ -1,15 +1,8 @@ from abc import ABC, abstractmethod from collections import defaultdict +from collections.abc import Mapping, MutableMapping, Sequence from typing import ( Any, - Dict, - List, - Mapping, - MutableMapping, - Optional, - Sequence, - Tuple, - Type, ) from sentry_protos.snuba.v1.trace_item_attribute_pb2 import AttributeKey @@ -57,7 +50,7 @@ def format(self, messages: Sequence[DeleteQueryMessage]) -> Sequence[ConditionsB assert isinstance(project_id, int) mapping[project_id] = mapping[project_id].union( # using int() to make mypy happy - set([int(g_id) for g_id in condition["group_id"]]) + {int(g_id) for g_id in condition["group_id"]} ) return [ @@ -72,14 +65,14 @@ def format(self, messages: Sequence[DeleteQueryMessage]) -> Sequence[ConditionsB def _deserialize_attribute_conditions( - data: Optional[Dict[str, WireAttributeCondition]], - item_type: Optional[int] = None, -) -> Optional[AttributeConditions]: + data: dict[str, WireAttributeCondition] | None, + item_type: int | None = None, +) -> AttributeConditions | None: if data is None: return None assert item_type is not None, "attribute_conditions cannot be deserialized without item_type" - attributes: Dict[str, Tuple[AttributeKey, List[Any]]] = {} + attributes: dict[str, tuple[AttributeKey, list[Any]]] = {} for key, wire_condition in data.items(): attr_key_type = wire_condition["attr_key_type"] @@ -110,7 +103,7 @@ def format(self, messages: Sequence[DeleteQueryMessage]) -> Sequence[ConditionsB ] -STORAGE_FORMATTER: Mapping[str, Type[Formatter]] = { +STORAGE_FORMATTER: Mapping[str, type[Formatter]] = { StorageKey.SEARCH_ISSUES.value: SearchIssuesFormatter, StorageKey.EAP_ITEMS.value: EAPItemsFormatter, } diff --git a/snuba/lw_deletions/off_peak.py b/snuba/lw_deletions/off_peak.py index 68631c1df5b..3d1e97823dd 100644 --- a/snuba/lw_deletions/off_peak.py +++ b/snuba/lw_deletions/off_peak.py @@ -1,6 +1,5 @@ import time -from datetime import datetime, timezone -from typing import Optional +from datetime import UTC, datetime from arroyo.backends.kafka import KafkaPayload from arroyo.processing.strategies import ProcessingStrategy @@ -32,7 +31,7 @@ def __init__( ) -> None: self.__next_step = next_step self.__metrics = metrics - self.__cached_result: Optional[bool] = None + self.__cached_result: bool | None = None self.__cached_at: float = 0.0 def poll(self) -> None: @@ -57,7 +56,7 @@ def _is_off_peak(self) -> bool: start = get_int_config("lw_deletions_offpeak_start", default=0) or 0 end = get_int_config("lw_deletions_offpeak_end", default=24) or 24 - current_hour = datetime.now(timezone.utc).hour + current_hour = datetime.now(UTC).hour if start == end: result = False @@ -77,5 +76,5 @@ def close(self) -> None: def terminate(self) -> None: self.__next_step.terminate() - def join(self, timeout: Optional[float] = None) -> None: + def join(self, timeout: float | None = None) -> None: self.__next_step.join(timeout) diff --git a/snuba/lw_deletions/strategy.py b/snuba/lw_deletions/strategy.py index 6676c5a8f62..e11900e1caf 100644 --- a/snuba/lw_deletions/strategy.py +++ b/snuba/lw_deletions/strategy.py @@ -3,8 +3,9 @@ import logging import time import typing +from collections.abc import Mapping, Sequence from datetime import datetime, timedelta -from typing import List, Mapping, Optional, Sequence, TypeVar +from typing import TypeVar import rapidjson from arroyo.backends.kafka import KafkaPayload @@ -72,7 +73,7 @@ def __init__( self.__partition_column = deletion_settings.partition_column self.__formatter: Formatter = formatter self.__metrics = metrics - self.__last_ongoing_mutations_check: Optional[float] = None + self.__last_ongoing_mutations_check: float | None = None self.__redis_client = get_redis_client(RedisClientKey.CONFIG) def poll(self) -> None: @@ -89,7 +90,7 @@ def _filter_allowed_conditions( if not str_config: return conditions # allowlist not set → allow all - org_ids_delete_allowlist = set(int(org_id) for org_id in str_config.split(",")) + org_ids_delete_allowlist = {int(org_id) for org_id in str_config.split(",")} allowed = [] for cond in conditions: @@ -116,12 +117,12 @@ def submit(self, message: Message[ValuesBatch[KafkaPayload]]) -> None: # backpressure is applied while we wait for the # currently ongoing mutations to finish self.__metrics.increment("too_many_ongoing_mutations") - raise MessageRejected + raise MessageRejected from None except QueryException as err: cause = err.__cause__ if isinstance(cause, AllocationPolicyViolations): self.__metrics.increment("allocation_policy_violation") - raise MessageRejected + raise MessageRejected from None self.__next_step.submit(message) @@ -156,7 +157,7 @@ def _conditions_hash(self, conditions: Sequence[ConditionsBag]) -> str: parts.sort() return hashlib.md5("|".join(parts).encode()).hexdigest()[:16] - def _get_partition_dates(self, table: str) -> List[str]: + def _get_partition_dates(self, table: str) -> list[str]: from snuba.util import decode_part_str cluster = self.__storage.get_cluster() @@ -224,6 +225,7 @@ def _execute_delete_by_partition( query_settings: HTTPQuerySettings, conditions: Sequence[ConditionsBag], ) -> None: + assert self.__partition_column is not None partition_dates = self._get_partition_dates(table) if not partition_dates: logger.warning( @@ -256,7 +258,7 @@ def _execute_delete_by_partition( self._check_ongoing_mutations(skip_throttle=True) partition_condition = equals( - FunctionCall(None, "toMonday", (column(self.__partition_column),)), # type: ignore[arg-type] + FunctionCall(None, "toMonday", (column(self.__partition_column),)), literal(partition_date), ) partition_where = combine_and_conditions([where_clause, partition_condition]) @@ -275,7 +277,7 @@ def _execute_single_delete( table: str, query: Query, query_settings: HTTPQuerySettings, - partition_week: Optional[str] = None, + partition_week: str | None = None, ) -> None: tags = {"table": table} if partition_week: @@ -302,7 +304,7 @@ def _execute_single_delete( if cause.code in LW_DELETE_NON_RETRYABLE_CLICKHOUSE_ERROR_CODES: logger.exception("Error running delete query %r", exc) else: - raise LWDeleteQueryException(exc.message) + raise LWDeleteQueryException(exc.message) from exc def _check_ongoing_mutations(self, skip_throttle: bool = False) -> None: now = time.time() @@ -336,7 +338,7 @@ def close(self) -> None: def terminate(self) -> None: self.__next_step.terminate() - def join(self, timeout: Optional[float] = None) -> None: + def join(self, timeout: float | None = None) -> None: self.__next_step.join(timeout) diff --git a/snuba/lw_deletions/types.py b/snuba/lw_deletions/types.py index 28720346fc8..18bfb4428a8 100644 --- a/snuba/lw_deletions/types.py +++ b/snuba/lw_deletions/types.py @@ -1,5 +1,6 @@ +from collections.abc import Mapping, Sequence from dataclasses import dataclass -from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple +from typing import Any from sentry_protos.snuba.v1.trace_item_attribute_pb2 import AttributeKey @@ -9,10 +10,10 @@ @dataclass class AttributeConditions: item_type: int - attributes: Dict[str, Tuple[AttributeKey, List[Any]]] + attributes: dict[str, tuple[AttributeKey, list[Any]]] @dataclass class ConditionsBag: column_conditions: ConditionsType - attribute_conditions: Optional[AttributeConditions] = None + attribute_conditions: AttributeConditions | None = None diff --git a/snuba/manual_jobs/__init__.py b/snuba/manual_jobs/__init__.py index 3bac77a6321..f9180e27a3b 100644 --- a/snuba/manual_jobs/__init__.py +++ b/snuba/manual_jobs/__init__.py @@ -1,8 +1,9 @@ import logging import os from abc import ABC, abstractmethod +from collections.abc import MutableMapping from dataclasses import dataclass -from typing import Any, MutableMapping, Optional, cast +from typing import Any, cast from snuba.manual_jobs.redis import _set_job_type from snuba.utils.registered_class import RegisteredClass, import_submodules_in_directory @@ -36,8 +37,8 @@ def error(self, line: str) -> None: class JobSpec: job_id: str job_type: str - is_async: Optional[bool] = False - params: Optional[MutableMapping[Any, Any]] = None + is_async: bool | None = False + params: MutableMapping[Any, Any] | None = None class Job(ABC, metaclass=RegisteredClass): diff --git a/snuba/manual_jobs/delete_events_by_tag_key_value.py b/snuba/manual_jobs/delete_events_by_tag_key_value.py index ea614cf738f..717c42487ef 100644 --- a/snuba/manual_jobs/delete_events_by_tag_key_value.py +++ b/snuba/manual_jobs/delete_events_by_tag_key_value.py @@ -1,5 +1,6 @@ +from collections.abc import Mapping from datetime import datetime -from typing import Any, Mapping, Optional +from typing import Any from snuba.clickhouse.escaping import escape_string from snuba.clusters.cluster import ClickhouseClientSettings, get_cluster @@ -12,14 +13,14 @@ def __init__(self, job_spec: JobSpec) -> None: self.__validate_job_params(job_spec.params) super().__init__(job_spec) - def __validate_job_params(self, params: Optional[Mapping[Any, Any]]) -> None: + def __validate_job_params(self, params: Mapping[Any, Any] | None) -> None: assert params assert isinstance(params["project_ids"], list) assert len(params["project_ids"]) > 0 assert isinstance(params["tag_key"], str) assert isinstance(params["tag_value"], str) assert params["tag_key"] and params["tag_value"] - assert all([isinstance(p, int) for p in params["project_ids"]]) + assert all(isinstance(p, int) for p in params["project_ids"]) self._project_ids = params["project_ids"] self._tag_key = params["tag_key"] self._tag_value = params["tag_value"] diff --git a/snuba/manual_jobs/extract_span_data.py b/snuba/manual_jobs/extract_span_data.py index d7ca7767cc0..65cc07ccfce 100644 --- a/snuba/manual_jobs/extract_span_data.py +++ b/snuba/manual_jobs/extract_span_data.py @@ -1,4 +1,5 @@ -from typing import Any, Mapping, Optional +from collections.abc import Mapping +from typing import Any from snuba.clusters.cluster import ClickhouseClientSettings, get_cluster from snuba.clusters.storage_sets import StorageSetKey @@ -10,7 +11,7 @@ def __init__(self, job_spec: JobSpec) -> None: self.__validate_job_params(job_spec.params) super().__init__(job_spec) - def __validate_job_params(self, params: Optional[Mapping[Any, Any]]) -> None: + def __validate_job_params(self, params: Mapping[Any, Any] | None) -> None: assert params required_params = [ "organization_ids", diff --git a/snuba/manual_jobs/redis.py b/snuba/manual_jobs/redis.py index aa609b35482..25b38cf2ed5 100644 --- a/snuba/manual_jobs/redis.py +++ b/snuba/manual_jobs/redis.py @@ -1,6 +1,6 @@ import typing +from collections.abc import Sequence from datetime import datetime -from typing import List, Sequence from snuba.manual_jobs.job_status import JobStatus from snuba.redis import RedisClientKey, get_redis_client @@ -61,7 +61,7 @@ def _get_job_type(job_id: str) -> str: return typing.cast(str, _redis_client.get(name=_build_job_type_key(job_id)).decode()) -def _get_job_types_multi(job_ids_keys: Sequence[str]) -> List[str]: +def _get_job_types_multi(job_ids_keys: Sequence[str]) -> list[str]: with _redis_client.pipeline(transaction=False) as pipeline: for job_id_key in job_ids_keys: pipeline.get(job_id_key) @@ -70,7 +70,7 @@ def _get_job_types_multi(job_ids_keys: Sequence[str]) -> List[str]: return [job_type.decode() for job_type in redis_statuses] -def _get_job_status_multi(job_ids_keys: Sequence[str]) -> List[JobStatus]: +def _get_job_status_multi(job_ids_keys: Sequence[str]) -> list[JobStatus]: if len(job_ids_keys) == 0: return [] diff --git a/snuba/manual_jobs/rerun_idempotent_migration.py b/snuba/manual_jobs/rerun_idempotent_migration.py index c9eda465525..0aad7b65297 100644 --- a/snuba/manual_jobs/rerun_idempotent_migration.py +++ b/snuba/manual_jobs/rerun_idempotent_migration.py @@ -1,4 +1,5 @@ -from typing import Any, Mapping, Optional +from collections.abc import Mapping +from typing import Any from snuba.clusters.storage_sets import StorageSetKey from snuba.manual_jobs import Job, JobLogger, JobSpec @@ -21,7 +22,7 @@ def __init__(self, job_spec: JobSpec) -> None: self.__validate_job_params(job_spec.params) super().__init__(job_spec) - def __validate_job_params(self, params: Optional[Mapping[Any, Any]]) -> None: + def __validate_job_params(self, params: Mapping[Any, Any] | None) -> None: assert params is not None, "storage_set and migration_id parameters required" assert params.get("storage_set"), "storage_set required" assert params.get("migration_id"), "migration_id required" diff --git a/snuba/manual_jobs/runner.py b/snuba/manual_jobs/runner.py index f9eca8f4f33..09567d4260e 100644 --- a/snuba/manual_jobs/runner.py +++ b/snuba/manual_jobs/runner.py @@ -1,7 +1,8 @@ import logging import os import traceback -from typing import Any, Mapping, Sequence, Union +from collections.abc import Mapping, Sequence +from typing import Any import simplejson @@ -92,7 +93,7 @@ def list_job_specs( def list_job_specs_with_status( manifest_filename: str = MANIFEST_FILENAME, -) -> Mapping[str, Mapping[str, Union[JobSpec, JobStatus]]]: +) -> Mapping[str, Mapping[str, JobSpec | JobStatus]]: specs = list_job_specs(manifest_filename) job_ids = list(specs.keys()) statuses = _get_job_status_multi([_build_job_status_key(job_id) for job_id in job_ids]) diff --git a/snuba/manual_jobs/scrub_ips_from_eap_spans.py b/snuba/manual_jobs/scrub_ips_from_eap_spans.py index 5a3a503c65d..d75373c17c7 100644 --- a/snuba/manual_jobs/scrub_ips_from_eap_spans.py +++ b/snuba/manual_jobs/scrub_ips_from_eap_spans.py @@ -1,5 +1,6 @@ +from collections.abc import Mapping from datetime import datetime -from typing import Any, Mapping, Optional +from typing import Any from snuba.clusters.cluster import ClickhouseClientSettings, get_cluster from snuba.clusters.storage_sets import StorageSetKey @@ -11,10 +12,10 @@ def __init__(self, job_spec: JobSpec) -> None: self.__validate_job_params(job_spec.params) super().__init__(job_spec) - def __validate_job_params(self, params: Optional[Mapping[Any, Any]]) -> None: + def __validate_job_params(self, params: Mapping[Any, Any] | None) -> None: assert params assert isinstance(params["organization_ids"], list) - assert all([isinstance(p, int) for p in params["organization_ids"]]) + assert all(isinstance(p, int) for p in params["organization_ids"]) self._organization_ids = params["organization_ids"] self._start_datetime = datetime.fromisoformat(params["start_datetime"]) self._end_datetime = datetime.fromisoformat(params["end_datetime"]) diff --git a/snuba/manual_jobs/scrub_users_from_eap_spans.py b/snuba/manual_jobs/scrub_users_from_eap_spans.py index 4c2d8f0792a..53a727555b6 100644 --- a/snuba/manual_jobs/scrub_users_from_eap_spans.py +++ b/snuba/manual_jobs/scrub_users_from_eap_spans.py @@ -1,5 +1,6 @@ +from collections.abc import Mapping from datetime import datetime -from typing import Any, Mapping, Optional +from typing import Any from snuba.clusters.cluster import ClickhouseClientSettings, get_cluster from snuba.clusters.storage_sets import StorageSetKey @@ -14,10 +15,10 @@ def __init__(self, job_spec: JobSpec) -> None: self.__validate_job_params(job_spec.params) super().__init__(job_spec) - def __validate_job_params(self, params: Optional[Mapping[Any, Any]]) -> None: + def __validate_job_params(self, params: Mapping[Any, Any] | None) -> None: assert params assert isinstance(params["organization_ids"], list) - assert all([isinstance(p, int) for p in params["organization_ids"]]) + assert all(isinstance(p, int) for p in params["organization_ids"]) self._organization_ids = params["organization_ids"] self._start_datetime = datetime.fromisoformat(params["start_datetime"]) self._end_datetime = datetime.fromisoformat(params["end_datetime"]) diff --git a/snuba/manual_jobs/scrub_users_from_eap_spans_str_attrs.py b/snuba/manual_jobs/scrub_users_from_eap_spans_str_attrs.py index 2ee9139aa65..87de060aec2 100644 --- a/snuba/manual_jobs/scrub_users_from_eap_spans_str_attrs.py +++ b/snuba/manual_jobs/scrub_users_from_eap_spans_str_attrs.py @@ -1,5 +1,6 @@ +from collections.abc import Mapping from datetime import datetime -from typing import Any, Mapping, Optional +from typing import Any from snuba.clusters.cluster import ClickhouseClientSettings, get_cluster from snuba.clusters.storage_sets import StorageSetKey @@ -13,10 +14,10 @@ def __init__(self, job_spec: JobSpec) -> None: self.__validate_job_params(job_spec.params) super().__init__(job_spec) - def __validate_job_params(self, params: Optional[Mapping[Any, Any]]) -> None: + def __validate_job_params(self, params: Mapping[Any, Any] | None) -> None: assert params assert isinstance(params["organization_ids"], list) - assert all([isinstance(p, int) for p in params["organization_ids"]]) + assert all(isinstance(p, int) for p in params["organization_ids"]) self._organization_ids = params["organization_ids"] self._start_datetime = datetime.fromisoformat(params["start_datetime"]) self._end_datetime = datetime.fromisoformat(params["end_datetime"]) diff --git a/snuba/manual_jobs/update_migration_status.py b/snuba/manual_jobs/update_migration_status.py index 3433dab26d2..a6d4f90fc22 100644 --- a/snuba/manual_jobs/update_migration_status.py +++ b/snuba/manual_jobs/update_migration_status.py @@ -1,5 +1,6 @@ +from collections.abc import Mapping from datetime import datetime -from typing import Any, Mapping, Optional +from typing import Any from snuba.clusters.cluster import ClickhouseClientSettings, get_cluster from snuba.clusters.storage_sets import StorageSetKey @@ -29,7 +30,7 @@ def __init__(self, job_spec: JobSpec): self.__validate_job_params(job_spec.params) super().__init__(job_spec) - def __validate_job_params(self, params: Optional[Mapping[Any, Any]]) -> None: + def __validate_job_params(self, params: Mapping[Any, Any] | None) -> None: assert params is not None, "group, migration_id, old_status, new_status parameters required" assert params["group"], "group required" assert params["migration_id"], "migration_id required" diff --git a/snuba/migrations/autogeneration/diff.py b/snuba/migrations/autogeneration/diff.py index 5c28d9e2ca6..53d3286ec16 100644 --- a/snuba/migrations/autogeneration/diff.py +++ b/snuba/migrations/autogeneration/diff.py @@ -1,4 +1,5 @@ -from typing import Any, Sequence, cast +from collections.abc import Sequence +from typing import Any, cast from snuba.clusters.storage_sets import StorageSetKey from snuba.datasets.configuration.utils import parse_columns @@ -40,7 +41,7 @@ def _storage_diff_to_migration_ops( if not valid: raise ValueError(reason) - oldcol_names = set(col["name"] for col in oldstorage["schema"]["columns"]) + oldcol_names = {col["name"] for col in oldstorage["schema"]["columns"]} newcols = newstorage["schema"]["columns"] forwardops: list[AddColumn] = [] @@ -114,7 +115,7 @@ def _is_valid_add_column( # verify nothing changed but the columns t1 = oldstorage["schema"].pop("columns") t2 = newstorage["schema"].pop("columns") - if not (oldstorage == newstorage): + if oldstorage != newstorage: return ( False, "Expected the only change to the storage to be the columns, but that is not true", @@ -126,8 +127,8 @@ def _is_valid_add_column( oldstorage_cols = oldstorage["schema"]["columns"] newstorage_cols = newstorage["schema"]["columns"] - colnames_old = set(e["name"] for e in oldstorage_cols) - colnames_new = set(e["name"] for e in newstorage_cols) + colnames_old = {e["name"] for e in oldstorage_cols} + colnames_new = {e["name"] for e in newstorage_cols} if not colnames_old.issubset(colnames_new): return (False, "Column removal is not supported") @@ -150,8 +151,7 @@ def _is_valid_add_column( False, "Adding a column to the beginning is currently unsupported, please add it anywhere else.", ) - else: - pnew += 1 + pnew += 1 assert pold == len(oldstorage_cols) # should always hold return True, "" diff --git a/snuba/migrations/autogeneration/main.py b/snuba/migrations/autogeneration/main.py index c71525385b2..491f0e02bc0 100644 --- a/snuba/migrations/autogeneration/main.py +++ b/snuba/migrations/autogeneration/main.py @@ -1,7 +1,6 @@ import ast import os import subprocess -from typing import Optional from yaml import safe_load @@ -10,7 +9,7 @@ from snuba.migrations.autogeneration.diff import generate_python_migration -def generate(storage_path: str, migration_name: Optional[str] = None) -> str: +def generate(storage_path: str, migration_name: str | None = None) -> str: # load into memory the given storage and the version of it at HEAD tmpnew, tmpold = get_working_and_head(storage_path) new_storage = safe_load(tmpnew) @@ -63,7 +62,7 @@ def get_working_and_head(path: str) -> tuple[str, str]: raise ValueError(e.stderr.decode("utf-8")) from e # working - with open(path, "r") as f: + with open(path) as f: working_file = f.read() return (working_file, head_file) diff --git a/snuba/migrations/check_dangerous.py b/snuba/migrations/check_dangerous.py index 2436d70910c..def961be85f 100644 --- a/snuba/migrations/check_dangerous.py +++ b/snuba/migrations/check_dangerous.py @@ -1,5 +1,4 @@ import re -from typing import Type from snuba.migrations.columns import LowCardinality, MigrationModifiers from snuba.migrations.operations import ModifyColumn, SqlOperation @@ -64,32 +63,32 @@ def _has_codec(codec_str: str, modifiers: MigrationModifiers) -> bool: return True return False - if _has_codec("Delta", modifiers): - if not (_has_codec("ZSTD", modifiers) or _has_codec("LZ4", modifiers)): - raise DangerousOperationError( - f"Changing column type from {old_type} to {new_col_type} is dangerous.\n" - "Clickhouse 21 doesn't support Delta codec without ZSTD or LZ4. " - "To attempt to run this migration set blocking=True" - ) + if _has_codec("Delta", modifiers) and not ( + _has_codec("ZSTD", modifiers) or _has_codec("LZ4", modifiers) + ): + raise DangerousOperationError( + f"Changing column type from {old_type} to {new_col_type} is dangerous.\n" + "Clickhouse 21 doesn't support Delta codec without ZSTD or LZ4. " + "To attempt to run this migration set blocking=True" + ) def _check_modifiers( modifier_str: str, - modifier: Type[TypeModifier], + modifier: type[TypeModifier], old_type_str: str, new_col_type: ColumnType[MigrationModifiers], ) -> None: modifier_str = modifier_str.lower() new_modifiers = new_col_type.get_modifiers() - if isinstance(new_modifiers, MigrationModifiers): - if ( - modifier_str in old_type_str - and not new_modifiers.has_modifier(modifier) - or modifier_str not in old_type_str - and new_modifiers.has_modifier(modifier) - ): - raise DangerousOperationError( - f"Changing column type from {old_type_str} to {new_col_type} is dangerous " - f"because only one has {modifier_str} type in it.\nChanging it will block or isn't supported. " - "To attempt to run this migration set blocking=True" - ) + if isinstance(new_modifiers, MigrationModifiers) and ( + modifier_str in old_type_str + and not new_modifiers.has_modifier(modifier) + or modifier_str not in old_type_str + and new_modifiers.has_modifier(modifier) + ): + raise DangerousOperationError( + f"Changing column type from {old_type_str} to {new_col_type} is dangerous " + f"because only one has {modifier_str} type in it.\nChanging it will block or isn't supported. " + "To attempt to run this migration set blocking=True" + ) diff --git a/snuba/migrations/columns.py b/snuba/migrations/columns.py index fb6df195065..f0bd82721ed 100644 --- a/snuba/migrations/columns.py +++ b/snuba/migrations/columns.py @@ -1,7 +1,7 @@ from __future__ import annotations +from collections.abc import Sequence from dataclasses import dataclass -from typing import List, Optional, Sequence from snuba.clickhouse.columns import ( Nullable, @@ -19,13 +19,13 @@ class MigrationModifiers(TypeModifiers): nullable: bool = False low_cardinality: bool = False - default: Optional[str] = None - materialized: Optional[str] = None - codecs: Optional[Sequence[str]] = None - ttl: Optional[str] = None + default: str | None = None + materialized: str | None = None + codecs: Sequence[str] | None = None + ttl: str | None = None def _get_modifiers(self) -> Sequence[TypeModifier]: - ret: List[TypeModifier] = [] + ret: list[TypeModifier] = [] if self.nullable: ret.append(Nullable()) if self.low_cardinality: @@ -53,7 +53,7 @@ def merge(self, other: MigrationModifiers) -> MigrationModifiers: def __eq__(self, other: object) -> bool: if isinstance(other, SchemaModifiers): return self.nullable == other.nullable - elif isinstance(other, MigrationModifiers): + if isinstance(other, MigrationModifiers): return ( self.nullable == other.nullable and self.low_cardinality == other.low_cardinality @@ -70,7 +70,7 @@ class Materialized(TypeModifier): expression: str def for_schema(self, content: str) -> str: - return "{} MATERIALIZED {}".format(content, self.expression) + return f"{content} MATERIALIZED {self.expression}" @dataclass(frozen=True) @@ -86,13 +86,13 @@ class WithDefault(TypeModifier): default: str def for_schema(self, content: str) -> str: - return "{} DEFAULT {}".format(content, self.default) + return f"{content} DEFAULT {self.default}" @dataclass(frozen=True) class LowCardinality(TypeModifier): def for_schema(self, content: str) -> str: - return "LowCardinality({})".format(content) + return f"LowCardinality({content})" @dataclass(frozen=True) diff --git a/snuba/migrations/connect.py b/snuba/migrations/connect.py index 05044c68a28..50e325cd1d0 100644 --- a/snuba/migrations/connect.py +++ b/snuba/migrations/connect.py @@ -1,6 +1,6 @@ import re import time -from typing import List, Sequence, Tuple +from collections.abc import Sequence import structlog from packaging import version @@ -35,7 +35,7 @@ def get_clickhouse_clusters_for_migration_group( migration_group: MigrationGroup, -) -> List[ClickhouseCluster]: +) -> list[ClickhouseCluster]: storage_set_keys = get_storage_set_keys(migration_group) return list({get_cluster(storage_set_key) for storage_set_key in storage_set_keys}) @@ -110,15 +110,15 @@ def check_clickhouse(clickhouse: ClickhousePool) -> None: Checks that the clickhouse version is at least the min version and at most the max version """ ver = clickhouse.execute("SELECT version()").results[0][0] - ver = re.search(r"(\d+.\d+.\d+.\d+)", ver) - if ver is None or version.parse(ver.group()) < version.parse(CLICKHOUSE_SERVER_MIN_VERSION): + match = re.search(r"(\d+.\d+.\d+.\d+)", ver) + if match is None or version.parse(match.group()) < version.parse(CLICKHOUSE_SERVER_MIN_VERSION): raise InvalidClickhouseVersion( - f"Snuba requires minimum Clickhouse version {CLICKHOUSE_SERVER_MIN_VERSION} ({clickhouse.host}:{clickhouse.port} - {version.parse(ver.group())})" + f"Snuba requires minimum Clickhouse version {CLICKHOUSE_SERVER_MIN_VERSION} ({clickhouse.host}:{clickhouse.port} - {version.parse(match.group()) if match else None})" ) - if version.parse(ver.group()) > version.parse(CLICKHOUSE_SERVER_MAX_VERSION): + if version.parse(match.group()) > version.parse(CLICKHOUSE_SERVER_MAX_VERSION): logger.warning( - f"Snuba has only been tested on Clickhouse versions up to {CLICKHOUSE_SERVER_MAX_VERSION} ({clickhouse.host}:{clickhouse.port} - {version.parse(ver.group())}). Higher versions might not be supported." + f"Snuba has only been tested on Clickhouse versions up to {CLICKHOUSE_SERVER_MAX_VERSION} ({clickhouse.host}:{clickhouse.port} - {version.parse(match.group())}). Higher versions might not be supported." ) @@ -136,7 +136,7 @@ def _get_all_storage_keys() -> Sequence[StorageKey]: def _get_all_nodes_for_storage( storage_key: StorageKey, -) -> Tuple[Sequence[ClickhouseNode], Sequence[ClickhouseNode], ClickhouseNode]: +) -> tuple[Sequence[ClickhouseNode], Sequence[ClickhouseNode], ClickhouseNode]: """ Returns all nodes for a given storage key. """ @@ -153,7 +153,7 @@ def _get_all_nodes_for_storage( return (local_nodes, distributed_nodes, query_node) -def check_for_inactive_replicas(clusters: List[ClickhouseCluster]) -> None: +def check_for_inactive_replicas(clusters: list[ClickhouseCluster]) -> None: """ Checks for inactive replicas and raise InactiveClickhouseReplica if any are found. """ diff --git a/snuba/migrations/context.py b/snuba/migrations/context.py index a8dc3eb222b..33f6ac6266e 100644 --- a/snuba/migrations/context.py +++ b/snuba/migrations/context.py @@ -1,5 +1,6 @@ import logging -from typing import Callable, NamedTuple +from collections.abc import Callable +from typing import NamedTuple from snuba.migrations.status import Status diff --git a/snuba/migrations/group_loader.py b/snuba/migrations/group_loader.py index 1636cabfa93..53330e68209 100644 --- a/snuba/migrations/group_loader.py +++ b/snuba/migrations/group_loader.py @@ -2,9 +2,9 @@ import os from abc import ABC, abstractmethod +from collections.abc import Sequence from glob import glob from importlib import import_module -from typing import Sequence from snuba.migrations.errors import MigrationDoesNotExist from snuba.migrations.migration import Migration @@ -53,10 +53,8 @@ def get_migrations(self) -> Sequence[str]: return [] # grab the migrations, ignore all other files migration_filenames = sorted( - map( - lambda x: os.path.basename(x)[:-3], - glob(os.path.join(migration_folder, "[0-9][0-9][0-9][0-9]_*.py")), - ) + os.path.basename(x)[:-3] + for x in glob(os.path.join(migration_folder, "[0-9][0-9][0-9][0-9]_*.py")) ) # validate no duplicate migration numbers last = None @@ -79,9 +77,10 @@ def get_migrations(self) -> Sequence[str]: def load_migration(self, migration_id: str) -> Migration: try: module = import_module(f"{self.__module}.{migration_id}") - return module.Migration() # type: ignore - except ModuleNotFoundError: - raise MigrationDoesNotExist("Invalid migration ID") + migration: Migration = module.Migration() + return migration + except ModuleNotFoundError as e: + raise MigrationDoesNotExist("Invalid migration ID") from e class SystemLoader(DirectoryLoader): diff --git a/snuba/migrations/groups.py b/snuba/migrations/groups.py index a47de41e8c9..f5347e057d5 100644 --- a/snuba/migrations/groups.py +++ b/snuba/migrations/groups.py @@ -1,5 +1,4 @@ from enum import Enum -from typing import Dict, Set from snuba.clusters.storage_sets import StorageSetKey from snuba.datasets.readiness_state import ReadinessState @@ -68,7 +67,7 @@ class _MigrationGroup: def __init__( self, loader: GroupLoader, - storage_sets_keys: Set[StorageSetKey], + storage_sets_keys: set[StorageSetKey], readiness_state: ReadinessState, ) -> None: self.loader = loader @@ -76,7 +75,7 @@ def __init__( self.readiness_state = readiness_state -_REGISTERED_MIGRATION_GROUPS: Dict[MigrationGroup, _MigrationGroup] = { +_REGISTERED_MIGRATION_GROUPS: dict[MigrationGroup, _MigrationGroup] = { MigrationGroup.SYSTEM: _MigrationGroup( loader=SystemLoader(), storage_sets_keys={StorageSetKey.MIGRATIONS}, @@ -187,7 +186,7 @@ class DuplicateStorageSetFoundInGroup(Exception): pass -def build_storage_set_to_group_mapping() -> Dict[StorageSetKey, MigrationGroup]: +def build_storage_set_to_group_mapping() -> dict[StorageSetKey, MigrationGroup]: result = {} for migration_group, _migration_group in _REGISTERED_MIGRATION_GROUPS.items(): for storage_set_key in _migration_group.storage_set_keys: @@ -199,7 +198,7 @@ def build_storage_set_to_group_mapping() -> Dict[StorageSetKey, MigrationGroup]: return result -_STORAGE_SET_TO_MIGRATION_GROUP_MAPPING: Dict[StorageSetKey, MigrationGroup] = ( +_STORAGE_SET_TO_MIGRATION_GROUP_MAPPING: dict[StorageSetKey, MigrationGroup] = ( build_storage_set_to_group_mapping() ) @@ -208,17 +207,17 @@ def get_group_loader(group: MigrationGroup) -> GroupLoader: return _REGISTERED_MIGRATION_GROUPS[group].loader -def get_storage_set_keys(group: MigrationGroup) -> Set[StorageSetKey]: +def get_storage_set_keys(group: MigrationGroup) -> set[StorageSetKey]: return _REGISTERED_MIGRATION_GROUPS[group].storage_set_keys def get_group_readiness_state_from_storage_set( storage_set_key: StorageSetKey, ) -> ReadinessState: - migration_group = _STORAGE_SET_TO_MIGRATION_GROUP_MAPPING.get(storage_set_key, None) + migration_group = _STORAGE_SET_TO_MIGRATION_GROUP_MAPPING.get(storage_set_key) if not migration_group: return ReadinessState.LIMITED - registered_migration_group = _REGISTERED_MIGRATION_GROUPS.get(migration_group, None) + registered_migration_group = _REGISTERED_MIGRATION_GROUPS.get(migration_group) if registered_migration_group: return registered_migration_group.readiness_state return ReadinessState.LIMITED diff --git a/snuba/migrations/migration.py b/snuba/migrations/migration.py index 43156f152fe..dfd0cd8da46 100644 --- a/snuba/migrations/migration.py +++ b/snuba/migrations/migration.py @@ -1,6 +1,6 @@ import warnings from abc import ABC, abstractmethod, abstractproperty -from typing import Optional, Sequence +from collections.abc import Sequence from snuba import settings from snuba.clusters.cluster import get_cluster @@ -149,7 +149,7 @@ def forwards( self, context: Context, dry_run: bool = False, - columns_state_to_check: Optional[ColumnStatesMapType] = None, + columns_state_to_check: ColumnStatesMapType | None = None, ) -> None: ops = self.forwards_ops() @@ -181,7 +181,7 @@ def backwards( self, context: Context, dry_run: bool, - columns_state_to_check: Optional[ColumnStatesMapType] = None, + columns_state_to_check: ColumnStatesMapType | None = None, ) -> None: ops = self.backwards_ops() if dry_run: @@ -257,6 +257,7 @@ def backwards_ops(self) -> Sequence[SqlOperation]: warnings.warn( "backwards_local and backwards_dist are deprecated. Use backwards_ops instead.", DeprecationWarning, + stacklevel=2, ) local_ops, dist_ops = self.backwards_local(), self.backwards_dist() self._set_targets(local_ops, OperationTarget.LOCAL) @@ -264,13 +265,13 @@ def backwards_ops(self) -> Sequence[SqlOperation]: if self.backwards_local_first: return (*local_ops, *dist_ops) - else: - return (*dist_ops, *local_ops) + return (*dist_ops, *local_ops) def forwards_ops(self) -> Sequence[SqlOperation]: warnings.warn( "forwards_local and forwards_dist are deprecated. Use forwards_ops instead.", DeprecationWarning, + stacklevel=2, ) local_ops, dist_ops = self.forwards_local(), self.forwards_dist() self._set_targets(local_ops, OperationTarget.LOCAL) @@ -278,5 +279,4 @@ def forwards_ops(self) -> Sequence[SqlOperation]: if self.forwards_local_first: return (*local_ops, *dist_ops) - else: - return (*dist_ops, *local_ops) + return (*dist_ops, *local_ops) diff --git a/snuba/migrations/migration_utilities.py b/snuba/migrations/migration_utilities.py index 84723a5e2d2..28982bf3538 100644 --- a/snuba/migrations/migration_utilities.py +++ b/snuba/migrations/migration_utilities.py @@ -1,14 +1,12 @@ -from typing import Optional, Set, Tuple - from snuba.clickhouse.native import ClickhousePool from snuba.clusters.cluster import ClickhouseClientSettings, get_cluster from snuba.clusters.storage_sets import StorageSetKey -ClickhouseVersion = Tuple[int, int] +ClickhouseVersion = tuple[int, int] def get_clickhouse_version_for_storage_set( - storage_set: StorageSetKey, clickhouse: Optional[ClickhousePool] + storage_set: StorageSetKey, clickhouse: ClickhousePool | None ) -> ClickhouseVersion: """ Determine the clickhouse version for a storage set. Assumes (and verifies) @@ -24,7 +22,7 @@ def get_clickhouse_version_for_storage_set( for node in cluster.get_local_nodes() ] - versions: Set[ClickhouseVersion] = set() + versions: set[ClickhouseVersion] = set() for connection in connections: ver = connection.execute("SELECT version()").results[0][0] diff --git a/snuba/migrations/operations.py b/snuba/migrations/operations.py index fea17db450b..5369e6e1d12 100644 --- a/snuba/migrations/operations.py +++ b/snuba/migrations/operations.py @@ -3,9 +3,10 @@ import logging import time from abc import ABC, abstractmethod +from collections.abc import Callable, Mapping, Sequence from dataclasses import dataclass from enum import Enum -from typing import Any, Callable, Mapping, Optional, Sequence, Tuple, Union +from typing import Any import structlog @@ -47,7 +48,7 @@ def __init__( self, storage_set: StorageSetKey, target: OperationTarget, - settings: Optional[Mapping[str, Any]] = None, + settings: Mapping[str, Any] | None = None, ): self._storage_set = storage_set self._settings = settings @@ -78,7 +79,7 @@ def _get_on_cluster_clause(self) -> str: return f" ON CLUSTER '{cluster_name}'" return "" - def _get_execution_node(self) -> Optional[ClickhouseNode]: + def _get_execution_node(self) -> ClickhouseNode | None: """Returns a single node to execute DDL on (ON CLUSTER handles distribution).""" cluster = get_cluster(self._storage_set) if self.target == OperationTarget.DISTRIBUTED: @@ -86,7 +87,7 @@ def _get_execution_node(self) -> Optional[ClickhouseNode]: return None nodes = cluster.get_distributed_nodes() return nodes[0] if nodes else None - elif self.target == OperationTarget.LOCAL: + if self.target == OperationTarget.LOCAL: nodes = cluster.get_local_nodes() return nodes[0] if nodes else None raise ValueError(f"Target not set for {self}") @@ -147,9 +148,8 @@ def _execute_per_node(self) -> None: nodes = self.get_nodes() cluster = get_cluster(self._storage_set) sql = self.format_sql() - if nodes: - if settings.LOG_MIGRATIONS: - logger.info(f"Executing op: {sql}") + if nodes and settings.LOG_MIGRATIONS: + logger.info(f"Executing op: {sql}") for node in nodes: connection = cluster.get_node_connection(ClickhouseClientSettings.MIGRATE, node) if settings.LOG_MIGRATIONS: @@ -206,11 +206,11 @@ class RetryOnSyncError: def execute(self) -> None: for i in range(30, -1, -1): # wait at most ~30 seconds try: - super().execute() # type: ignore + super().execute() # type: ignore[misc] break except Exception as e: # Metadata on replica is not up to date with common metadata in Zookeeper (status code = 517) - if i and e.code == 517: # type: ignore + if i and e.code == 517: # type: ignore[attr-defined] time.sleep(1) else: raise @@ -230,7 +230,7 @@ def __init__( columns: Sequence[Column[MigrationModifiers]], engine: TableEngine, target: OperationTarget = OperationTarget.UNSET, - settings: Optional[Mapping[str, Any]] = None, + settings: Mapping[str, Any] | None = None, ): super().__init__(storage_set, target=target, settings=settings) self.table_name = table_name @@ -385,7 +385,7 @@ def __init__( storage_set: StorageSetKey, table_name: str, column: Column[MigrationModifiers], - after: Optional[str] = None, + after: str | None = None, target: OperationTarget = OperationTarget.UNSET, ): super().__init__(storage_set, target=target) @@ -457,7 +457,7 @@ def __init__( storage_set: StorageSetKey, table_name: str, column: Column[MigrationModifiers], - ttl_month: Optional[Tuple[str, int]] = None, + ttl_month: tuple[str, int] | None = None, target: OperationTarget = OperationTarget.UNSET, ): super().__init__(storage_set, target=target) @@ -547,7 +547,7 @@ def __init__( index_expression: str, index_type: str, granularity: int, - after: Optional[str] = None, + after: str | None = None, target: OperationTarget = OperationTarget.UNSET, ): super().__init__(storage_set, target=target) @@ -678,11 +678,11 @@ def __init__( dest_columns: Sequence[str], src_table_name: str, src_columns: Sequence[str], - prewhere: Optional[str] = None, - order_by: Optional[str] = None, - limit: Optional[int] = None, - offset: Optional[int] = None, - where: Optional[str] = None, + prewhere: str | None = None, + order_by: str | None = None, + limit: int | None = None, + offset: int | None = None, + where: str | None = None, target: OperationTarget = OperationTarget.UNSET, ): super().__init__(storage_set, target=target) @@ -738,7 +738,7 @@ def execute_new_node( raise NotImplementedError @abstractmethod - def description(self) -> Optional[str]: + def description(self) -> str | None: raise NotImplementedError @@ -752,8 +752,8 @@ class RunPython(GenericOperation): def __init__( self, func: Callable[[logging.Logger], None], - new_node_func: Optional[Callable[[Sequence[StorageSetKey]], None]] = None, - description: Optional[str] = None, + new_node_func: Callable[[Sequence[StorageSetKey]], None] | None = None, + description: str | None = None, ) -> None: self.__func = func self.__new_node_func = new_node_func @@ -771,7 +771,7 @@ def execute_new_node( if self.__new_node_func is not None: self.__new_node_func(storage_sets) - def description(self) -> Optional[str]: + def description(self) -> str | None: return self.__description @@ -784,15 +784,14 @@ class RunSqlAsCode(GenericOperation): def __init__( self, - operation_function: Union[SqlOperation, Callable[[Optional[ClickhousePool]], SqlOperation]], + operation_function: SqlOperation | Callable[[ClickhousePool | None], SqlOperation], ) -> None: self.__operation_function = operation_function - def _get_operation(self, clickhouse: Optional[ClickhousePool]) -> SqlOperation: + def _get_operation(self, clickhouse: ClickhousePool | None) -> SqlOperation: if callable(self.__operation_function): return self.__operation_function(clickhouse) - else: - return self.__operation_function + return self.__operation_function def execute(self, logger: logging.Logger) -> None: self._get_operation(None).execute() @@ -816,5 +815,5 @@ def execute_new_node( logger.info(f"Executing {sql}") clickhouse.execute(sql) - def description(self) -> Optional[str]: + def description(self) -> str | None: return self._get_operation(None).format_sql() diff --git a/snuba/migrations/parse_schema.py b/snuba/migrations/parse_schema.py index 2294adf3e40..f80010e65ec 100644 --- a/snuba/migrations/parse_schema.py +++ b/snuba/migrations/parse_schema.py @@ -1,7 +1,8 @@ from __future__ import annotations import re -from typing import Any, Iterable, Mapping, Sequence +from collections.abc import Iterable, Mapping, Sequence +from typing import Any from clickhouse_driver import Client from parsimonious.grammar import Grammar @@ -89,8 +90,7 @@ def merge_modifiers( existing_modifiers = col_type.get_modifiers() if existing_modifiers is None: return col_type.set_modifiers(modifiers) - else: - return col_type.set_modifiers(existing_modifiers.merge(modifiers)) + return col_type.set_modifiers(existing_modifiers.merge(modifiers)) _TYPES: dict[str, type[ColumnType[MigrationModifiers]]] = { @@ -105,7 +105,7 @@ def merge_modifiers( } -class Visitor(NodeVisitor): # type: ignore +class Visitor(NodeVisitor): # type: ignore[misc] def visit_basic_type( self, node: Node, visited_children: Iterable[Any] ) -> ColumnType[MigrationModifiers]: @@ -420,6 +420,6 @@ def get_local_schema(conn: Client, table_name: str) -> Mapping[str, ColumnType[M return { column_name: _get_column(column_type, default_type, default_expr, codec_expr) for column_name, column_type, default_type, default_expr, _comment, codec_expr in [ - cols[:6] for cols in conn.execute("DESCRIBE TABLE %s" % table_name).results + cols[:6] for cols in conn.execute(f"DESCRIBE TABLE {table_name}").results ] } diff --git a/snuba/migrations/policies.py b/snuba/migrations/policies.py index d7104f073dc..ab8d48cd582 100644 --- a/snuba/migrations/policies.py +++ b/snuba/migrations/policies.py @@ -64,7 +64,7 @@ def can_run(self, migration_key: MigrationKey) -> bool: return True migration = get_group_loader(migration_key.group).load_migration(migration_key.migration_id) - return False if migration.blocking else True + return not migration.blocking def can_reverse(self, migration_key: MigrationKey) -> bool: if get_group_readiness_state(migration_key.group) == ReadinessState.EXPERIMENTAL: @@ -73,14 +73,14 @@ def can_reverse(self, migration_key: MigrationKey) -> bool: status, timestamp = Runner().get_status(migration_key) migration = get_group_loader(migration_key.group).load_migration(migration_key.migration_id) if status == Status.IN_PROGRESS: - return False if migration.blocking else True + return not migration.blocking if status == Status.COMPLETED and timestamp: oldest_allowed_timestamp = datetime.now() + timedelta( hours=-MAX_MIGRATIONS_REVERT_TIME_WINDOW_HRS ) if timestamp >= oldest_allowed_timestamp: - return False if migration.blocking else True + return not migration.blocking return False diff --git a/snuba/migrations/runner.py b/snuba/migrations/runner.py index 86d661a7eea..f71e70efcfe 100644 --- a/snuba/migrations/runner.py +++ b/snuba/migrations/runner.py @@ -1,7 +1,8 @@ from collections import defaultdict +from collections.abc import Mapping, MutableMapping, Sequence from datetime import datetime from functools import partial -from typing import List, Mapping, MutableMapping, NamedTuple, Optional, Sequence, Tuple +from typing import NamedTuple import structlog from clickhouse_driver import errors @@ -77,9 +78,9 @@ def __init__(self) -> None: ClickhouseClientSettings.MIGRATE ) - self.__status: MutableMapping[MigrationKey, Tuple[Status, Optional[datetime]]] = {} + self.__status: MutableMapping[MigrationKey, tuple[Status, datetime | None]] = {} - def get_status(self, migration_key: MigrationKey) -> Tuple[Status, Optional[datetime]]: + def get_status(self, migration_key: MigrationKey) -> tuple[Status, datetime | None]: """ Returns the status and timestamp of a migration. """ @@ -128,12 +129,12 @@ def force_overwrite_status( ) def show_all( - self, groups: Optional[Sequence[str]] = None, include_nonexistent: bool = False - ) -> List[Tuple[MigrationGroup, List[MigrationDetails]]]: + self, groups: Sequence[str] | None = None, include_nonexistent: bool = False + ) -> list[tuple[MigrationGroup, list[MigrationDetails]]]: """ Returns the list of migrations and their statuses for each group. """ - migrations: List[Tuple[MigrationGroup, List[MigrationDetails]]] = [] + migrations: list[tuple[MigrationGroup, list[MigrationDetails]]] = [] if groups: migration_groups: Sequence[MigrationGroup] = [MigrationGroup(group) for group in groups] @@ -142,14 +143,14 @@ def show_all( migration_status = self._get_migration_status(migration_groups) clickhouse_group_migrations = defaultdict(set) - for group, migration_id in migration_status.keys(): + for group, migration_id in migration_status: clickhouse_group_migrations[group].add(migration_id) def get_status(migration_key: MigrationKey) -> Status: return migration_status.get(migration_key, Status.NOT_STARTED) for group in migration_groups: - group_migrations: List[MigrationDetails] = [] + group_migrations: list[MigrationDetails] = [] group_loader = get_group_loader(group) migration_ids = group_loader.get_migrations() @@ -185,8 +186,8 @@ def run_all( through: str = "all", fake: bool = False, force: bool = False, - group: Optional[MigrationGroup] = None, - readiness_states: Optional[Sequence[ReadinessState]] = None, + group: MigrationGroup | None = None, + readiness_states: Sequence[ReadinessState] | None = None, check_dangerous: bool = False, ) -> None: """ @@ -215,7 +216,7 @@ def run_all( if get_group_readiness_state(m.group) in readiness_states ] - use_through = False if through == "all" else True + use_through = through != "all" def exact_migration_exists(through: str) -> bool: migration_ids = [ @@ -223,9 +224,7 @@ def exact_migration_exists(through: str) -> bool: for key in pending_migrations if key.migration_id.startswith(through) ] - if len(migration_ids) == 1: - return True - return False + return len(migration_ids) == 1 if use_through and not exact_migration_exists(through): raise MigrationError(f"No exact match for: {through}") @@ -376,8 +375,8 @@ def reverse_all( fake: bool = False, force: bool = False, include_system: bool = False, - group: Optional[MigrationGroup] = None, - readiness_states: Optional[Sequence[ReadinessState]] = None, + group: MigrationGroup | None = None, + readiness_states: Sequence[ReadinessState] | None = None, ) -> None: if not force: raise MigrationError("Requires force to reverse migrations") @@ -426,7 +425,7 @@ def exact_migration_exists(through: str) -> bool: def reverse_in_progress( self, fake: bool = False, - group: Optional[MigrationGroup] = None, + group: MigrationGroup | None = None, dry_run: bool = False, ) -> None: """ @@ -446,7 +445,7 @@ def get_status(migration_key: MigrationKey) -> Status: else: migration_groups = get_active_migration_groups() - def get_in_progress_migration(group: MigrationGroup) -> Optional[MigrationKey]: + def get_in_progress_migration(group: MigrationGroup) -> MigrationKey | None: group_migrations = get_group_loader(group).get_migrations() for migration_id in group_migrations: migration_key = MigrationKey(group, migration_id) @@ -485,11 +484,11 @@ def _reverse_migration_impl( migration.backwards(context, dry_run) - def _get_pending_migrations(self) -> List[MigrationKey]: + def _get_pending_migrations(self) -> list[MigrationKey]: """ Gets pending migration list. """ - migrations: List[MigrationKey] = [] + migrations: list[MigrationKey] = [] for group in get_active_migration_groups(): group_migrations = self._get_pending_migrations_for_group(group) @@ -497,7 +496,7 @@ def _get_pending_migrations(self) -> List[MigrationKey]: return migrations - def _get_pending_migrations_for_group(self, group: MigrationGroup) -> List[MigrationKey]: + def _get_pending_migrations_for_group(self, group: MigrationGroup) -> list[MigrationKey]: """ Gets pending migrations list for a specific group """ @@ -507,7 +506,7 @@ def get_status(migration_key: MigrationKey) -> Status: return migration_status.get(migration_key, Status.NOT_STARTED) group_loader = get_group_loader(group) - group_migrations: List[MigrationKey] = [] + group_migrations: list[MigrationKey] = [] for migration_id in group_loader.get_migrations(): migration_key = MigrationKey(group, migration_id) @@ -523,13 +522,13 @@ def get_status(migration_key: MigrationKey) -> Status: return group_migrations - def _get_completed_migrations(self, groups: Sequence[MigrationGroup]) -> List[MigrationKey]: + def _get_completed_migrations(self, groups: Sequence[MigrationGroup]) -> list[MigrationKey]: """ Get a list of completed migrations for a list of groups """ migration_status = self._get_migration_status() - group_migrations: List[MigrationKey] = [] + group_migrations: list[MigrationKey] = [] for group in groups: group_loader = get_group_loader(group) completed_migrations = 0 @@ -539,7 +538,7 @@ def _get_completed_migrations(self, groups: Sequence[MigrationGroup]) -> List[Mi if status == Status.IN_PROGRESS: # can't reverse migrations if one is stuck pending raise MigrationInProgress(str(migration_key)) - elif status == Status.COMPLETED: + if status == Status.COMPLETED: group_migrations.append(migration_key) completed_migrations += 1 elif completed_migrations > 0: @@ -583,7 +582,7 @@ def _get_next_version(self, migration_key: MigrationKey) -> int: return 1 def _get_migration_status( - self, groups: Optional[Sequence[MigrationGroup]] = None + self, groups: Sequence[MigrationGroup] | None = None ) -> Mapping[MigrationKey, Status]: data: MutableMapping[MigrationKey, Status] = {} diff --git a/snuba/migrations/system_migrations/0001_migrations.py b/snuba/migrations/system_migrations/0001_migrations.py index affe7954793..e3dd5399ccb 100644 --- a/snuba/migrations/system_migrations/0001_migrations.py +++ b/snuba/migrations/system_migrations/0001_migrations.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import Column, DateTime, Enum, String, UInt from snuba.clusters.storage_sets import StorageSetKey diff --git a/snuba/migrations/table_engines.py b/snuba/migrations/table_engines.py index 7e4df2f6110..461204bdf4c 100644 --- a/snuba/migrations/table_engines.py +++ b/snuba/migrations/table_engines.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Mapping, Optional, Union +from collections.abc import Mapping from snuba import settings from snuba.clickhouse.escaping import escape_string @@ -37,11 +37,11 @@ def __init__( self, storage_set: StorageSetKey, order_by: str, - primary_key: Optional[str] = None, - partition_by: Optional[str] = None, - sample_by: Optional[str] = None, - ttl: Optional[str] = None, - settings: Optional[Mapping[str, Union[str, int]]] = None, + primary_key: str | None = None, + partition_by: str | None = None, + sample_by: str | None = None, + ttl: str | None = None, + settings: Mapping[str, str | int] | None = None, unsharded: bool = False, ) -> None: self._storage_set_value = storage_set.value @@ -90,9 +90,8 @@ def _get_zookeeper_path(self, cluster: ClickhouseCluster, table_name: str) -> st def _get_engine_type(self, cluster: ClickhouseCluster, table_name: str) -> str: if cluster.is_single_node(): return "MergeTree()" - else: - zoo_path = self._get_zookeeper_path(cluster, table_name) - return f"ReplicatedMergeTree({zoo_path}, '{{replica}}')" + zoo_path = self._get_zookeeper_path(cluster, table_name) + return f"ReplicatedMergeTree({zoo_path}, '{{replica}}')" class ReplacingMergeTree(MergeTree): @@ -100,12 +99,12 @@ def __init__( self, storage_set: StorageSetKey, order_by: str, - primary_key: Optional[str] = None, - version_column: Optional[str] = None, - partition_by: Optional[str] = None, - sample_by: Optional[str] = None, - ttl: Optional[str] = None, - settings: Optional[Mapping[str, str]] = None, + primary_key: str | None = None, + version_column: str | None = None, + partition_by: str | None = None, + sample_by: str | None = None, + ttl: str | None = None, + settings: Mapping[str, str] | None = None, unsharded: bool = False, ) -> None: super().__init__( @@ -125,11 +124,12 @@ def _get_engine_type(self, cluster: ClickhouseCluster, table_name: str) -> str: if self.__version_column: return f"ReplacingMergeTree({self.__version_column})" return "ReplacingMergeTree()" - else: - zoo_path = self._get_zookeeper_path(cluster, table_name) - if self.__version_column: - return f"ReplicatedReplacingMergeTree({zoo_path}, '{{replica}}', {self.__version_column})" - return f"ReplicatedReplacingMergeTree({zoo_path}, '{{replica}}')" + zoo_path = self._get_zookeeper_path(cluster, table_name) + if self.__version_column: + return ( + f"ReplicatedReplacingMergeTree({zoo_path}, '{{replica}}', {self.__version_column})" + ) + return f"ReplicatedReplacingMergeTree({zoo_path}, '{{replica}}')" class CollapsingMergeTree(MergeTree): @@ -137,12 +137,12 @@ def __init__( self, storage_set: StorageSetKey, order_by: str, - primary_key: Optional[str] = None, - sign_column: Optional[str] = None, - partition_by: Optional[str] = None, - sample_by: Optional[str] = None, - ttl: Optional[str] = None, - settings: Optional[Mapping[str, str]] = None, + primary_key: str | None = None, + sign_column: str | None = None, + partition_by: str | None = None, + sample_by: str | None = None, + ttl: str | None = None, + settings: Mapping[str, str] | None = None, unsharded: bool = False, ) -> None: super().__init__( @@ -160,31 +160,28 @@ def __init__( def _get_engine_type(self, cluster: ClickhouseCluster, table_name: str) -> str: if cluster.is_single_node(): return f"CollapsingMergeTree({self.__sign_column})" - else: - zoo_path = self._get_zookeeper_path(cluster, table_name) - return f"ReplicatedCollapsingMergeTree({zoo_path}, '{{replica}}', {self.__sign_column})" + zoo_path = self._get_zookeeper_path(cluster, table_name) + return f"ReplicatedCollapsingMergeTree({zoo_path}, '{{replica}}', {self.__sign_column})" class SummingMergeTree(MergeTree): def _get_engine_type(self, cluster: ClickhouseCluster, table_name: str) -> str: if cluster.is_single_node(): return "SummingMergeTree()" - else: - zoo_path = self._get_zookeeper_path(cluster, table_name) - return f"ReplicatedSummingMergeTree({zoo_path}, '{{replica}}')" + zoo_path = self._get_zookeeper_path(cluster, table_name) + return f"ReplicatedSummingMergeTree({zoo_path}, '{{replica}}')" class AggregatingMergeTree(MergeTree): def _get_engine_type(self, cluster: ClickhouseCluster, table_name: str) -> str: if cluster.is_single_node(): return "AggregatingMergeTree()" - else: - zoo_path = self._get_zookeeper_path(cluster, table_name) - return f"ReplicatedAggregatingMergeTree({zoo_path}, '{{replica}}')" + zoo_path = self._get_zookeeper_path(cluster, table_name) + return f"ReplicatedAggregatingMergeTree({zoo_path}, '{{replica}}')" class Distributed(TableEngine): - def __init__(self, local_table_name: str, sharding_key: Optional[str]) -> None: + def __init__(self, local_table_name: str, sharding_key: str | None) -> None: self.__local_table_name = local_table_name self.__sharding_key = sharding_key diff --git a/snuba/migrations/validator.py b/snuba/migrations/validator.py index 23a7a418687..d3ad72f02d9 100644 --- a/snuba/migrations/validator.py +++ b/snuba/migrations/validator.py @@ -1,5 +1,5 @@ import re -from typing import Sequence, Union +from collections.abc import Sequence from snuba.clusters.cluster import UndefinedClickhouseCluster, get_cluster from snuba.datasets.schemas.tables import TableSchema @@ -54,9 +54,9 @@ class InvalidDistributedOperation(Exception): def _conflicts_ops(local_op: SqlOperation, dist_op: SqlOperation) -> bool: if isinstance(local_op, CreateTable) and isinstance(dist_op, CreateTable): return conflicts_create_table_op(local_op, dist_op) - elif isinstance(local_op, AddColumn) and isinstance(dist_op, AddColumn): + if isinstance(local_op, AddColumn) and isinstance(dist_op, AddColumn): return conflicts_add_column_op(local_op, dist_op) - elif isinstance(local_op, DropColumn) and isinstance(dist_op, DropColumn): + if isinstance(local_op, DropColumn) and isinstance(dist_op, DropColumn): return conflicts_drop_column_op(local_op, dist_op) return False @@ -64,26 +64,28 @@ def _conflicts_ops(local_op: SqlOperation, dist_op: SqlOperation) -> bool: def _validate_add_col_or_create_table( local_op: SqlOperation, dist_ops: Sequence[SqlOperation] ) -> None: - if isinstance(local_op, (CreateTable, AddColumn)): - if any(_conflicts_ops(local_op, dist_op) for dist_op in dist_ops): - op_name = ( - f"{local_op.table_name}.{local_op.column.name}" - if isinstance(local_op, AddColumn) - else local_op.table_name - ) - raise InvalidMigrationOrderError( - f"{type(local_op).__name__} {op_name} operation " - "must be applied on local table before dist" - ) + if isinstance(local_op, (CreateTable, AddColumn)) and any( + _conflicts_ops(local_op, dist_op) for dist_op in dist_ops + ): + op_name = ( + f"{local_op.table_name}.{local_op.column.name}" + if isinstance(local_op, AddColumn) + else local_op.table_name + ) + raise InvalidMigrationOrderError( + f"{type(local_op).__name__} {op_name} operation " + "must be applied on local table before dist" + ) def _validate_drop_col(dist_op: SqlOperation, local_ops: Sequence[SqlOperation]) -> None: - if isinstance(dist_op, (DropColumn)): - if any(_conflicts_ops(local_op, dist_op) for local_op in local_ops): - raise InvalidMigrationOrderError( - f"{type(dist_op).__name__} {dist_op.table_name}.{dist_op.column_name} " - "operation must be applied on dist table before local" - ) + if isinstance(dist_op, DropColumn) and any( + _conflicts_ops(local_op, dist_op) for local_op in local_ops + ): + raise InvalidMigrationOrderError( + f"{type(dist_op).__name__} {dist_op.table_name}.{dist_op.column_name} " + "operation must be applied on dist table before local" + ) def _validate_order_old( @@ -112,12 +114,10 @@ def _validate_order_new( for i, op in enumerate(ops): local_ops_before = [op for op in ops[:i] if op.target != OperationTarget.DISTRIBUTED] dist_ops_before = [op for op in ops[:i] if op.target != OperationTarget.LOCAL] - if isinstance(op, (CreateTable, AddColumn)): - if op.target == OperationTarget.LOCAL: - _validate_add_col_or_create_table(op, dist_ops_before) - elif isinstance(op, DropColumn): - if op.target == OperationTarget.DISTRIBUTED: - _validate_drop_col(op, local_ops_before) + if isinstance(op, (CreateTable, AddColumn)) and op.target == OperationTarget.LOCAL: + _validate_add_col_or_create_table(op, dist_ops_before) + elif isinstance(op, DropColumn) and op.target == OperationTarget.DISTRIBUTED: + _validate_drop_col(op, local_ops_before) def validate_migration_order(migration: ClickhouseNodeMigration) -> None: @@ -205,7 +205,7 @@ def conflicts_drop_column_op(local_drop: DropColumn, dist_drop: DropColumn) -> b return False -def _get_local_table_name(dist_op: Union[CreateTable, AddColumn, DropColumn]) -> str: +def _get_local_table_name(dist_op: CreateTable | AddColumn | DropColumn) -> str: """ Returns the local table name for a distributed table. """ @@ -215,12 +215,14 @@ def _get_local_table_name(dist_op: Union[CreateTable, AddColumn, DropColumn]) -> if storage.get_storage_set_key() != dist_op._storage_set: continue schema = storage.get_schema() - if isinstance(schema, TableSchema): - # In local mode we want to verify that the pairing of - # dist/local tables is correct, so using get_dist_table_name - # instead of get_table_name here - if schema.get_dist_table_name() == dist_op.table_name: - return schema.get_local_table_name() + # In local mode we want to verify that the pairing of + # dist/local tables is correct, so using get_dist_table_name + # instead of get_table_name here + if ( + isinstance(schema, TableSchema) + and schema.get_dist_table_name() == dist_op.table_name + ): + return schema.get_local_table_name() except UndefinedClickhouseCluster: continue raise InvalidDistributedOperation( diff --git a/snuba/pipeline/composite_entity_processing.py b/snuba/pipeline/composite_entity_processing.py index dbfbceb97ae..336e32af63c 100644 --- a/snuba/pipeline/composite_entity_processing.py +++ b/snuba/pipeline/composite_entity_processing.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Mapping, Union +from collections.abc import Mapping from snuba.clickhouse.query import Query as ClickhouseQuery from snuba.datasets.plans.entity_processing import run_entity_processing_executor @@ -38,11 +38,12 @@ def _pre_entity_query_processing(query: CompositeQuery[Entity]) -> None: if isinstance(from_clause, JoinClause): nodes = from_clause.get_alias_node_map() for node in nodes.values(): - if isinstance(node.data_source, Entity): - if not node.data_source.key.value.startswith( - "generic_metrics" - ) and not node.data_source.key.value.startswith("metrics"): - is_gen_metrics_join_query = False + if ( + isinstance(node.data_source, Entity) + and not node.data_source.key.value.startswith("generic_metrics") + and not node.data_source.key.value.startswith("metrics") + ): + is_gen_metrics_join_query = False else: is_gen_metrics_join_query = False @@ -81,11 +82,7 @@ def _translate_logical_composite_query( class CompositeDataSourceTransformer( DataSourceVisitor[ - Union[ - ProcessableQuery[Table], - CompositeQuery[Table], - JoinClause[Table], - ], + ProcessableQuery[Table] | CompositeQuery[Table] | JoinClause[Table], Entity, ] ): @@ -103,7 +100,7 @@ def _visit_simple_source(self, data_source: Entity) -> ClickhouseQuery: def _visit_join( self, data_source: JoinClause[Entity] - ) -> Union[ProcessableQuery[Table], CompositeQuery[Table], JoinClause[Table]]: + ) -> ProcessableQuery[Table] | CompositeQuery[Table] | JoinClause[Table]: alias_to_query_mappings = data_source.accept(JoinQueryVisitor(self.__settings)) check_sub_query_storage_sets(alias_to_query_mappings) @@ -114,7 +111,7 @@ def _visit_join( def _visit_simple_query( self, data_source: ProcessableQuery[Entity] - ) -> Union[ProcessableQuery[Table], CompositeQuery[Table], JoinClause[Table]]: + ) -> ProcessableQuery[Table] | CompositeQuery[Table] | JoinClause[Table]: assert isinstance(data_source, LogicalQuery), ( f"Only subqueries are allowed at query planning stage. {type(data_source)} found." ) @@ -123,7 +120,7 @@ def _visit_simple_query( def _visit_composite_query( self, data_source: CompositeQuery[Entity] - ) -> Union[ProcessableQuery[Table], CompositeQuery[Table], JoinClause[Table]]: + ) -> ProcessableQuery[Table] | CompositeQuery[Table] | JoinClause[Table]: return _translate_logical_composite_query(data_source, self.__settings) @@ -150,9 +147,7 @@ def visit_join_clause(self, node: JoinClause[Entity]) -> Mapping[str, Clickhouse } -class JoinDataSourceTransformer( - JoinVisitor[Union[JoinClause[Table], IndividualNode[Table]], Entity] -): +class JoinDataSourceTransformer(JoinVisitor[JoinClause[Table] | IndividualNode[Table], Entity]): """ A visitor class responsible for producing a join data source. """ @@ -167,7 +162,7 @@ def __init__( def visit_individual_node( self, node: IndividualNode[Entity] - ) -> Union[JoinClause[Table], IndividualNode[Table]]: + ) -> JoinClause[Table] | IndividualNode[Table]: assert isinstance(node.data_source, ProcessableQuery), ( "Invalid composite query. All nodes must be subqueries." ) @@ -177,7 +172,7 @@ def visit_individual_node( def visit_join_clause( self, node: JoinClause[Entity] - ) -> Union[JoinClause[Table], IndividualNode[Table]]: + ) -> JoinClause[Table] | IndividualNode[Table]: left_node = node.left_node.accept(self) right_node = self.visit_individual_node(node.right_node) diff --git a/snuba/pipeline/composite_storage_processing.py b/snuba/pipeline/composite_storage_processing.py index d9deb75d0c5..f11f1989d73 100644 --- a/snuba/pipeline/composite_storage_processing.py +++ b/snuba/pipeline/composite_storage_processing.py @@ -1,6 +1,7 @@ from __future__ import annotations -from typing import Mapping, NamedTuple, Sequence, Union +from collections.abc import Mapping, Sequence +from typing import NamedTuple import sentry_sdk @@ -187,7 +188,7 @@ class JoinDataSourcePlan(NamedTuple): an execution strategy that is not used. """ - translated_source: Union[JoinClause[Table], IndividualNode[Table]] + translated_source: JoinClause[Table] | IndividualNode[Table] processors: Mapping[str, SubqueryProcessors] storage_set_key: StorageSetKey diff --git a/snuba/pipeline/query_pipeline.py b/snuba/pipeline/query_pipeline.py index a0a496c9709..d734ca983fb 100644 --- a/snuba/pipeline/query_pipeline.py +++ b/snuba/pipeline/query_pipeline.py @@ -3,7 +3,7 @@ import logging from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Generic, Optional, TypeVar, Union, cast +from typing import Generic, TypeVar, cast from snuba.query.query_settings import QuerySettings from snuba.utils.metrics.timer import Timer @@ -46,7 +46,7 @@ class QueryPipelineStage(Generic[Tin, Tout]): >>> print("PhysicalQuery: ", stage_2.data) """ - def _process_error(self, pipe_input: QueryPipelineError[Tin]) -> Union[Tout, Exception]: + def _process_error(self, pipe_input: QueryPipelineError[Tin]) -> Tout | Exception: """default behaviour is to just pass through to the next stage of the pipeline Can be overridden to do something else""" logging.exception(pipe_input.error) @@ -66,13 +66,12 @@ def execute(self, pipe_input: QueryPipelineResult[Tin]) -> QueryPipelineResult[T error=res, timer=pipe_input.timer, ) - else: - return QueryPipelineResult( - data=res, - query_settings=pipe_input.query_settings, - error=None, - timer=pipe_input.timer, - ) + return QueryPipelineResult( + data=res, + query_settings=pipe_input.query_settings, + error=None, + timer=pipe_input.timer, + ) try: return QueryPipelineResult( data=self._process_data(pipe_input.as_data()), @@ -99,8 +98,8 @@ class QueryPipelineResult(ABC, Generic[T]): A container to represent the result of a query pipeline stage. """ - data: Optional[T] - error: Optional[Exception] + data: T | None + error: Exception | None query_settings: QuerySettings timer: Timer diff --git a/snuba/pipeline/stages/query_execution.py b/snuba/pipeline/stages/query_execution.py index b7b200984ea..1bedabdb052 100644 --- a/snuba/pipeline/stages/query_execution.py +++ b/snuba/pipeline/stages/query_execution.py @@ -3,9 +3,10 @@ import logging import textwrap from collections import defaultdict +from collections.abc import MutableMapping from dataclasses import replace from math import floor -from typing import Any, MutableMapping, Optional +from typing import Any import sentry_sdk @@ -57,7 +58,7 @@ def __init__( attribution_info: AttributionInfo, query_metadata: SnubaQueryMetadata, robust: bool = False, - concurrent_queries_gauge: Optional[Gauge] = None, + concurrent_queries_gauge: Gauge | None = None, ): self._attribution_info = attribution_info self._query_metadata = query_metadata @@ -86,18 +87,17 @@ def _process_data( cluster, "get_clickhouse_cluster_name", lambda: "no_cluster_name" )(), ) - else: - return _run_and_apply_column_names( - timer=pipe_input.timer, - query_metadata=self._query_metadata, - attribution_info=self._attribution_info, - robust=self._robust, - concurrent_queries_gauge=None, - clickhouse_query=pipe_input.data, - query_settings=pipe_input.query_settings, - reader=cluster.get_reader(), - cluster_name=cluster.get_clickhouse_cluster_name() or "", - ) + return _run_and_apply_column_names( + timer=pipe_input.timer, + query_metadata=self._query_metadata, + attribution_info=self._attribution_info, + robust=self._robust, + concurrent_queries_gauge=None, + clickhouse_query=pipe_input.data, + query_settings=pipe_input.query_settings, + reader=cluster.get_reader(), + cluster_name=cluster.get_clickhouse_cluster_name() or "", + ) def _dry_run_query_runner( @@ -123,7 +123,7 @@ def _run_and_apply_column_names( query_metadata: SnubaQueryMetadata, attribution_info: AttributionInfo, robust: bool, - concurrent_queries_gauge: Optional[Gauge], + concurrent_queries_gauge: Gauge | None, clickhouse_query: ClickhouseQuery | CompositeQuery[Table], query_settings: QuerySettings, reader: Reader, @@ -185,7 +185,7 @@ def _format_storage_query_and_run( query_settings: QuerySettings, reader: Reader, robust: bool, - concurrent_queries_gauge: Optional[Gauge] = None, + concurrent_queries_gauge: Gauge | None = None, cluster_name: str = "", ) -> QueryResult: """ @@ -252,7 +252,7 @@ def _format_storage_query_and_run( cause.__class__.__name__, str(cause), extra=QueryExtraData( - stats=stats, + stats=dict(stats), sql=formatted_sql, experiments=clickhouse_query.get_experiments(), ), @@ -300,9 +300,8 @@ def get_query_size_group(query_size_bytes: int) -> str: """ if query_size_bytes == _max_query_size_bytes(): return "100%" - else: - query_size_group = int(floor(query_size_bytes / _max_query_size_bytes() * 10)) * 10 - return f">={query_size_group}%" + query_size_group = int(floor(query_size_bytes / _max_query_size_bytes() * 10)) * 10 + return f">={query_size_group}%" def _apply_turbo_sampling_if_needed( @@ -313,11 +312,14 @@ def _apply_turbo_sampling_if_needed( TODO: Remove this method entirely and move the sampling logic into a query processor. """ - if isinstance(clickhouse_query, ClickhouseQuery): - if query_settings.get_turbo() and not clickhouse_query.get_from_clause().sampling_rate: - clickhouse_query.set_from_clause( - replace( - clickhouse_query.get_from_clause(), - sampling_rate=snuba_settings.TURBO_SAMPLE_RATE, - ) + if ( + isinstance(clickhouse_query, ClickhouseQuery) + and query_settings.get_turbo() + and not clickhouse_query.get_from_clause().sampling_rate + ): + clickhouse_query.set_from_clause( + replace( + clickhouse_query.get_from_clause(), + sampling_rate=snuba_settings.TURBO_SAMPLE_RATE, ) + ) diff --git a/snuba/pipeline/stages/query_processing.py b/snuba/pipeline/stages/query_processing.py index 909bc3279f8..6d42adec3de 100644 --- a/snuba/pipeline/stages/query_processing.py +++ b/snuba/pipeline/stages/query_processing.py @@ -35,15 +35,14 @@ def _process_data( run_entity_validators(cast(EntityQuery, query), pipe_input.query_settings) if isinstance(query, LogicalQuery) and isinstance(query.get_from_clause(), Entity): return run_entity_processing_executor(query, pipe_input.query_settings) - elif isinstance(query, CompositeQuery): + if isinstance(query, CompositeQuery): # if we were not able to translate the storage query earlier and we got to this point, this is # definitely a composite entity query return translate_composite_query( cast(CompositeQuery[Entity], query), pipe_input.query_settings, ) - else: - raise NotImplementedError(f"Unknown query type {type(query)}, {query}") + raise NotImplementedError(f"Unknown query type {type(query)}, {query}") class StorageProcessingStage( @@ -68,10 +67,7 @@ def _process_data( if isinstance(pipe_input.data, ClickhouseQuery): query_plan = build_best_plan(pipe_input.data, pipe_input.query_settings, []) return apply_storage_processors(query_plan, pipe_input.query_settings) - else: - composite_query_plan = build_best_plan_for_composite_query( - pipe_input.data, pipe_input.query_settings, [] - ) - return apply_composite_storage_processors( - composite_query_plan, pipe_input.query_settings - ) + composite_query_plan = build_best_plan_for_composite_query( + pipe_input.data, pipe_input.query_settings, [] + ) + return apply_composite_storage_processors(composite_query_plan, pipe_input.query_settings) diff --git a/snuba/pipeline/storage_query_identity_translate.py b/snuba/pipeline/storage_query_identity_translate.py index e0ddf74a87f..963165a23fe 100644 --- a/snuba/pipeline/storage_query_identity_translate.py +++ b/snuba/pipeline/storage_query_identity_translate.py @@ -21,16 +21,14 @@ def try_translate_storage_query( if finder.is_storage_query: if not finder.has_join: return _translate_storage_query(query) - else: - raise InvalidQueryException("Joins not supported for storage queries") - elif finder.is_mixed_data_source_query: + raise InvalidQueryException("Joins not supported for storage queries") + if finder.is_mixed_data_source_query: raise InvalidQueryException( "Queries on storages and entities are not supported", data_sources=[str(d) for d in finder.simple_data_sources], ) - else: - # this is not a storage query, pass this on to the rest of the pipeline - return None + # this is not a storage query, pass this on to the rest of the pipeline + return None def _translate_storage_query( @@ -51,21 +49,20 @@ def _translate_storage_query( ) ) return res - else: - assert isinstance(from_clause, (LogicalQuery, CompositeQuery)) - return CompositeQuery( - from_clause=_translate_storage_query(from_clause), - selected_columns=query.get_selected_columns(), - array_join=query.get_arrayjoin(), - condition=query.get_condition(), - groupby=query.get_groupby(), - having=query.get_having(), - order_by=query.get_orderby(), - limitby=query.get_limitby(), - offset=query.get_offset(), - totals=query.has_totals(), - granularity=query.get_granularity(), - ) + assert isinstance(from_clause, (LogicalQuery, CompositeQuery)) + return CompositeQuery( + from_clause=_translate_storage_query(from_clause), + selected_columns=query.get_selected_columns(), + array_join=query.get_arrayjoin(), + condition=query.get_condition(), + groupby=query.get_groupby(), + having=query.get_having(), + order_by=query.get_orderby(), + limitby=query.get_limitby(), + offset=query.get_offset(), + totals=query.has_totals(), + granularity=query.get_granularity(), + ) class _LogicalDataSourceFinder( @@ -93,12 +90,12 @@ class _LogicalDataSourceFinder( @property def is_storage_query(self) -> bool: - return all((isinstance(d, Storage) for d in self.simple_data_sources)) + return all(isinstance(d, Storage) for d in self.simple_data_sources) @property def is_mixed_data_source_query(self) -> bool: - return any((isinstance(d, Storage) for d in self.simple_data_sources)) and any( - (isinstance(d, Entity) for d in self.simple_data_sources) + return any(isinstance(d, Storage) for d in self.simple_data_sources) and any( + isinstance(d, Entity) for d in self.simple_data_sources ) def __init__(self) -> None: diff --git a/snuba/pipeline/utils/storage_finder.py b/snuba/pipeline/utils/storage_finder.py index 336d3e24275..ed8eda046c2 100644 --- a/snuba/pipeline/utils/storage_finder.py +++ b/snuba/pipeline/utils/storage_finder.py @@ -19,8 +19,7 @@ class StorageKeyJoinFinder(JoinVisitor[StorageKey, Table]): def visit_individual_node(self, node: IndividualNode[Table]) -> StorageKey: if isinstance(node.data_source, ProcessableQuery): return node.data_source.get_from_clause().storage_key - else: - return node.data_source.storage_key + return node.data_source.storage_key def visit_join_clause(self, node: JoinClause[Table]) -> StorageKey: left_storage_key = node.left_node.accept(self) diff --git a/snuba/processor.py b/snuba/processor.py index fff58092614..3022ddaf52c 100644 --- a/snuba/processor.py +++ b/snuba/processor.py @@ -1,22 +1,15 @@ import ipaddress import re from abc import ABC, abstractmethod +from collections.abc import Iterable, MutableMapping, Sequence from datetime import datetime, timedelta from enum import Enum from hashlib import md5 from typing import ( Any, - Dict, - FrozenSet, - Iterable, - MutableMapping, NamedTuple, - Optional, - Sequence, - Tuple, TypedDict, TypeVar, - Union, ) import simplejson as json @@ -35,10 +28,10 @@ class InsertBatch(NamedTuple): rows: Sequence[WriterTableRow] # origin_timestamp is the timestamp of the event when was received by Relay - origin_timestamp: Optional[datetime] + origin_timestamp: datetime | None # sentry_received_timestamp is the timestamp of the event when received by the ingest # consumer in Sentry - sentry_received_timestamp: Optional[datetime] = None + sentry_received_timestamp: datetime | None = None # Indicates that we need an encoder that will interpolate @@ -53,7 +46,7 @@ class ReplacementBatch(NamedTuple): values: Sequence[Any] -ProcessedMessage = Union[InsertBatch, ReplacementBatch] +ProcessedMessage = InsertBatch | ReplacementBatch class MessageProcessor(ABC): @@ -65,7 +58,7 @@ class MessageProcessor(ABC): @abstractmethod def process_message( self, message: Any, metadata: KafkaMessageMetadata - ) -> Optional[ProcessedMessage]: + ) -> ProcessedMessage | None: raise NotImplementedError @@ -91,13 +84,13 @@ class ReplacementType(str, Enum): EXCLUDE_GROUPS = "exclude_groups" -REPLACEMENT_EVENT_TYPES: FrozenSet[ReplacementType] = frozenset( +REPLACEMENT_EVENT_TYPES: frozenset[ReplacementType] = frozenset( ReplacementType.__members__.values() ) class InsertEvent(TypedDict): - group_id: Optional[int] + group_id: int | None event_id: str organization_id: int project_id: int @@ -114,7 +107,7 @@ class InsertEvent(TypedDict): def _as_dict_safe( - value: Union[None, Iterable[Optional[Tuple[Optional[TKey], TValue]]], Dict[TKey, TValue]], + value: None | Iterable[tuple[TKey | None, TValue] | None] | dict[TKey, TValue], ) -> MutableMapping[TKey, TValue]: if value is None: return {} @@ -127,7 +120,7 @@ def _as_dict_safe( return rv -def _collapse_uint16(n: Any) -> Optional[int]: +def _collapse_uint16(n: Any) -> int | None: if n is None: return None @@ -138,7 +131,7 @@ def _collapse_uint16(n: Any) -> Optional[int]: return i -def _collapse_uint32(n: Any) -> Optional[int]: +def _collapse_uint32(n: Any) -> int | None: if n is None: return None @@ -149,7 +142,7 @@ def _collapse_uint32(n: Any) -> Optional[int]: return i -def _boolify(s: Any) -> Optional[bool]: +def _boolify(s: Any) -> bool | None: if s is None: return None @@ -160,17 +153,17 @@ def _boolify(s: Any) -> Optional[bool]: if s in ("yes", "true", "1"): return True - elif s in ("false", "no", "0"): + if s in ("false", "no", "0"): return False return None -def _unicodify(s: Any) -> Optional[str]: +def _unicodify(s: Any) -> str | None: if s is None: return None - if isinstance(s, dict) or isinstance(s, list): + if isinstance(s, (dict, list)): return json.dumps(s) return str(s).encode("utf8", errors="backslashreplace").decode("utf8") @@ -185,7 +178,7 @@ def _hashify(h: str) -> str: epoch = datetime(1970, 1, 1) -def _ensure_valid_date(dt: Optional[datetime]) -> Optional[datetime]: +def _ensure_valid_date(dt: datetime | None) -> datetime | None: if dt is None: return None seconds = (dt - epoch).total_seconds() @@ -196,7 +189,7 @@ def _ensure_valid_date(dt: Optional[datetime]) -> Optional[datetime]: def _ensure_valid_ip( ip: Any, -) -> Optional[Union[ipaddress.IPv4Address, ipaddress.IPv6Address]]: +) -> ipaddress.IPv4Address | ipaddress.IPv6Address | None: """ IP addresses in e.g. `user.ip_address` might be invalid due to PII stripping. """ diff --git a/snuba/protos/common.py b/snuba/protos/common.py index 3ae0b3e0fb6..28046f0d7b1 100644 --- a/snuba/protos/common.py +++ b/snuba/protos/common.py @@ -1,5 +1,6 @@ from collections import defaultdict -from typing import Final, Mapping, Sequence +from collections.abc import Mapping, Sequence +from typing import Final from sentry_conventions.attributes import ATTRIBUTE_METADATA from sentry_protos.snuba.v1.trace_item_attribute_pb2 import AttributeKey @@ -246,12 +247,11 @@ def attribute_key_to_expression(attr_key: AttributeKey) -> Expression: *expressions, alias=alias, ) - else: - return _generate_subscriptable_reference( - attr_key.name, - attr_key.type, - alias, - ) + return _generate_subscriptable_reference( + attr_key.name, + attr_key.type, + alias, + ) if attr_key.type == AttributeKey.Type.TYPE_ARRAY: # Tagged array under attributes_array.* as Array(JSON). Select toJSONString(...) diff --git a/snuba/query/__init__.py b/snuba/query/__init__.py index 4df6e3551b8..cff941e6b91 100644 --- a/snuba/query/__init__.py +++ b/snuba/query/__init__.py @@ -1,20 +1,13 @@ from __future__ import annotations from abc import ABC, abstractmethod +from collections.abc import Callable, Iterable, MutableMapping, Sequence from dataclasses import dataclass, replace from enum import Enum from itertools import chain from typing import Any as AnyType from typing import ( - Callable, Generic, - Iterable, - MutableMapping, - Optional, - Sequence, - Set, - Tuple, - Type, TypeVar, cast, ) @@ -56,7 +49,7 @@ class OrderBy: class SelectedExpression: # The name of this column in the resultset. # TODO: Make this non nullable - name: Optional[str] + name: str | None expression: Expression @@ -94,19 +87,19 @@ def __init__( # TODO: Consider if to remove the defaults and make some of # these fields mandatory. This impacts a lot of code so it # would be done on its own. - selected_columns: Optional[Sequence[SelectedExpression]] = None, - array_join: Optional[Sequence[Expression]] = None, - condition: Optional[Expression] = None, - groupby: Optional[Sequence[Expression]] = None, - having: Optional[Expression] = None, - order_by: Optional[Sequence[OrderBy]] = None, - limitby: Optional[LimitBy] = None, - limit: Optional[int] = None, + selected_columns: Sequence[SelectedExpression] | None = None, + array_join: Sequence[Expression] | None = None, + condition: Expression | None = None, + groupby: Sequence[Expression] | None = None, + having: Expression | None = None, + order_by: Sequence[OrderBy] | None = None, + limitby: LimitBy | None = None, + limit: int | None = None, offset: int = 0, totals: bool = False, - granularity: Optional[int] = None, - experiments: Optional[MutableMapping[str, AnyType]] = None, - on_cluster: Optional[Expression] = None, + granularity: int | None = None, + experiments: MutableMapping[str, AnyType] | None = None, + on_cluster: Expression | None = None, is_delete: bool = False, ): self.__selected_columns = selected_columns or [] @@ -165,10 +158,10 @@ def get_groupby(self) -> Sequence[Expression]: def set_ast_groupby(self, groupby: Sequence[Expression]) -> None: self.__groupby = groupby - def get_condition(self) -> Optional[Expression]: + def get_condition(self) -> Expression | None: return self.__condition - def set_ast_condition(self, condition: Optional[Expression]) -> None: + def set_ast_condition(self, condition: Expression | None) -> None: self.__condition = condition def add_condition_to_ast(self, condition: Expression) -> None: @@ -177,16 +170,16 @@ def add_condition_to_ast(self, condition: Expression) -> None: else: self.__condition = binary_condition(BooleanFunctions.AND, condition, self.__condition) - def get_arrayjoin(self) -> Optional[Sequence[Expression]]: + def get_arrayjoin(self) -> Sequence[Expression] | None: return self.__array_join - def set_arrayjoin(self, arrayjoin: Optional[Sequence[Expression]]) -> None: + def set_arrayjoin(self, arrayjoin: Sequence[Expression] | None) -> None: self.__array_join = arrayjoin - def get_having(self) -> Optional[Expression]: + def get_having(self) -> Expression | None: return self.__having - def set_ast_having(self, condition: Optional[Expression]) -> None: + def set_ast_having(self, condition: Expression | None) -> None: self.__having = condition def get_orderby(self) -> Sequence[OrderBy]: @@ -195,13 +188,13 @@ def get_orderby(self) -> Sequence[OrderBy]: def set_ast_orderby(self, orderby: Sequence[OrderBy]) -> None: self.__order_by = orderby - def get_limitby(self) -> Optional[LimitBy]: + def get_limitby(self) -> LimitBy | None: return self.__limitby def set_limitby(self, limitby: LimitBy) -> None: self.__limitby = limitby - def get_limit(self) -> Optional[int]: + def get_limit(self) -> int | None: return self.__limit def set_limit(self, limit: int) -> None: @@ -222,7 +215,7 @@ def has_totals(self) -> bool: def set_granularity(self, granularity: int) -> None: self.__granularity = granularity - def get_granularity(self) -> Optional[int]: + def get_granularity(self) -> int | None: return self.__granularity def add_experiment(self, name: str, value: AnyType) -> None: @@ -240,7 +233,7 @@ def get_experiment_value(self, name: str) -> AnyType: def is_delete(self) -> bool: return self.__is_delete - def get_on_cluster(self) -> Optional[Expression]: + def get_on_cluster(self) -> Expression | None: return self.__on_cluster @abstractmethod @@ -259,12 +252,12 @@ def get_all_expressions(self) -> Iterable[Expression]: deduplicate any of the expressions found. """ return chain( - chain.from_iterable(map(lambda selected: selected.expression, self.__selected_columns)), + chain.from_iterable(selected.expression for selected in self.__selected_columns), self.__array_join or [], self.__condition or [], chain.from_iterable(self.__groupby), self.__having or [], - chain.from_iterable(map(lambda orderby: orderby.expression, self.__order_by)), + chain.from_iterable(orderby.expression for orderby in self.__order_by), self.__limitby.columns if self.__limitby else [], self._get_expressions_impl(), ) @@ -300,33 +293,24 @@ def transform_expressions( def transform_expression_list( expressions: Sequence[Expression], ) -> Sequence[Expression]: - return list( - map(lambda exp: exp.transform(func), expressions), - ) + return [exp.transform(func) for exp in expressions] - self.__selected_columns = list( - map( - lambda selected: replace(selected, expression=selected.expression.transform(func)), - self.__selected_columns, - ) - ) - if not skip_array_join: - if self.__array_join: - self.__array_join = [ - join_element.transform(func) for join_element in self.__array_join - ] + self.__selected_columns = [ + replace(selected, expression=selected.expression.transform(func)) + for selected in self.__selected_columns + ] + if not skip_array_join and self.__array_join: + self.__array_join = [join_element.transform(func) for join_element in self.__array_join] if not skip_transform_condition: self.__condition = self.__condition.transform(func) if self.__condition else None self.__groupby = transform_expression_list(self.__groupby) self.__having = self.__having.transform(func) if self.__having else None if not skip_transform_order_by: - self.__order_by = list( - map( - lambda clause: replace(clause, expression=clause.expression.transform(func)), - self.__order_by, - ) - ) + self.__order_by = [ + replace(clause, expression=clause.expression.transform(func)) + for clause in self.__order_by + ] if self.__limitby is not None: self.__limitby = LimitBy( @@ -355,12 +339,10 @@ def transform(self, visitor: ExpressionVisitor[Expression]) -> None: The transformation happens in place. """ - self.__selected_columns = list( - map( - lambda selected: replace(selected, expression=selected.expression.accept(visitor)), - self.__selected_columns, - ) - ) + self.__selected_columns = [ + replace(selected, expression=selected.expression.accept(visitor)) + for selected in self.__selected_columns + ] if self.__array_join is not None: self.__array_join = [join_element.accept(visitor) for join_element in self.__array_join] if self.__condition is not None: @@ -368,12 +350,10 @@ def transform(self, visitor: ExpressionVisitor[Expression]) -> None: self.__groupby = [e.accept(visitor) for e in (self.__groupby or [])] if self.__having is not None: self.__having = self.__having.accept(visitor) - self.__order_by = list( - map( - lambda clause: replace(clause, expression=clause.expression.accept(visitor)), - self.__order_by, - ) - ) + self.__order_by = [ + replace(clause, expression=clause.expression.accept(visitor)) + for clause in self.__order_by + ] if self.__limitby is not None: self.__limitby = LimitBy( self.__limitby.limit, @@ -382,27 +362,27 @@ def transform(self, visitor: ExpressionVisitor[Expression]) -> None: self._transform_impl(visitor) def __get_all_ast_referenced_expressions( - self, expressions: Iterable[Expression], exp_type: Type[TExp] - ) -> Set[TExp]: - ret: Set[TExp] = set() + self, expressions: Iterable[Expression], exp_type: type[TExp] + ) -> set[TExp]: + ret: set[TExp] = set() for expression in expressions: ret |= {c for c in expression if isinstance(c, exp_type)} return ret - def get_all_ast_referenced_columns(self) -> Set[Column]: + def get_all_ast_referenced_columns(self) -> set[Column]: return self.__get_all_ast_referenced_expressions(self.get_all_expressions(), Column) - def get_all_ast_referenced_subscripts(self) -> Set[SubscriptableReference]: + def get_all_ast_referenced_subscripts(self) -> set[SubscriptableReference]: return self.__get_all_ast_referenced_expressions( self.get_all_expressions(), SubscriptableReference ) - def get_columns_referenced_in_conditions_ast(self) -> Set[Column]: + def get_columns_referenced_in_conditions_ast(self) -> set[Column]: return self.__get_all_ast_referenced_expressions( [self.__condition] if self.__condition is not None else [], Column ) - def get_columns_referenced_in_select(self) -> Set[Column]: + def get_columns_referenced_in_select(self) -> set[Column]: return self.__get_all_ast_referenced_expressions( [selected.expression for selected in self.__selected_columns], Column ) @@ -417,8 +397,8 @@ def validate_aliases(self) -> bool: Caution: for this to work, data_source needs to be already populated, otherwise it would throw. """ - declared_symbols: Set[str] = set() - referenced_symbols: Set[str] = set() + declared_symbols: set[str] = set() + referenced_symbols: set[str] = set() for e in self.get_all_expressions(): # SELECT f(g(x)) as A -> declared_symbols = {A} # SELECT a as B -> declared_symbols = {B} referenced_symbols = {a} @@ -455,7 +435,7 @@ def _eq_functions(self) -> Sequence[str]: "get_granularity", ) - def equals(self, other: object) -> Tuple[bool, str]: + def equals(self, other: object) -> tuple[bool, str]: if self.__class__ != other.__class__: return False, f"{self.__class__} != {other.__class__}" @@ -484,23 +464,23 @@ class ProcessableQuery(Query, ABC, Generic[TSimpleDataSource]): def __init__( self, - from_clause: Optional[TSimpleDataSource], + from_clause: TSimpleDataSource | None, # TODO: Consider if to remove the defaults and make some of # these fields mandatory. This impacts a lot of code so it # would be done on its own. - selected_columns: Optional[Sequence[SelectedExpression]] = None, - array_join: Optional[Sequence[Expression]] = None, - condition: Optional[Expression] = None, - groupby: Optional[Sequence[Expression]] = None, - having: Optional[Expression] = None, - order_by: Optional[Sequence[OrderBy]] = None, - limitby: Optional[LimitBy] = None, - limit: Optional[int] = None, + selected_columns: Sequence[SelectedExpression] | None = None, + array_join: Sequence[Expression] | None = None, + condition: Expression | None = None, + groupby: Sequence[Expression] | None = None, + having: Expression | None = None, + order_by: Sequence[OrderBy] | None = None, + limitby: LimitBy | None = None, + limit: int | None = None, offset: int = 0, totals: bool = False, - granularity: Optional[int] = None, + granularity: int | None = None, is_delete: bool = False, - on_cluster: Optional[Expression] = None, + on_cluster: Expression | None = None, ): super().__init__( selected_columns=selected_columns, diff --git a/snuba/query/accessors.py b/snuba/query/accessors.py index 5eeccd0bac9..e4f4ac330e3 100644 --- a/snuba/query/accessors.py +++ b/snuba/query/accessors.py @@ -1,10 +1,8 @@ -from typing import Set - from snuba.query.expressions import Column, Expression -def get_columns_in_expression(exp: Expression) -> Set[Column]: +def get_columns_in_expression(exp: Expression) -> set[Column]: """ Returns all the columns referenced in an arbitrary AST expression. """ - return set(e for e in exp if isinstance(e, Column)) + return {e for e in exp if isinstance(e, Column)} diff --git a/snuba/query/allocation_policies/__init__.py b/snuba/query/allocation_policies/__init__.py index ee39e0bf62d..d5868154c2c 100644 --- a/snuba/query/allocation_policies/__init__.py +++ b/snuba/query/allocation_policies/__init__.py @@ -106,7 +106,7 @@ def from_args( tenant_ids: dict[str, str | int], policy_name: str, description: str | None = None, - ) -> "InvalidTenantsForAllocationPolicy": + ) -> InvalidTenantsForAllocationPolicy: return cls( description or "Invalid tenants for allocation policy", tenant_ids=tenant_ids, @@ -126,7 +126,7 @@ def __str__(self) -> str: @property def violations(self) -> dict[str, dict[str, Any]]: details = cast(dict[str, Any], self.quota_allowance.get("details")) - return {k: v for k, v in details.items() if v["can_run"] == False} + return {k: v for k, v in details.items() if not v["can_run"]} @property def quota_allowance(self) -> dict[str, dict[str, Any]]: @@ -140,7 +140,7 @@ def summary(self) -> dict[str, Any]: def from_args( cls, quota_allowances: dict[str, Any], - ) -> "AllocationPolicyViolations": + ) -> AllocationPolicyViolations: return cls( "Query on could not be run due to allocation policies", quota_allowances=quota_allowances, @@ -375,7 +375,7 @@ def __init__( ) @classmethod - def create_minimal_instance(cls, resource_identifier: str) -> "ConfigurableComponent": + def create_minimal_instance(cls, resource_identifier: str) -> ConfigurableComponent: return cls( storage_key=ResourceIdentifier(resource_identifier), required_tenant_types=[], @@ -425,7 +425,7 @@ def is_cross_org_query(self, tenant_ids: dict[str, str | int]) -> bool: return bool(tenant_ids.get("cross_org_query", False)) @classmethod - def from_kwargs(cls, **kwargs: str) -> "AllocationPolicy": + def from_kwargs(cls, **kwargs: str) -> AllocationPolicy: required_tenant_types = kwargs.pop("required_tenant_types", None) storage_key = kwargs.pop("storage_key", None) default_config_overrides: dict[str, Any] = cast( @@ -548,7 +548,7 @@ def update_quota_balance( ) -> None: try: if not self.is_active: - return + return None return self._update_quota_balance(tenant_ids, query_id, result_or_error) except InvalidTenantsForAllocationPolicy: # the policy did not do anything because the tenants were invalid, updating is also not necessary @@ -584,7 +584,7 @@ def query_type(self) -> QueryType: def to_dict(self) -> PolicyData: base_data = super().to_dict() - return PolicyData(**base_data, query_type=self.query_type.value) # type: ignore + return PolicyData(**base_data, query_type=self.query_type.value) class PassthroughPolicy(AllocationPolicy): diff --git a/snuba/query/allocation_policies/bytes_scanned_rejecting_policy.py b/snuba/query/allocation_policies/bytes_scanned_rejecting_policy.py index 5d099a1b720..9bbb29164bb 100644 --- a/snuba/query/allocation_policies/bytes_scanned_rejecting_policy.py +++ b/snuba/query/allocation_policies/bytes_scanned_rejecting_policy.py @@ -30,11 +30,9 @@ # we don't limit the amount of bytes subscriptions can scan at this time -_PASS_THROUGH_REFERRERS = set( - [ - "subscriptions_executor", - ] -) +_PASS_THROUGH_REFERRERS = { + "subscriptions_executor", +} UNREASONABLY_LARGE_NUMBER_OF_BYTES_SCANNED_PER_QUERY = int(1e12) @@ -186,7 +184,7 @@ def __get_scan_limit( if override == DEFAULT_OVERRIDE_LIMIT: return int(self.get_config_value("project_referrer_scan_limit")) return int(override) - elif customer_tenant_key == "organization_id": + if customer_tenant_key == "organization_id": org_referrer_override = self.get_config_value( "organization_referrer_scan_limit_override", {"organization_id": customer_tenant_value, "referrer": referrer}, @@ -351,32 +349,31 @@ def _get_quota_allowance( suggestion=SUGGESTION, ) - else: - explanation[ - "reason" - ] = f"""{customer_tenant_key} {customer_tenant_value} is over the bytes scanned limit of {scan_limit} for referrer {referrer}. + explanation[ + "reason" + ] = f"""{customer_tenant_key} {customer_tenant_value} is over the bytes scanned limit of {scan_limit} for referrer {referrer}. This policy is exceeded when a customer is abusing a specific feature in a way that puts load on clickhouse. If this is happening to "many customers, that may mean the feature is written in an inefficient way""" - explanation["granted_quota"] = granted_quota.granted - explanation["limit"] = scan_limit - # This is technically a high cardinality tag value however these rejections - # should not happen often therefore it should be safe to output these rejections as metris + explanation["granted_quota"] = granted_quota.granted + explanation["limit"] = scan_limit + # This is technically a high cardinality tag value however these rejections + # should not happen often therefore it should be safe to output these rejections as metris - self.metrics.increment( - "bytes_scanned_rejection", - tags={"tenant": f"{customer_tenant_key}__{customer_tenant_value}__{referrer}"}, - ) - return QuotaAllowance( - can_run=False, - max_threads=0, - explanation=explanation, - is_throttled=True, - throttle_threshold=throttle_threshold, - rejection_threshold=scan_limit, - quota_used=used_quota, - quota_unit=QUOTA_UNIT, - suggestion=SUGGESTION, - ) + self.metrics.increment( + "bytes_scanned_rejection", + tags={"tenant": f"{customer_tenant_key}__{customer_tenant_value}__{referrer}"}, + ) + return QuotaAllowance( + can_run=False, + max_threads=0, + explanation=explanation, + is_throttled=True, + throttle_threshold=throttle_threshold, + rejection_threshold=scan_limit, + quota_used=used_quota, + quota_unit=QUOTA_UNIT, + suggestion=SUGGESTION, + ) # this checks to see if you reached the throttle threshold if granted_quota.granted < scan_limit - throttle_threshold: @@ -420,13 +417,11 @@ def _get_bytes_scanned_in_query( and result_or_error.error.__cause__.code == errors.ErrorCodes.TIMEOUT_EXCEEDED ): return int(self.get_config_value("clickhouse_timeout_bytes_scanned_penalization")) - else: - return 0 + return 0 + assert result_or_error.query_result is not None progress_bytes_scanned = cast( int, - result_or_error.query_result.result.get("profile", {}).get( # type: ignore[union-attr] - "progress_bytes", None - ), + (result_or_error.query_result.result.get("profile") or {}).get("progress_bytes", None), ) if isinstance(progress_bytes_scanned, (int, float)): self.metrics.increment( diff --git a/snuba/query/allocation_policies/bytes_scanned_window_policy.py b/snuba/query/allocation_policies/bytes_scanned_window_policy.py index 93770d66fbf..e6e64b3077a 100644 --- a/snuba/query/allocation_policies/bytes_scanned_window_policy.py +++ b/snuba/query/allocation_policies/bytes_scanned_window_policy.py @@ -29,61 +29,53 @@ # A hardcoded list of referrers which do not have an organization_id associated with them # purposefully not in config because we don't want that to be easily changeable -_ORG_LESS_REFERRERS = set( - [ - "subscriptions_executor", - "weekly_reports.outcomes", - "reports.key_errors", - "reports.key_performance_issues", - "weekly_reports.key_transactions.this_week", - "weekly_reports.key_transactions.last_week", - "dynamic_sampling.distribution.fetch_projects_with_count_per_root_total_volumes", - "dynamic_sampling.distribution.fetch_orgs_with_count_per_root_total_volumes", - "dynamic_sampling.counters.fetch_projects_with_count_per_transaction_volumes", - "dynamic_sampling.counters.fetch_projects_with_transaction_totals", - "dynamic_sampling.counters.get_org_transaction_volumes", - "dynamic_sampling.counters.get_active_orgs", - "migration.backfill_perf_issue_events_issue_platform", - "api.vroom", - "replays.query.download_replay_segments", - "release_monitor.fetch_projects_with_recent_sessions", - "http://localhost:1219/", - "reprocessing2.start_group_reprocessing", - "metric_validation", - ] -) +_ORG_LESS_REFERRERS = { + "subscriptions_executor", + "weekly_reports.outcomes", + "reports.key_errors", + "reports.key_performance_issues", + "weekly_reports.key_transactions.this_week", + "weekly_reports.key_transactions.last_week", + "dynamic_sampling.distribution.fetch_projects_with_count_per_root_total_volumes", + "dynamic_sampling.distribution.fetch_orgs_with_count_per_root_total_volumes", + "dynamic_sampling.counters.fetch_projects_with_count_per_transaction_volumes", + "dynamic_sampling.counters.fetch_projects_with_transaction_totals", + "dynamic_sampling.counters.get_org_transaction_volumes", + "dynamic_sampling.counters.get_active_orgs", + "migration.backfill_perf_issue_events_issue_platform", + "api.vroom", + "replays.query.download_replay_segments", + "release_monitor.fetch_projects_with_recent_sessions", + "http://localhost:1219/", + "reprocessing2.start_group_reprocessing", + "metric_validation", +} # referrers which do not serve the UI and are given low capacity by default -_SINGLE_THREAD_REFERRERS = set( - [ - "delete-events-from-file", - "delete-event-user-data", - "scrub-nodestore", - "fetch_events_for_deletion", - "delete-events-by-tag-value", - "delete.fetch_last_group", - "forward-events", - "_insert_transaction.verify_transaction", - "tasks.update_user_reports", - "test.wait_for_event_count", - ] -) +_SINGLE_THREAD_REFERRERS = { + "delete-events-from-file", + "delete-event-user-data", + "scrub-nodestore", + "fetch_events_for_deletion", + "delete-events-by-tag-value", + "delete.fetch_last_group", + "forward-events", + "_insert_transaction.verify_transaction", + "tasks.update_user_reports", + "test.wait_for_event_count", +} # subscriptions currently do not undergo rate limiting in any way. # having subscriptions be too slow means there is an incident -_PASS_THROUGH_REFERRERS = set( - [ - "subscriptions_executor", - ] -) +_PASS_THROUGH_REFERRERS = { + "subscriptions_executor", +} UNREASONABLY_LARGE_NUMBER_OF_BYTES_SCANNED_PER_QUERY = int(1e10) -_RATE_LIMITER = RedisSlidingWindowRateLimiter( - get_redis_client(RedisClientKey.RATE_LIMITER) -) +_RATE_LIMITER = RedisSlidingWindowRateLimiter(get_redis_client(RedisClientKey.RATE_LIMITER)) DEFAULT_OVERRIDE_LIMIT = -1 DEFAULT_BYTES_SCANNED_LIMIT = 10000000 QUOTA_UNIT = "bytes" @@ -117,17 +109,15 @@ def _additional_config_definitions(self) -> list[Configuration]: ), ] - def _are_tenant_ids_valid( - self, tenant_ids: dict[str, str | int] - ) -> tuple[bool, str]: + def _are_tenant_ids_valid(self, tenant_ids: dict[str, str | int]) -> tuple[bool, str]: if self.is_cross_org_query(tenant_ids): return True, "cross org query" if tenant_ids.get("referrer") is None: return False, "no referrer" if ( tenant_ids.get("organization_id") is None - and tenant_ids.get("referrer", None) not in _ORG_LESS_REFERRERS - and tenant_ids.get("referrer", None) not in _SINGLE_THREAD_REFERRERS + and tenant_ids.get("referrer") not in _ORG_LESS_REFERRERS + and tenant_ids.get("referrer") not in _SINGLE_THREAD_REFERRERS ): return False, f"no organization_id for referrer {tenant_ids['referrer']}" @@ -137,19 +127,18 @@ def _get_quota_allowance( self, tenant_ids: dict[str, str | int], query_id: str ) -> QuotaAllowance: ids_are_valid, why = self._are_tenant_ids_valid(tenant_ids) - if not ids_are_valid: - if self.is_enforced: - return QuotaAllowance( - can_run=False, - max_threads=0, - explanation={"reason": why}, - is_throttled=False, - throttle_threshold=0, - rejection_threshold=0, - quota_used=0, - quota_unit=NO_UNITS, - suggestion=NO_SUGGESTION, - ) + if not ids_are_valid and self.is_enforced: + return QuotaAllowance( + can_run=False, + max_threads=0, + explanation={"reason": why}, + is_throttled=False, + throttle_threshold=0, + rejection_threshold=0, + quota_used=0, + quota_unit=NO_UNITS, + suggestion=NO_SUGGESTION, + ) if self.is_cross_org_query(tenant_ids): return QuotaAllowance( can_run=True, @@ -163,7 +152,7 @@ def _get_quota_allowance( suggestion=CROSS_ORG_SUGGESTION, ) referrer = tenant_ids.get("referrer", "no_referrer") - org_id = tenant_ids.get("organization_id", None) + org_id = tenant_ids.get("organization_id") if referrer in _PASS_THROUGH_REFERRERS: return QuotaAllowance( can_run=True, @@ -217,9 +206,9 @@ def _get_quota_allowance( is_throttled = False if granted_quota.granted <= 0: is_throttled = True - explanation[ - "reason" - ] = f"organization {org_id} is over the bytes scanned limit of {org_limit_bytes_scanned}" + explanation["reason"] = ( + f"organization {org_id} is over the bytes scanned limit of {org_limit_bytes_scanned}" + ) explanation["is_enforced"] = self.is_enforced explanation["granted_quota"] = granted_quota.granted explanation["limit"] = org_limit_bytes_scanned @@ -253,7 +242,10 @@ def _get_quota_allowance( def _get_bytes_scanned_in_query( self, tenant_ids: dict[str, str | int], result_or_error: QueryResultOrError ) -> int: - progress_bytes_scanned = cast(int, result_or_error.query_result.result.get("profile", {}).get("progress_bytes", None)) # type: ignore + progress_bytes_scanned = cast( + int, + result_or_error.query_result.result.get("profile", {}).get("progress_bytes", None), # type: ignore[union-attr] + ) if isinstance(progress_bytes_scanned, (int, float)): self.metrics.increment( "progress_bytes_scanned", diff --git a/snuba/query/allocation_policies/concurrent_rate_limit.py b/snuba/query/allocation_policies/concurrent_rate_limit.py index 2867a9c805b..bf4cce405f8 100644 --- a/snuba/query/allocation_policies/concurrent_rate_limit.py +++ b/snuba/query/allocation_policies/concurrent_rate_limit.py @@ -1,12 +1,16 @@ from __future__ import annotations import logging -from typing import Callable, cast +import typing +from collections.abc import Callable +from typing import cast from snuba import state from snuba.configs.configuration import Configuration from snuba.query.allocation_policies import ( CROSS_ORG_SUGGESTION, + MAX_THRESHOLD, + NO_SUGGESTION, PASS_THROUGH_REFERRERS_SUGGESTION, AllocationPolicy, AllocationPolicyViolations, @@ -27,23 +31,19 @@ logger = logging.getLogger("snuba.query.allocation_policy_rate_limit") -_PASS_THROUGH_REFERRERS = set( - [ - # these referrers are tied to ingest and are better limited by the ReferrerGuardRailPolicy - "subscriptions_executor", - "tsdb-modelid:4.batch_alert_event_frequency", - "tsdb-modelid:4.batch_alert_event_uniq_user_frequency", - "tsdb-modelid:4.batch_alert_event_frequency_percent", - "tsdb-modelid:4.wf_batch_alert_event_frequency", - "tsdb-modelid:300.wf_batch_alert_event_uniq_user_frequency", - "tsdb-modelid:4.wf_batch_alert_event_frequency_percent", - ] -) -from snuba.query.allocation_policies import MAX_THRESHOLD, NO_SUGGESTION +_PASS_THROUGH_REFERRERS = { + # these referrers are tied to ingest and are better limited by the ReferrerGuardRailPolicy + "subscriptions_executor", + "tsdb-modelid:4.batch_alert_event_frequency", + "tsdb-modelid:4.batch_alert_event_uniq_user_frequency", + "tsdb-modelid:4.batch_alert_event_frequency_percent", + "tsdb-modelid:4.wf_batch_alert_event_frequency", + "tsdb-modelid:300.wf_batch_alert_event_uniq_user_frequency", + "tsdb-modelid:4.wf_batch_alert_event_frequency_percent", +} QUOTA_UNIT = "concurrent_queries" SUGGESTION = "A customer is sending too many queries to snuba. The customer may be abusing an API or the queries may be innefficient" -import typing class BaseConcurrentRateLimitAllocationPolicy(AllocationPolicy): diff --git a/snuba/query/allocation_policies/cross_org.py b/snuba/query/allocation_policies/cross_org.py index 5120cf3a98a..c9425790159 100644 --- a/snuba/query/allocation_policies/cross_org.py +++ b/snuba/query/allocation_policies/cross_org.py @@ -1,10 +1,17 @@ from __future__ import annotations import logging +import typing from typing import Any, cast from snuba.configs.configuration import Configuration, InvalidConfig, ResourceIdentifier -from snuba.query.allocation_policies import QueryResultOrError, QuotaAllowance +from snuba.query.allocation_policies import ( + MAX_THRESHOLD, + NO_SUGGESTION, + NO_UNITS, + QueryResultOrError, + QuotaAllowance, +) from snuba.query.allocation_policies.concurrent_rate_limit import ( BaseConcurrentRateLimitAllocationPolicy, ) @@ -22,11 +29,9 @@ _RATE_LIMIT_NAME = "concurrent_limit_policy" _UNREGISTERED_REFERRER_MAX_THREADS = 1 _UNREGISTERED_REFERRER_CONCURRENT_QUERIES = 1 -from snuba.query.allocation_policies import MAX_THRESHOLD, NO_SUGGESTION, NO_UNITS QUOTA_UNIT = "concurrent_queries" SUGGESTION = "scan less concurrent queries" -import typing class CrossOrgQueryAllocationPolicy(BaseConcurrentRateLimitAllocationPolicy): @@ -61,15 +66,17 @@ def set_config_value( self, config_key: str, value: Any, - params: dict[str, Any] = {}, + params: dict[str, Any] | None = None, user: str | None = None, ) -> None: """makes sure only registered referrers can be overridden""" + if params is None: + params = {} if config_key in ( "referrer_concurrent_override", "referrer_max_threads_override", ): - referrer = params.get("referrer", None) + referrer = params.get("referrer") if referrer is not None and not self._referrer_is_registered(referrer): raise InvalidConfig( f"Referrer {referrer} is not registered in the the {self._resource_identifier.value} yaml. Register it first to be able to override its limits" diff --git a/snuba/query/allocation_policies/utils.py b/snuba/query/allocation_policies/utils.py index 93348d73f1b..17f7d2e9a49 100644 --- a/snuba/query/allocation_policies/utils.py +++ b/snuba/query/allocation_policies/utils.py @@ -1,9 +1,7 @@ -from typing import List - from snuba.query.allocation_policies import QuotaAllowance -def get_max_bytes_to_read(quota_allowances: List[QuotaAllowance]) -> int: +def get_max_bytes_to_read(quota_allowances: list[QuotaAllowance]) -> int: max_bytes_to_read = min( [qa.max_bytes_to_read for qa in quota_allowances], key=lambda mb: float("inf") if mb == 0 else mb, diff --git a/snuba/query/composite.py b/snuba/query/composite.py index 8477d9a3480..782d4bb454b 100644 --- a/snuba/query/composite.py +++ b/snuba/query/composite.py @@ -1,6 +1,7 @@ from __future__ import annotations -from typing import Callable, Generic, Iterable, Optional, Sequence, Union, cast +from collections.abc import Callable, Iterable, Sequence +from typing import Generic, cast from snuba.query import ( LimitBy, @@ -26,27 +27,24 @@ class CompositeQuery(Query, Generic[TSimpleDataSource]): def __init__( self, - from_clause: Optional[ - Union[ - ProcessableQuery[TSimpleDataSource], - CompositeQuery[TSimpleDataSource], - JoinClause[TSimpleDataSource], - ] - ], + from_clause: ProcessableQuery[TSimpleDataSource] + | CompositeQuery[TSimpleDataSource] + | JoinClause[TSimpleDataSource] + | None, # TODO: Consider if to remove the defaults and make some of # these fields mandatory. This impacts a lot of code so it # would be done on its own. - selected_columns: Optional[Sequence[SelectedExpression]] = None, - array_join: Optional[Sequence[Expression]] = None, - condition: Optional[Expression] = None, - groupby: Optional[Sequence[Expression]] = None, - having: Optional[Expression] = None, - order_by: Optional[Sequence[OrderBy]] = None, - limitby: Optional[LimitBy] = None, - limit: Optional[int] = None, + selected_columns: Sequence[SelectedExpression] | None = None, + array_join: Sequence[Expression] | None = None, + condition: Expression | None = None, + groupby: Sequence[Expression] | None = None, + having: Expression | None = None, + order_by: Sequence[OrderBy] | None = None, + limitby: LimitBy | None = None, + limit: int | None = None, offset: int = 0, totals: bool = False, - granularity: Optional[int] = None, + granularity: int | None = None, ): super().__init__( selected_columns=selected_columns, @@ -88,21 +86,19 @@ def __repr__(self) -> str: def get_from_clause( self, - ) -> Union[ - ProcessableQuery[TSimpleDataSource], - CompositeQuery[TSimpleDataSource], - JoinClause[TSimpleDataSource], - ]: + ) -> ( + ProcessableQuery[TSimpleDataSource] + | CompositeQuery[TSimpleDataSource] + | JoinClause[TSimpleDataSource] + ): assert self.__from_clause is not None, "Data source has not been provided yet." return self.__from_clause def set_from_clause( self, - from_clause: Union[ - ProcessableQuery[TSimpleDataSource], - CompositeQuery[TSimpleDataSource], - JoinClause[TSimpleDataSource], - ], + from_clause: ProcessableQuery[TSimpleDataSource] + | CompositeQuery[TSimpleDataSource] + | JoinClause[TSimpleDataSource], ) -> None: self.__from_clause = from_clause diff --git a/snuba/query/conditions.py b/snuba/query/conditions.py index 637b2eec078..5c30190a997 100644 --- a/snuba/query/conditions.py +++ b/snuba/query/conditions.py @@ -1,6 +1,7 @@ from __future__ import annotations -from typing import Any, Mapping, Optional, Sequence, Set, Union +from collections.abc import Mapping, Sequence +from typing import Any from snuba.query.dsl import literals_tuple from snuba.query.expressions import Expression, FunctionCall, Literal @@ -109,13 +110,15 @@ def __set_condition_pattern(lhs: Pattern[Expression], operator: str) -> Function def __is_set_condition(exp: Expression, operator: str) -> bool: - if is_any_binary_condition(exp, operator): - if operator in set_condition_pattern: - if set_condition_pattern[operator].match(exp) is not None: - assert isinstance(exp, FunctionCall) # mypy - assert isinstance(exp.parameters[1], FunctionCall) # mypy - # Matchers can't currently match arbitrary numbers of parameters, so test this directly - return all(isinstance(c, Literal) for c in exp.parameters[1].parameters) + if ( + is_any_binary_condition(exp, operator) + and operator in set_condition_pattern + and set_condition_pattern[operator].match(exp) is not None + ): + assert isinstance(exp, FunctionCall) # mypy + assert isinstance(exp.parameters[1], FunctionCall) # mypy + # Matchers can't currently match arbitrary numbers of parameters, so test this directly + return all(isinstance(c, Literal) for c in exp.parameters[1].parameters) return False @@ -168,7 +171,7 @@ def binary_condition(function_name: str, lhs: Expression, rhs: Expression) -> Fu def condition_pattern( - operators: Set[str], + operators: set[str], lhs_pattern: Pattern[Expression], rhs_pattern: Pattern[Expression], commutative: bool, @@ -268,8 +271,7 @@ def _get_first_level_conditions(condition: Expression, function: str) -> Sequenc *_get_first_level_conditions(match.expression("left"), function), *_get_first_level_conditions(match.expression("right"), function), ] - else: - return [condition] + return [condition] def combine_or_conditions(conditions: Sequence[Expression]) -> Expression: @@ -322,13 +324,13 @@ def is_condition(exp: Expression) -> bool: def build_match( - col: Optional[str] = None, - subscriptable: Optional[str] = None, - ops: Optional[Sequence[str]] = None, - array_ops: Optional[Sequence[str]] = None, - param_type: Optional[Any] = None, - alias: Optional[str] = None, - key: Optional[str] = None, + col: str | None = None, + subscriptable: str | None = None, + ops: Sequence[str] | None = None, + array_ops: Sequence[str] | None = None, + param_type: Any | None = None, + alias: str | None = None, + key: str | None = None, ) -> Or[Expression]: """ There is a common use case of matching a specific condition in our code base. @@ -342,7 +344,7 @@ def build_match( on a successful match. """ alias_match = AnyOptionalString() if alias is None else String(alias) - pattern: Union[ColumnPattern, SubscriptableReferencePattern] + pattern: ColumnPattern | SubscriptableReferencePattern assert subscriptable is not None or col is not None if subscriptable is not None: diff --git a/snuba/query/data_source/join.py b/snuba/query/data_source/join.py index f663fdafa1a..c8c757c6d5d 100644 --- a/snuba/query/data_source/join.py +++ b/snuba/query/data_source/join.py @@ -1,18 +1,13 @@ from __future__ import annotations from abc import ABC, abstractmethod +from collections.abc import Mapping, MutableSequence, Sequence from dataclasses import dataclass from enum import Enum from typing import ( Generic, - Mapping, - MutableSequence, NamedTuple, - Optional, - Sequence, - Tuple, TypeVar, - Union, ) from snuba.datasets.entities.entity_key import EntityKey @@ -45,7 +40,7 @@ class JoinRelationship(NamedTuple): rhs_entity: EntityKey join_type: JoinType - columns: Sequence[Tuple[str, str]] + columns: Sequence[tuple[str, str]] # Keeps track of the semantically equivalent columns between the two # related entities. Example transaction_name on the transactions table # and transaction_name on the spans table. These columns are not part @@ -84,7 +79,7 @@ class IndividualNode(JoinNode[TSimpleDataSource], Generic[TSimpleDataSource]): """ alias: str - data_source: Union[TSimpleDataSource, ProcessableQuery[TSimpleDataSource]] + data_source: TSimpleDataSource | ProcessableQuery[TSimpleDataSource] def get_alias_node_map(self) -> Mapping[str, IndividualNode[TSimpleDataSource]]: return {self.alias: self} @@ -140,7 +135,7 @@ class JoinClause(DataSource, JoinNode[TSimpleDataSource], Generic[TSimpleDataSou right_node: IndividualNode[TSimpleDataSource] keys: Sequence[JoinCondition] join_type: JoinType - join_modifier: Optional[JoinModifier] = None + join_modifier: JoinModifier | None = None def get_column_sets(self) -> Mapping[str, ColumnSet]: return { diff --git a/snuba/query/data_source/simple.py b/snuba/query/data_source/simple.py index e40cea06375..93b66ea50a3 100644 --- a/snuba/query/data_source/simple.py +++ b/snuba/query/data_source/simple.py @@ -1,8 +1,8 @@ from __future__ import annotations from abc import ABC +from collections.abc import Sequence from dataclasses import dataclass, field -from typing import Optional, Sequence from snuba.clickhouse.columns import ColumnSet from snuba.datasets.entities.entity_key import EntityKey @@ -36,7 +36,7 @@ def human_readable_id(self) -> str: class LogicalDataSource(SimpleDataSource): key: EntityKey | StorageKey schema: ColumnSet - sample: Optional[float] = None + sample: float | None = None def get_columns(self) -> ColumnSet: return self.schema @@ -54,7 +54,7 @@ class Entity(LogicalDataSource): key: EntityKey schema: ColumnSet - sample: Optional[float] = None + sample: float | None = None def get_columns(self) -> ColumnSet: return self.schema @@ -71,7 +71,7 @@ class Storage(LogicalDataSource): key: StorageKey schema: ColumnSet = field(default_factory=lambda: ColumnSet([])) - sample: Optional[float] = None + sample: float | None = None @property def human_readable_id(self) -> str: @@ -100,7 +100,7 @@ class Table(SimpleDataSource): default_factory=lambda: [DEFAULT_PASSTHROUGH_POLICY] ) final: bool = False - sampling_rate: Optional[float] = None + sampling_rate: float | None = None # TODO: Move mandatory connditions out of # here as they are structural property of a storage. This requires # the processors that consume these fields to access the storage. diff --git a/snuba/query/data_source/visitor.py b/snuba/query/data_source/visitor.py index 0ed86811d49..6757be7b416 100644 --- a/snuba/query/data_source/visitor.py +++ b/snuba/query/data_source/visitor.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Generic, TypeVar, Union +from typing import Generic, TypeVar from snuba.query import ProcessableQuery, TSimpleDataSource from snuba.query.composite import CompositeQuery @@ -21,24 +21,21 @@ class DataSourceVisitor(ABC, Generic[TReturn, TSimpleDataSource]): def visit( self, - data_source: Union[ - TSimpleDataSource, - JoinClause[TSimpleDataSource], - ProcessableQuery[TSimpleDataSource], - CompositeQuery[TSimpleDataSource], - ], + data_source: TSimpleDataSource + | JoinClause[TSimpleDataSource] + | ProcessableQuery[TSimpleDataSource] + | CompositeQuery[TSimpleDataSource], ) -> TReturn: if isinstance(data_source, JoinClause): return self._visit_join(data_source) - elif isinstance(data_source, ProcessableQuery): + if isinstance(data_source, ProcessableQuery): return self._visit_simple_query(data_source) - elif isinstance(data_source, CompositeQuery): + if isinstance(data_source, CompositeQuery): return self._visit_composite_query(data_source) - else: - # It must be a simple data source according to the type - # signature, we cannot do that via the isinstance call - # since that type does not exist at runtime. - return self._visit_simple_source(data_source) + # It must be a simple data source according to the type + # signature, we cannot do that via the isinstance call + # since that type does not exist at runtime. + return self._visit_simple_source(data_source) @abstractmethod def _visit_simple_source(self, data_source: TSimpleDataSource) -> TReturn: diff --git a/snuba/query/dsl.py b/snuba/query/dsl.py index bc237b67f24..8600c89348b 100644 --- a/snuba/query/dsl.py +++ b/snuba/query/dsl.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Optional, Sequence +from collections.abc import Sequence from snuba.query.expressions import ( Column, @@ -129,52 +129,52 @@ def literal(value: OptionalScalarType, alias: str | None = None) -> Literal: return Literal(alias, value) -def literals_tuple(alias: Optional[str], literals: Sequence[Literal]) -> FunctionCall: +def literals_tuple(alias: str | None, literals: Sequence[Literal]) -> FunctionCall: return FunctionCall(alias, "tuple", tuple(literals)) -def literals_array(alias: Optional[str], literals: Sequence[Literal]) -> FunctionCall: +def literals_array(alias: str | None, literals: Sequence[Literal]) -> FunctionCall: return FunctionCall(alias, "array", tuple(literals)) # Array functions -def arrayElement(alias: Optional[str], array: Expression, index: Expression) -> FunctionCall: +def arrayElement(alias: str | None, array: Expression, index: Expression) -> FunctionCall: return FunctionCall(alias, "arrayElement", (array, index)) -def arrayJoin(alias: Optional[str], content: Expression) -> Expression: +def arrayJoin(alias: str | None, content: Expression) -> Expression: return FunctionCall(alias, "arrayJoin", (content,)) # Tuple functions -def tupleElement(alias: Optional[str], tuple_expr: Expression, index: Expression) -> FunctionCall: +def tupleElement(alias: str | None, tuple_expr: Expression, index: Expression) -> FunctionCall: return FunctionCall(alias, "tupleElement", (tuple_expr, index)) # arithmetic function -def plus(lhs: Expression, rhs: Expression, alias: Optional[str] = None) -> FunctionCall: +def plus(lhs: Expression, rhs: Expression, alias: str | None = None) -> FunctionCall: return FunctionCall(alias, "plus", (lhs, rhs)) -def minus(lhs: Expression, rhs: Expression, alias: Optional[str] = None) -> FunctionCall: +def minus(lhs: Expression, rhs: Expression, alias: str | None = None) -> FunctionCall: return FunctionCall(alias, "minus", (lhs, rhs)) -def multiply(lhs: Expression, rhs: Expression, alias: Optional[str] = None) -> FunctionCall: +def multiply(lhs: Expression, rhs: Expression, alias: str | None = None) -> FunctionCall: return FunctionCall(alias, "multiply", (lhs, rhs)) -def divide(lhs: Expression, rhs: Expression, alias: Optional[str] = None) -> FunctionCall: +def divide(lhs: Expression, rhs: Expression, alias: str | None = None) -> FunctionCall: return FunctionCall(alias, "divide", (lhs, rhs)) -def if_in(lhs: Expression, rhs: Expression, alias: Optional[str] = None) -> FunctionCall: +def if_in(lhs: Expression, rhs: Expression, alias: str | None = None) -> FunctionCall: return FunctionCall(alias, "in", (lhs, rhs)) # boolean functions def binary_condition( - function_name: str, lhs: Expression, rhs: Expression, alias: Optional[str] = None + function_name: str, lhs: Expression, rhs: Expression, alias: str | None = None ) -> FunctionCall: return FunctionCall(alias, function_name, (lhs, rhs)) @@ -195,23 +195,23 @@ def or_cond(lhs: Expression, rhs: Expression, *args: Expression) -> FunctionCall return FunctionCall(None, "or", (lhs, rhs, *args)) -def in_cond(lhs: Expression, rhs: Expression, alias: Optional[str] = None) -> FunctionCall: +def in_cond(lhs: Expression, rhs: Expression, alias: str | None = None) -> FunctionCall: return binary_condition("in", lhs, rhs, alias) -def not_cond(expr: Expression, alias: Optional[str] = None) -> FunctionCall: +def not_cond(expr: Expression, alias: str | None = None) -> FunctionCall: return FunctionCall(alias, "not", (expr,)) # aggregate functions -def count(column: Optional[Column] = None, alias: Optional[str] = None) -> FunctionCall: +def count(column: Column | None = None, alias: str | None = None) -> FunctionCall: return FunctionCall(alias, "count", (column,) if column else ()) def countIf( condition: FunctionCall, - column: Optional[Column] = None, - alias: Optional[str] = None, + column: Column | None = None, + alias: str | None = None, ) -> FunctionCall: return FunctionCall(alias, "countIf", (condition, column) if column else (condition,)) @@ -220,10 +220,10 @@ def if_cond( condition: FunctionCall, then_clause: Expression, else_clause: Expression, - alias: Optional[str] = None, + alias: str | None = None, ) -> FunctionCall: return FunctionCall(alias, "if", (condition, then_clause, else_clause)) -def identity(expression: Expression, alias: Optional[str]) -> FunctionCall: +def identity(expression: Expression, alias: str | None) -> FunctionCall: return FunctionCall(alias, "identity", (expression,)) diff --git a/snuba/query/dsl_mapper.py b/snuba/query/dsl_mapper.py index fc337c33aed..0d654014762 100644 --- a/snuba/query/dsl_mapper.py +++ b/snuba/query/dsl_mapper.py @@ -1,4 +1,4 @@ -from typing import Callable, Sequence +from collections.abc import Callable, Sequence from snuba.clickhouse.query import Query as ClickhouseQuery from snuba.query import LimitBy, OrderBy, SelectedExpression @@ -212,9 +212,9 @@ def ast_repr( ) -> str: if not exp: return "None" - elif isinstance(exp, Expression): + if isinstance(exp, Expression): return exp.accept(visitor) - elif isinstance(exp, LimitBy): + if isinstance(exp, LimitBy): return visitor.visit_limitby(exp) strings = [] diff --git a/snuba/query/exceptions.py b/snuba/query/exceptions.py index 9b23f6add56..88615eeeaf2 100644 --- a/snuba/query/exceptions.py +++ b/snuba/query/exceptions.py @@ -1,5 +1,3 @@ -from typing import Optional - from snuba.query.expressions import Expression from snuba.utils.serializable_exception import SerializableException @@ -40,8 +38,8 @@ class QueryPlanException(SerializableException): def __init__( self, - exception_type: Optional[str] = None, - message: Optional[str] = None, + exception_type: str | None = None, + message: str | None = None, should_report: bool = True, ) -> None: self.exception_type = exception_type diff --git a/snuba/query/expressions.py b/snuba/query/expressions.py index 6180edfd107..4fb250200d3 100644 --- a/snuba/query/expressions.py +++ b/snuba/query/expressions.py @@ -1,9 +1,10 @@ from __future__ import annotations from abc import ABC, abstractmethod +from collections.abc import Callable, Iterator from dataclasses import dataclass, replace from datetime import date, datetime -from typing import Callable, Generic, Iterator, Optional, Tuple, TypeVar, Union +from typing import Generic, TypeVar, cast from snuba import settings @@ -22,7 +23,7 @@ @dataclass(frozen=True, repr=_AUTO_REPR) class _Expression: # TODO: Make it impossible to assign empty string as an alias. - alias: Optional[str] + alias: str | None class Expression(_Expression, ABC): @@ -77,8 +78,7 @@ def __repr__(self) -> str: if settings.PRETTY_FORMAT_EXPRESSIONS: visitor = StringifyVisitor() return self.accept(visitor) - else: - return super().__repr__() + return super().__repr__() def functional_eq(self, other: Expression) -> bool: """Returns if an expression is functionally equivalent to the other. i.e. performs an equality @@ -154,7 +154,7 @@ def visit_subscriptable_reference(self, exp: SubscriptableReference) -> None: def visit_function_call(self, exp: FunctionCall) -> None: for param in exp.parameters: param.accept(self) - return None + return def visit_curried_function_call(self, exp: CurriedFunctionCall) -> None: for param in exp.parameters: @@ -317,7 +317,7 @@ def visit_json_path(self, exp: JsonPath) -> set[str]: return exp.base.accept(self) -OptionalScalarType = Union[None, bool, str, float, int, date, datetime] +OptionalScalarType = None | bool | str | float | int | date | datetime @dataclass(frozen=True, repr=_AUTO_REPR) @@ -349,7 +349,7 @@ class Column(Expression): Represent a column in the schema of the dataset. """ - table_name: Optional[str] + table_name: str | None column_name: str def transform(self, func: Callable[[Expression], Expression]) -> Expression: @@ -388,8 +388,8 @@ def accept(self, visitor: ExpressionVisitor[TVisited]) -> TVisited: def transform(self, func: Callable[[Expression], Expression]) -> Expression: transformed = replace( self, - column=self.column.transform(func), - key=self.key.transform(func), + column=cast(Column, self.column.transform(func)), + key=cast(Literal, self.key.transform(func)), ) return func(transformed) @@ -424,7 +424,7 @@ class FunctionCall(Expression): function_name: str # This is a tuple with variable size and not a Sequence to enforce it is hashable - parameters: Tuple[Expression, ...] + parameters: tuple[Expression, ...] def transform(self, func: Callable[[Expression], Expression]) -> Expression: """ @@ -441,7 +441,7 @@ def transform(self, func: Callable[[Expression], Expression]) -> Expression: """ transformed = replace( self, - parameters=tuple(map(lambda child: child.transform(func), self.parameters)), + parameters=tuple(child.transform(func) for child in self.parameters), ) return func(transformed) @@ -452,8 +452,7 @@ def __iter__(self) -> Iterator[Expression]: order we have in the transform method. """ for child in self.parameters: - for sub in child: - yield sub + yield from child yield self def accept(self, visitor: ExpressionVisitor[TVisited]) -> TVisited: @@ -490,7 +489,7 @@ class CurriedFunctionCall(Expression): internal_function: FunctionCall # The parameters to apply to the result of internal_function. # This is a tuple with variable size and not a Sequence to enforce it is hashable - parameters: Tuple[Expression, ...] + parameters: tuple[Expression, ...] def transform(self, func: Callable[[Expression], Expression]) -> Expression: """ @@ -501,8 +500,8 @@ def transform(self, func: Callable[[Expression], Expression]) -> Expression: """ transformed = replace( self, - internal_function=self.internal_function.transform(func), - parameters=tuple(map(lambda child: child.transform(func), self.parameters)), + internal_function=cast(FunctionCall, self.internal_function.transform(func)), + parameters=tuple(child.transform(func) for child in self.parameters), ) return func(transformed) @@ -513,8 +512,7 @@ def __iter__(self) -> Iterator[Expression]: for child in self.internal_function: yield child for child in self.parameters: - for sub in child: - yield sub + yield from child yield self def accept(self, visitor: ExpressionVisitor[TVisited]) -> TVisited: @@ -566,7 +564,7 @@ class Lambda(Expression): # the parameters in the expressions. These are intentionally not expressions # since they are variable names and cannot have aliases # This is a tuple with variable size and not a Sequence to enforce it is hashable - parameters: Tuple[str, ...] + parameters: tuple[str, ...] transformation: Expression def transform(self, func: Callable[[Expression], Expression]) -> Expression: @@ -581,8 +579,7 @@ def __iter__(self) -> Iterator[Expression]: """ Traverse the subtree in a postfix order. """ - for child in self.transformation: - yield child + yield from self.transformation yield self def accept(self, visitor: ExpressionVisitor[TVisited]) -> TVisited: @@ -593,9 +590,7 @@ def functional_eq(self, other: Expression) -> bool: return False if self.parameters != other.parameters: return False - if not self.transformation.functional_eq(other.transformation): - return False - return True + return self.transformation.functional_eq(other.transformation) @dataclass(frozen=True, repr=_AUTO_REPR) @@ -639,15 +634,14 @@ class JsonPath(Expression): base: Expression path: str - return_type: Optional[str] = None + return_type: str | None = None def transform(self, func: Callable[[Expression], Expression]) -> Expression: transformed = replace(self, base=self.base.transform(func)) return func(transformed) def __iter__(self) -> Iterator[Expression]: - for sub in self.base: - yield sub + yield from self.base yield self def accept(self, visitor: ExpressionVisitor[TVisited]) -> TVisited: diff --git a/snuba/query/formatters/tracing.py b/snuba/query/formatters/tracing.py index 069f2719856..e3d04d640eb 100644 --- a/snuba/query/formatters/tracing.py +++ b/snuba/query/formatters/tracing.py @@ -1,4 +1,5 @@ -from typing import Any, List, Mapping, Sequence, Union +from collections.abc import Mapping, Sequence +from typing import Any from snuba.query import ProcessableQuery from snuba.query.composite import CompositeQuery @@ -12,17 +13,17 @@ from snuba.query.data_source.visitor import DataSourceVisitor from snuba.query.expressions import StringifyVisitor -TExpression = Union[str, Mapping[str, Any], Sequence[Any]] +TExpression = str | Mapping[str, Any] | Sequence[Any] -def _indent_str_list(str_list: List[str], levels: int) -> List[str]: +def _indent_str_list(str_list: list[str], levels: int) -> list[str]: indent = " " * levels return [f"{indent}{s}" for s in str_list] def format_query( - query: Union[ProcessableQuery[SimpleDataSource], CompositeQuery[SimpleDataSource]], -) -> List[str]: + query: ProcessableQuery[SimpleDataSource] | CompositeQuery[SimpleDataSource], +) -> list[str]: """ Formats a query as a list of strings with each element being a new line @@ -96,14 +97,14 @@ def format_query( class TracingQueryFormatter( - DataSourceVisitor[List[str], SimpleDataSource], - JoinVisitor[List[str], SimpleDataSource], + DataSourceVisitor[list[str], SimpleDataSource], + JoinVisitor[list[str], SimpleDataSource], ): - def _indent_str_list(self, str_list: List[str], levels: int) -> List[str]: + def _indent_str_list(self, str_list: list[str], levels: int) -> list[str]: indent = " " * levels return [f"{indent}{s}" for s in str_list] - def _visit_simple_source(self, data_source: SimpleDataSource) -> List[str]: + def _visit_simple_source(self, data_source: SimpleDataSource) -> list[str]: # Entity and Table define their sampling rates with slightly different # terms and renaming it would introduce a lot of code changes down the line # so we use this dynamic workaround @@ -111,40 +112,39 @@ def _visit_simple_source(self, data_source: SimpleDataSource) -> List[str]: sample_str = f" SAMPLE {sample_val}" if sample_val is not None else "" return [f"{data_source.human_readable_id}{sample_str}"] - def _visit_join(self, data_source: JoinClause[SimpleDataSource]) -> List[str]: + def _visit_join(self, data_source: JoinClause[SimpleDataSource]) -> list[str]: return self.visit_join_clause(data_source) - def _visit_simple_query(self, data_source: ProcessableQuery[SimpleDataSource]) -> List[str]: + def _visit_simple_query(self, data_source: ProcessableQuery[SimpleDataSource]) -> list[str]: return format_query(data_source) - def _visit_composite_query(self, data_source: CompositeQuery[SimpleDataSource]) -> List[str]: + def _visit_composite_query(self, data_source: CompositeQuery[SimpleDataSource]) -> list[str]: return format_query(data_source) - def visit_individual_node(self, node: IndividualNode[SimpleDataSource]) -> List[str]: + def visit_individual_node(self, node: IndividualNode[SimpleDataSource]) -> list[str]: return [f"{self.visit(node.data_source)} AS `{node.alias}`"] - def visit_join_clause(self, node: JoinClause[SimpleDataSource]) -> List[str]: + def visit_join_clause(self, node: JoinClause[SimpleDataSource]) -> list[str]: if node.join_type == JoinType.CROSS: return [ *_indent_str_list(node.left_node.accept(self), 1), f"{node.join_type.name.upper()} JOIN", *_indent_str_list(node.right_node.accept(self), 1), ] - else: - on_list = [ - [ - f"{c.left.table_alias}.{c.left.column}", - f"{c.right.table_alias}.{c.right.column}", - ] - for c in node.keys - ][0] - return [ - *_indent_str_list(node.left_node.accept(self), 1), - f"{node.join_type.name.upper()} JOIN", - *_indent_str_list(node.right_node.accept(self), 1), - "ON", - *_indent_str_list( - on_list, - 1, - ), + on_list = [ + [ + f"{c.left.table_alias}.{c.left.column}", + f"{c.right.table_alias}.{c.right.column}", ] + for c in node.keys + ][0] + return [ + *_indent_str_list(node.left_node.accept(self), 1), + f"{node.join_type.name.upper()} JOIN", + *_indent_str_list(node.right_node.accept(self), 1), + "ON", + *_indent_str_list( + on_list, + 1, + ), + ] diff --git a/snuba/query/indexer/resolver.py b/snuba/query/indexer/resolver.py index c4bdacca0d3..583a6ed83bb 100644 --- a/snuba/query/indexer/resolver.py +++ b/snuba/query/indexer/resolver.py @@ -31,8 +31,7 @@ def resolve(value: str, mapping: dict[str, str | int]) -> str | int: def resolve_tag_column_name(value: str, mapping: dict[str, str | int], dataset: Dataset) -> str: if get_dataset_name(dataset) == "metrics": return f"tags[{resolve(value, mapping)}]" - else: - return f"tags_raw[{resolve(value, mapping)}]" + return f"tags_raw[{resolve(value, mapping)}]" def resolve_tag_key_mappings( diff --git a/snuba/query/joins/classifier.py b/snuba/query/joins/classifier.py index a7f021ef41a..17019286942 100644 --- a/snuba/query/joins/classifier.py +++ b/snuba/query/joins/classifier.py @@ -1,18 +1,9 @@ from __future__ import annotations from abc import ABC, abstractmethod +from collections.abc import Callable, Generator, Mapping, MutableMapping, Sequence from dataclasses import dataclass, replace from functools import partial -from typing import ( - Callable, - Generator, - List, - Mapping, - MutableMapping, - Optional, - Sequence, - Set, -) from snuba.query.expressions import ( Argument, @@ -29,7 +20,7 @@ ) from snuba.query.functions import is_aggregation_function -AliasGenerator = Generator[str, None, None] +AliasGenerator = Generator[str] # This is a workaround for a mypy bug, found here: https://github.com/python/mypy/issues/5374 @@ -64,7 +55,7 @@ class MainQueryExpression(SubExpression): expression for the main query. """ - cut_branches: Mapping[str, Set[Expression]] + cut_branches: Mapping[str, set[Expression]] def cut_branch(self, alias_generator: AliasGenerator) -> MainQueryExpression: return self @@ -133,7 +124,7 @@ def cut_branch(self, alias_generator: AliasGenerator) -> MainQueryExpression: def _merge_subexpressions( - builder: Callable[[List[Expression]], Expression], + builder: Callable[[list[Expression]], Expression], sub_expressions: Sequence[SubExpression], alias_generator: AliasGenerator, ) -> SubExpression: @@ -159,24 +150,22 @@ def _merge_subexpressions( # All parameters are not classified. This function is also # not classified. return UnclassifiedExpression(builder([v.main_expression for v in sub_expressions])) - else: - # All parameters are either not classified or in a single - # subquery. This function is also referencing that subquery - # only. - return SubqueryExpression( - builder([v.main_expression for v in sub_expressions]), - subquery_alias=subqueries.pop(), - ) - else: - return _merge_and_cut(builder, sub_expressions, alias_generator) + # All parameters are either not classified or in a single + # subquery. This function is also referencing that subquery + # only. + return SubqueryExpression( + builder([v.main_expression for v in sub_expressions]), + subquery_alias=subqueries.pop(), + ) + return _merge_and_cut(builder, sub_expressions, alias_generator) def _merge_and_cut( - builder: Callable[[List[Expression]], Expression], + builder: Callable[[list[Expression]], Expression], sub_expressions: Sequence[SubExpression], alias_generator: AliasGenerator, ) -> SubExpression: - cut_branches: MutableMapping[str, Set[Expression]] = {} + cut_branches: MutableMapping[str, set[Expression]] = {} parameters = [] for v in sub_expressions: cut = v.cut_branch(alias_generator) @@ -244,7 +233,7 @@ def visit_subscriptable_reference(self, exp: SubscriptableReference) -> SubExpre def visit_function_call(self, exp: FunctionCall) -> SubExpression: def builder( - alias: Optional[str], func_name: str, params: Sequence[Expression] + alias: str | None, func_name: str, params: Sequence[Expression] ) -> FunctionCall: return FunctionCall(alias, func_name, tuple(params)) @@ -272,7 +261,7 @@ def builder( ) def visit_curried_function_call(self, exp: CurriedFunctionCall) -> SubExpression: - def builder(alias: Optional[str], params: List[Expression]) -> Expression: + def builder(alias: str | None, params: list[Expression]) -> Expression: # The first element in the sequence is the inner function. # Unfortunately I could not find a better way to reuse this # between FunctionCall and CurriedFunctionCall. @@ -321,7 +310,7 @@ def __init__(self, alias_generator: AliasGenerator) -> None: def visit_function_call(self, exp: FunctionCall) -> SubExpression: def builder( - alias: Optional[str], func_name: str, params: Sequence[Expression] + alias: str | None, func_name: str, params: Sequence[Expression] ) -> FunctionCall: return FunctionCall(alias, func_name, tuple(params)) diff --git a/snuba/query/joins/equivalence_adder.py b/snuba/query/joins/equivalence_adder.py index 87078b2882e..a37e05bb949 100644 --- a/snuba/query/joins/equivalence_adder.py +++ b/snuba/query/joins/equivalence_adder.py @@ -1,5 +1,5 @@ +from collections.abc import Mapping, MutableMapping from functools import partial -from typing import Mapping, MutableMapping, Optional, Set, Tuple from snuba.datasets.entities.entity_key import EntityKey from snuba.query import ProcessableQuery @@ -43,7 +43,7 @@ def add_equivalent_conditions(query: CompositeQuery[Entity]) -> None: if isinstance(from_clause, CompositeQuery): add_equivalent_conditions(from_clause) return - elif isinstance(from_clause, ProcessableQuery): + if isinstance(from_clause, ProcessableQuery): return # Now this has to be a join, so we can work with it. @@ -54,7 +54,7 @@ def add_equivalent_conditions(query: CompositeQuery[Entity]) -> None: alias_to_entity = { alias: entity_from_node(node) for alias, node in from_clause.get_alias_node_map().items() } - entity_to_alias: MutableMapping[EntityKey, Set[str]] = {} + entity_to_alias: MutableMapping[EntityKey, set[str]] = {} for alias, entity in alias_to_entity.items(): entity_to_alias.setdefault(entity, set()).add(alias) @@ -110,12 +110,12 @@ def add_equivalent_conditions(query: CompositeQuery[Entity]) -> None: def _classify_single_column_condition( condition: Expression, alias_entity_map: Mapping[str, EntityKey] -) -> Optional[Tuple[QualifiedCol, str]]: +) -> tuple[QualifiedCol, str] | None: """ Inspects a condition to check if it is a condition on a single column on a single entity """ - qualified_col: Optional[Tuple[QualifiedCol, str]] = None + qualified_col: tuple[QualifiedCol, str] | None = None for e in condition: if isinstance(e, Column): if not e.table_name: diff --git a/snuba/query/joins/metrics_subquery_generator.py b/snuba/query/joins/metrics_subquery_generator.py index 03912f5fe07..efb024e4c04 100644 --- a/snuba/query/joins/metrics_subquery_generator.py +++ b/snuba/query/joins/metrics_subquery_generator.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Generator, Mapping +from collections.abc import Generator, Mapping from snuba.query import ProcessableQuery, SelectedExpression from snuba.query.composite import CompositeQuery @@ -120,7 +120,7 @@ def _push_down_groupby_branches( return cut_subexpression.main_expression -def _alias_generator() -> Generator[str, None, None]: +def _alias_generator() -> Generator[str]: i = 0 while True: i += 1 @@ -201,7 +201,7 @@ def generate_metrics_subqueries(query: CompositeQuery[Entity]) -> None: if isinstance(from_clause, CompositeQuery): generate_subqueries(from_clause) return - elif isinstance(from_clause, ProcessableQuery): + if isinstance(from_clause, ProcessableQuery): return # Now this has to be a join, so we can work with it. diff --git a/snuba/query/joins/pre_processor.py b/snuba/query/joins/pre_processor.py index 6f95e7a1870..506279efaeb 100644 --- a/snuba/query/joins/pre_processor.py +++ b/snuba/query/joins/pre_processor.py @@ -1,7 +1,8 @@ from __future__ import annotations +from collections.abc import Mapping, MutableMapping from copy import copy -from typing import Mapping, MutableMapping, NamedTuple, Set +from typing import NamedTuple from snuba.datasets.entities.entity_key import EntityKey from snuba.datasets.entities.factory import get_entity @@ -25,7 +26,7 @@ class QualifiedCol(NamedTuple): # Each node is a QualifiedCol instance, which represents entity # and column. # Each edge represent an equivalence between two nodes. -EquivalenceGraph = MutableMapping[QualifiedCol, Set[QualifiedCol]] +EquivalenceGraph = MutableMapping[QualifiedCol, set[QualifiedCol]] class EquivalenceExtractor(JoinVisitor[EquivalenceGraph, Entity]): @@ -43,7 +44,7 @@ class EquivalenceExtractor(JoinVisitor[EquivalenceGraph, Entity]): An EquivalenceGraph is produced. """ - def __init__(self, entities_in_join: Set[EntityKey]) -> None: + def __init__(self, entities_in_join: set[EntityKey]) -> None: # We initialize the visitor with the list of entities present # in the join to filter the graph just because extracting this # list inside the visitor before we start processing would @@ -106,7 +107,7 @@ def merge_into_graph(node: JoinNode[Entity]) -> None: def get_equivalent_columns( join: JoinClause[Entity], -) -> Mapping[QualifiedCol, Set[QualifiedCol]]: +) -> Mapping[QualifiedCol, set[QualifiedCol]]: """ Given a Join, it returns the set of all the semantically equivalent columns across the entities involved in the join. @@ -126,7 +127,7 @@ def get_equivalent_columns( same connected component """ - def traverse_graph(node: QualifiedCol, visited_nodes: Set[QualifiedCol]) -> Set[QualifiedCol]: + def traverse_graph(node: QualifiedCol, visited_nodes: set[QualifiedCol]) -> set[QualifiedCol]: """ Traverse the whole connected component in with a depth first algorithm starting from the node provided. @@ -140,7 +141,7 @@ def traverse_graph(node: QualifiedCol, visited_nodes: Set[QualifiedCol]) -> Set[ entities_in_join = {entity_from_node(node) for node in join.get_alias_node_map().values()} adjacency_sets = join.accept(EquivalenceExtractor(entities_in_join)) - connected_components: MutableMapping[QualifiedCol, Set[QualifiedCol]] = {} + connected_components: MutableMapping[QualifiedCol, set[QualifiedCol]] = {} for node in adjacency_sets: if node not in connected_components: diff --git a/snuba/query/joins/semi_joins.py b/snuba/query/joins/semi_joins.py index cd51e9cfa78..c53b73c7a74 100644 --- a/snuba/query/joins/semi_joins.py +++ b/snuba/query/joins/semi_joins.py @@ -1,5 +1,3 @@ -from typing import Set - from snuba.query import ProcessableQuery from snuba.query.composite import CompositeQuery from snuba.query.data_source.join import ( @@ -44,7 +42,7 @@ def process_query(self, query: CompositeQuery[Table], query_settings: QuerySetti if isinstance(from_clause, CompositeQuery): self.process_query(from_clause, query_settings) return - elif isinstance(from_clause, ProcessableQuery): + if isinstance(from_clause, ProcessableQuery): return # Now this has to be a join, so we can work with it. @@ -54,7 +52,7 @@ def process_query(self, query: CompositeQuery[Table], query_settings: QuerySetti class SemiJoinGenerator(JoinVisitor[JoinNode[Table], Table]): - def __init__(self, referenced_columns: Set[Column]) -> None: + def __init__(self, referenced_columns: set[Column]) -> None: self.__referenced_columns = referenced_columns def visit_individual_node(self, node: IndividualNode[Table]) -> IndividualNode[Table]: diff --git a/snuba/query/joins/subquery_generator.py b/snuba/query/joins/subquery_generator.py index 7dcb2d87d58..fe6a5fa1d1f 100644 --- a/snuba/query/joins/subquery_generator.py +++ b/snuba/query/joins/subquery_generator.py @@ -1,7 +1,8 @@ from __future__ import annotations +from collections.abc import Generator, Mapping from dataclasses import replace -from typing import Generator, Mapping, cast +from typing import cast from snuba.query import ProcessableQuery, SelectedExpression from snuba.query.composite import CompositeQuery @@ -64,11 +65,9 @@ def build_query(self) -> ProcessableQuery[Entity]: ProcessableQuery[Entity], LogicalQuery( from_clause=self.__data_source, - selected_columns=list( - sorted( - self.__selected_expressions, - key=lambda selected: selected.name or "", - ) + selected_columns=sorted( + self.__selected_expressions, + key=lambda selected: selected.name or "", ), condition=( combine_and_conditions(self.__conditions) if self.__conditions else None @@ -198,7 +197,7 @@ def _push_down_branches( return cut_subexpression.main_expression -def _alias_generator() -> Generator[str, None, None]: +def _alias_generator() -> Generator[str]: i = 0 while True: i += 1 @@ -242,7 +241,7 @@ def generate_subqueries(query: CompositeQuery[Entity]) -> None: if isinstance(from_clause, CompositeQuery): generate_subqueries(from_clause) return - elif isinstance(from_clause, ProcessableQuery): + if isinstance(from_clause, ProcessableQuery): return # Now this has to be a join, so we can work with it. diff --git a/snuba/query/logical.py b/snuba/query/logical.py index 53f3d9ae193..1d98255fbc7 100644 --- a/snuba/query/logical.py +++ b/snuba/query/logical.py @@ -1,7 +1,8 @@ from __future__ import annotations from abc import ABCMeta -from typing import Any, Callable, Iterable, Optional, Sequence, Type, Union, cast +from collections.abc import Callable, Iterable, Sequence +from typing import Any, cast from snuba.query import LimitBy, OrderBy, ProcessableQuery, SelectedExpression from snuba.query.composite import CompositeQuery @@ -23,21 +24,21 @@ class Query(ProcessableQuery[LogicalDataSource]): def __init__( self, - from_clause: Optional[LogicalDataSource], + from_clause: LogicalDataSource | None, # New data model to replace the one based on the dictionary - selected_columns: Optional[Sequence[SelectedExpression]] = None, - array_join: Optional[Sequence[Expression]] = None, - condition: Optional[Expression] = None, - prewhere: Optional[Expression] = None, - groupby: Optional[Sequence[Expression]] = None, - having: Optional[Expression] = None, - order_by: Optional[Sequence[OrderBy]] = None, - limitby: Optional[LimitBy] = None, - sample: Optional[float] = None, - limit: Optional[int] = None, + selected_columns: Sequence[SelectedExpression] | None = None, + array_join: Sequence[Expression] | None = None, + condition: Expression | None = None, + prewhere: Expression | None = None, + groupby: Sequence[Expression] | None = None, + having: Expression | None = None, + order_by: Sequence[OrderBy] | None = None, + limitby: LimitBy | None = None, + sample: float | None = None, + limit: int | None = None, offset: int = 0, totals: bool = False, - granularity: Optional[int] = None, + granularity: int | None = None, ): """ Expects an already parsed query body. @@ -68,7 +69,7 @@ def get_final(self) -> bool: def set_final(self, final: bool) -> None: self.__final = final - def get_sample(self) -> Optional[float]: + def get_sample(self) -> float | None: return self.__sample def _eq_functions(self) -> Sequence[str]: @@ -95,8 +96,7 @@ def __instancecheck__(self, instance: Any) -> bool: data_source_type = cast(type, getattr(self, "data_source", object)()) instance_data_source = instance.get_from_clause() return isinstance(instance_data_source, data_source_type) - else: - return False + return False """ @@ -135,7 +135,7 @@ def __instancecheck__(self, instance: Any) -> bool: class EntityQuery(Query, metaclass=_FlexibleQueryType): @classmethod - def data_source(cls) -> Type[Entity]: + def data_source(cls) -> type[Entity]: return Entity def get_from_clause(self) -> Entity: @@ -144,13 +144,11 @@ def get_from_clause(self) -> Entity: @classmethod def check_data_source( cls, - data_source: Union[ - Query, - ProcessableQuery[Entity], - CompositeQuery[Entity], - JoinClause[Entity], - IndividualNode[Entity], - ], + data_source: Query + | ProcessableQuery[Entity] + | CompositeQuery[Entity] + | JoinClause[Entity] + | IndividualNode[Entity], ) -> None: if isinstance(data_source, JoinClause): if isinstance(data_source.left_node, IndividualNode): @@ -165,20 +163,20 @@ def check_data_source( assert isinstance(data_source.get_from_clause(), cls.data_source()) @classmethod - def from_query(cls, query: Union[Query, CompositeQuery[Entity]]) -> "EntityQuery": + def from_query(cls, query: Query | CompositeQuery[Entity]) -> EntityQuery: cls.check_data_source(query) return cast("EntityQuery", query) class StorageQuery(Query, metaclass=_FlexibleQueryType): @classmethod - def data_source(cls) -> Type[Storage]: + def data_source(cls) -> type[Storage]: return Storage def get_from_clause(self) -> Storage: return cast(Storage, super().get_from_clause()) @classmethod - def from_query(cls, query: Query) -> "StorageQuery": + def from_query(cls, query: Query) -> StorageQuery: assert isinstance(query.get_from_clause(), cls.data_source()) return cast("StorageQuery", query) diff --git a/snuba/query/matchers.py b/snuba/query/matchers.py index 055f6e3cb25..992446e759b 100644 --- a/snuba/query/matchers.py +++ b/snuba/query/matchers.py @@ -1,10 +1,11 @@ from __future__ import annotations from abc import ABC, abstractmethod +from collections.abc import Mapping, Sequence from dataclasses import dataclass, field from datetime import date, datetime from typing import Any as AnyType -from typing import Generic, Mapping, Optional, Sequence, Tuple, Type, TypeVar, Union +from typing import Generic, TypeVar from snuba.query.expressions import Column as ColumnExpr from snuba.query.expressions import Expression, OptionalScalarType @@ -12,7 +13,7 @@ from snuba.query.expressions import Literal as LiteralExpr from snuba.query.expressions import SubscriptableReference as SubscriptableReferenceExpr -MatchType = Union[Expression, OptionalScalarType] +MatchType = Expression | OptionalScalarType TMatchedType = TypeVar("TMatchedType", covariant=True) @@ -65,7 +66,7 @@ def string(self, name: str) -> str: assert isinstance(ret, str), type(ret) return ret - def optional_string(self, name: str) -> Optional[str]: + def optional_string(self, name: str) -> str | None: """ Returns a string present in the result or it is None. """ @@ -114,7 +115,7 @@ class Pattern(ABC, Generic[TMatchedType]): """ @abstractmethod - def match(self, node: AnyType) -> Optional[MatchResult]: + def match(self, node: AnyType) -> MatchResult | None: """ Returns a MatchResult if the node provided matches this pattern otherwise it returns None. @@ -146,7 +147,7 @@ class Param(Pattern[TMatchedType]): name: str pattern: Pattern[TMatchedType] - def match(self, node: AnyType) -> Optional[MatchResult]: + def match(self, node: AnyType) -> MatchResult | None: result = self.pattern.match(node) if result is None: return None @@ -161,7 +162,7 @@ class AnyExpression(Pattern[Expression]): match abstract classes (like Expression) """ - def match(self, node: AnyType) -> Optional[MatchResult]: + def match(self, node: AnyType) -> MatchResult | None: return MatchResult() if isinstance(node, Expression) else None @@ -171,9 +172,9 @@ class Any(Pattern[TMatchedType]): Match any concrete expression/scalar of the type provided. """ - type: Type[TMatchedType] + type: type[TMatchedType] - def match(self, node: AnyType) -> Optional[MatchResult]: + def match(self, node: AnyType) -> MatchResult | None: return MatchResult() if isinstance(node, self.type) else None @@ -185,7 +186,7 @@ class String(Pattern[str]): value: str - def match(self, node: AnyType) -> Optional[MatchResult]: + def match(self, node: AnyType) -> MatchResult | None: return MatchResult() if node == self.value else None @@ -197,30 +198,30 @@ class Integer(Pattern[int]): value: int - def match(self, node: AnyType) -> Optional[MatchResult]: + def match(self, node: AnyType) -> MatchResult | None: return MatchResult() if node == self.value else None @dataclass(frozen=True) -class OptionalString(Pattern[Optional[str]]): +class OptionalString(Pattern[str | None]): """ Matches one specific string (or None). """ - value: Optional[str] + value: str | None - def match(self, node: AnyType) -> Optional[MatchResult]: + def match(self, node: AnyType) -> MatchResult | None: return MatchResult() if node == self.value else None @dataclass(frozen=True) -class AnyOptionalString(Pattern[Optional[str]]): +class AnyOptionalString(Pattern[str | None]): """ Matches any string including the None value. This cannot be done with Any(type) because that cannot match Union[str, None]. """ - def match(self, node: AnyType) -> Optional[MatchResult]: + def match(self, node: AnyType) -> MatchResult | None: return MatchResult() if node is None or isinstance(node, str) else None @@ -233,7 +234,7 @@ class Or(Pattern[TMatchedType]): patterns: Sequence[Pattern[TMatchedType]] - def match(self, node: AnyType) -> Optional[MatchResult]: + def match(self, node: AnyType) -> MatchResult | None: for p in self.patterns: ret = p.match(node) if ret: @@ -251,10 +252,10 @@ class Column(Pattern[ColumnExpr]): (equivalent to Any, but less verbose). """ - table_name: Optional[Pattern[Optional[str]]] = None - column_name: Optional[Pattern[str]] = None + table_name: Pattern[str | None] | None = None + column_name: Pattern[str] | None = None - def match(self, node: AnyType) -> Optional[MatchResult]: + def match(self, node: AnyType) -> MatchResult | None: if not isinstance(node, ColumnExpr): return None @@ -274,16 +275,15 @@ def match(self, node: AnyType) -> Optional[MatchResult]: @dataclass(frozen=True) class Literal(Pattern[LiteralExpr]): - value: Optional[Pattern[OptionalScalarType]] = None + value: Pattern[OptionalScalarType] | None = None - def match(self, node: AnyType) -> Optional[MatchResult]: + def match(self, node: AnyType) -> MatchResult | None: if not isinstance(node, LiteralExpr): return None if self.value is not None: return self.value.match(node.value) - else: - return MatchResult() + return MatchResult() @dataclass(frozen=True) @@ -294,11 +294,11 @@ class FunctionCall(Pattern[FunctionCallExpr]): are provided, they have to match, otherwise they are ignored. """ - function_name: Optional[Pattern[str]] = None + function_name: Pattern[str] | None = None # This is a tuple instead of a sequence to match the data structure # we use in the actual FunctionCall class. There it has to be a tuple # to be hashable. - parameters: Optional[Tuple[Pattern[Expression], ...]] = None + parameters: tuple[Pattern[Expression], ...] | None = None # Specifies whether we allow optional parameters when matching. # if this is False, all patterns of the function to match must match # one by one. If with_optionals is True, this will allow additional @@ -311,9 +311,9 @@ class FunctionCall(Pattern[FunctionCallExpr]): # If it is set, then it will iterate through the parameters and # check them against the type. If this is set, it's not necessary # to also specify the parameters field. - all_parameters: Optional[Pattern[Expression]] = None + all_parameters: Pattern[Expression] | None = None - def match(self, node: AnyType) -> Optional[MatchResult]: + def match(self, node: AnyType) -> MatchResult | None: if not isinstance(node, FunctionCallExpr): return None @@ -337,8 +337,7 @@ def match(self, node: AnyType) -> Optional[MatchResult]: p_result = param_pattern.match(node.parameters[index]) if p_result is None: return None - else: - result = result.merge(p_result) + result = result.merge(p_result) if self.all_parameters: for p in node.parameters: @@ -355,11 +354,11 @@ class SubscriptableReference(Pattern[SubscriptableReferenceExpr]): If column_name and key arguments are provided, they have to match, otherwise they are ignored. """ - table_name: Optional[Pattern[Optional[str]]] = None - column_name: Optional[Pattern[str]] = None - key: Optional[Pattern[str]] = None + table_name: Pattern[str | None] | None = None + column_name: Pattern[str] | None = None + key: Pattern[str] | None = None - def match(self, node: AnyType) -> Optional[MatchResult]: + def match(self, node: AnyType) -> MatchResult | None: if not isinstance(node, SubscriptableReferenceExpr): return None diff --git a/snuba/query/mql/mql_context.py b/snuba/query/mql/mql_context.py index 68650da92b6..a93cdb5bba1 100644 --- a/snuba/query/mql/mql_context.py +++ b/snuba/query/mql/mql_context.py @@ -1,7 +1,8 @@ from __future__ import annotations +from collections.abc import Sequence from dataclasses import dataclass -from typing import Any, Sequence +from typing import Any from snuba.query.parser.exceptions import ParsingException @@ -63,7 +64,7 @@ def from_dict(mql_context_dict: dict[str, Any]) -> MQLContext: extrapolate=mql_context_dict.get("extrapolate", False), ) except KeyError as e: - raise ParsingException(f"MQL context: missing required field {e}") + raise ParsingException(f"MQL context: missing required field {e}") from e @dataclass(frozen=True) diff --git a/snuba/query/mql/parser.py b/snuba/query/mql/parser.py index 6aa3d3fc63b..a1dc2ed3802 100644 --- a/snuba/query/mql/parser.py +++ b/snuba/query/mql/parser.py @@ -1,8 +1,9 @@ from __future__ import annotations import logging +from collections.abc import Callable, Sequence from dataclasses import dataclass, replace -from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union +from typing import Any import sentry_sdk from parsimonious.exceptions import IncompleteParseError @@ -74,7 +75,7 @@ # The parser returns a bunch of different types, so create a single aggregate type to # capture everything. -MQLSTUFF = Dict[str, Union[str, list[SelectedExpression], list[Expression]]] +MQLSTUFF = dict[str, str | list[SelectedExpression] | list[Expression]] logger = logging.getLogger("snuba.mql.parser") @@ -89,7 +90,7 @@ class InitialParseResult: table_alias: str | None = None -FormulaParameter = Union[InitialParseResult, int, float] +FormulaParameter = InitialParseResult | int | float ARITHMETIC_OPERATORS_MAPPING = { "+": "plus", @@ -103,7 +104,7 @@ class InitialParseResult: } -class MQLVisitor(NodeVisitor): # type: ignore +class MQLVisitor(NodeVisitor): # type: ignore[misc] """ Builds the arguments for a Snuba AST from the MQL Parsimonious parse tree. """ @@ -114,7 +115,7 @@ def __init__(self) -> None: def visit_expression( self, node: Node, - children: Tuple[ + children: tuple[ InitialParseResult, Any, ], @@ -141,7 +142,7 @@ def visit_unary_op(self, node: Node, children: Sequence[Any]) -> Any: def visit_term( self, node: Node, - children: Tuple[InitialParseResult, Any], + children: tuple[InitialParseResult, Any], ) -> InitialParseResult: term, zero_or_more_others = children if zero_or_more_others: @@ -161,25 +162,24 @@ def visit_term( def visit_unary(self, node: Node, children: Sequence[Any]) -> Any: unary_op, coefficient = children if unary_op: - if isinstance(coefficient, float) or isinstance(coefficient, int): + if isinstance(coefficient, (float, int)): return -coefficient - elif isinstance(coefficient, InitialParseResult): + if isinstance(coefficient, InitialParseResult): return InitialParseResult( expression=None, formula=unary_op[0], parameters=[coefficient], ) - else: - raise InvalidQueryException( - f"Unary expression not supported for type {type(coefficient)}" - ) + raise InvalidQueryException( + f"Unary expression not supported for type {type(coefficient)}" + ) return coefficient def visit_coefficient( self, node: Node, - children: Tuple[InitialParseResult], + children: tuple[InitialParseResult], ) -> InitialParseResult: return children[0] @@ -189,7 +189,7 @@ def visit_number(self, node: Node, children: Sequence[Any]) -> float: def visit_filter( self, node: Node, - children: Tuple[ + children: tuple[ InitialParseResult, Any, Sequence[Any], @@ -256,9 +256,8 @@ def _filter(self, children: Sequence[Any], operator: str) -> FunctionCall: ] if len(filters) == 1: return filters[0] - else: - # We flatten all filters into a single condition since Snuba supports it. - return FunctionCall(None, operator, tuple(filters)) + # We flatten all filters into a single condition since Snuba supports it. + return FunctionCall(None, operator, tuple(filters)) def visit_filter_expr(self, node: Node, children: Sequence[Any]) -> Any: return self._filter(children, BooleanFunctions.OR) @@ -269,7 +268,7 @@ def visit_filter_term(self, node: Node, children: Sequence[Any]) -> Any: def visit_filter_factor( self, node: Node, - children: Tuple[Sequence[str | Sequence[str] | FilterFactorValue] | FunctionCall, Any], + children: tuple[Sequence[str | Sequence[str] | FilterFactorValue] | FunctionCall, Any], ) -> FunctionCall: factor, *_ = children if isinstance(factor, FunctionCall): @@ -280,7 +279,7 @@ def visit_filter_factor( lhs: str filter_factor_value: FilterFactorValue - condition_op, lhs, _, _, _, filter_factor_value = factor # type: ignore + condition_op, lhs, _, _, _, filter_factor_value = factor # type: ignore[assignment] condition_op_value = "!" if len(condition_op) == 1 and condition_op[0] == "!" else "" contains_wildcard = filter_factor_value.contains_wildcard @@ -320,20 +319,19 @@ def visit_filter_factor( ), ), ) - else: - assert isinstance(rhs, str) - if not condition_op_value: - op = ConditionFunctions.EQ - elif condition_op_value == "!": - op = ConditionFunctions.NEQ - return FunctionCall( - None, - op, - ( - Column(None, None, lhs[0]), - Literal(None, rhs), - ), - ) + assert isinstance(rhs, str) + if not condition_op_value: + op = ConditionFunctions.EQ + elif condition_op_value == "!": + op = ConditionFunctions.NEQ + return FunctionCall( + None, + op, + ( + Column(None, None, lhs[0]), + Literal(None, rhs), + ), + ) def visit_nested_expr(self, node: Node, children: Sequence[Any]) -> Any: _, _, filter_expr, *_ = children @@ -342,8 +340,8 @@ def visit_nested_expr(self, node: Node, children: Sequence[Any]) -> Any: def visit_function( self, node: Node, - children: Tuple[ - Tuple[InitialParseResult,], + children: tuple[ + tuple[InitialParseResult,], Sequence[list[SelectedExpression]], ], ) -> InitialParseResult: @@ -358,7 +356,7 @@ def visit_function( def visit_group_by( self, node: Node, - children: Tuple[Any, Any, Any, Sequence[Sequence[str]]], + children: tuple[Any, Any, Any, Sequence[Sequence[str]]], ) -> list[SelectedExpression]: *_, groupbys = children groupby = groupbys[0] @@ -439,7 +437,7 @@ def visit_group_by_name_tuple(self, node: Node, children: Sequence[Any]) -> Sequ def visit_target( self, node: Node, - children: Sequence[Union[InitialParseResult, Sequence[InitialParseResult]]], + children: Sequence[InitialParseResult | Sequence[InitialParseResult]], ) -> InitialParseResult: target = children[0] if isinstance(children[0], list): @@ -452,16 +450,16 @@ def visit_variable(self, node: Node, children: Sequence[Any]) -> str: raise InvalidQueryException("Variables are not supported yet") def visit_nested_expression( - self, node: Node, children: Tuple[Any, Any, InitialParseResult] + self, node: Node, children: tuple[Any, Any, InitialParseResult] ) -> InitialParseResult: return children[2] def visit_aggregate( self, node: Node, - children: Tuple[ + children: tuple[ str, - Tuple[ + tuple[ Any, Any, InitialParseResult, @@ -485,10 +483,10 @@ def visit_aggregate( def visit_curried_aggregate( self, node: Node, - children: Tuple[ + children: tuple[ str, - Tuple[Any, Any, Sequence[Sequence[Union[str, int, float]]], Any, Any], - Tuple[Any, Any, InitialParseResult, Any, Any], + tuple[Any, Any, Sequence[Sequence[str | int | float]], Any, Any], + tuple[Any, Any, InitialParseResult, Any, Any], ], ) -> InitialParseResult: aggregate_name, agg_params, zero_or_one = children @@ -513,12 +511,12 @@ def visit_curried_aggregate( def visit_arbitrary_function( self, node: Node, - children: Tuple[ + children: tuple[ str, - Tuple[ + tuple[ Any, Sequence[InitialParseResult], - Sequence[Sequence[Union[str, int, float]]], + Sequence[Sequence[str | int | float]], Any, ], ], @@ -596,23 +594,21 @@ def visit_inner_filter(self, node: Node, children: Sequence[Any]) -> InitialPars target.groupby = group_by return target - def visit_param( - self, node: Node, children: Tuple[Union[str, int, float], Any] - ) -> Union[str, int, float]: + def visit_param(self, node: Node, children: tuple[str | int | float, Any]) -> str | int | float: param, *_ = children return param def visit_param_expression( - self, node: Node, children: Tuple[Union[str, int, float], Any] - ) -> Union[str, int, float]: + self, node: Node, children: tuple[str | int | float, Any] + ) -> str | int | float: param = children[0] return param def visit_aggregate_list( self, node: Node, - children: Tuple[list[Union[str, int, float]], Optional[Union[str, int, float]]], - ) -> Sequence[Union[str, int, float]]: + children: tuple[list[str | int | float], str | int | float | None], + ) -> Sequence[str | int | float]: agg_params, param = children if param is not None: agg_params.append(param) @@ -701,12 +697,12 @@ def parse_mql_query_body(body: str, dataset: Dataset) -> EntityQuery: idx = e.column() prefix = line[max(0, idx - 3) : idx] suffix = line[idx : (idx + 10)] - raise ParsingException(f"Parsing error on line {e.line()} at '{prefix}{suffix}'") + raise ParsingException(f"Parsing error on line {e.line()} at '{prefix}{suffix}'") from e except Exception as e: message = str(e) if "\n" in message: message, _ = message.split("\n", 1) - raise ParsingException(message) + raise ParsingException(message) from e if not parsed.expression and not parsed.formula: raise ParsingException("No aggregate/expression or formula specified in MQL query") @@ -740,15 +736,14 @@ def build_formula_query_from_clause( def find_all_leaf_nodes(tree: FormulaParameter) -> list[InitialParseResult] | None: if isinstance(tree, InitialParseResult) and tree.formula is None: return [tree] - elif isinstance(tree, InitialParseResult) and tree.formula is not None: + if isinstance(tree, InitialParseResult) and tree.formula is not None: nodes = [] for p in tree.parameters or []: found = find_all_leaf_nodes(p) if found: nodes.extend(found) return nodes - else: - return None + return None join_nodes = find_all_leaf_nodes(parsed) if join_nodes is None: @@ -759,9 +754,8 @@ def find_all_leaf_nodes(tree: FormulaParameter) -> list[InitialParseResult] | No # Example: sum(`transactions.duration`) by transaction / sum(`transactions.duration`) groupbys = join_nodes[0].groupby for node in join_nodes: - if node.groupby is not None: - if node.groupby != groupbys: - raise InvalidQueryException("All terms in a formula must have the same groupby") + if node.groupby is not None and node.groupby != groupbys: + raise InvalidQueryException("All terms in a formula must have the same groupby") entity_keys = [select_entity(node.mri or "", dataset) for node in join_nodes] if len(entity_keys) == 1: @@ -922,7 +916,7 @@ def extract_expression(param: InitialParseResult | Any) -> Expression: if leaf_node.groupby: for group_exp in leaf_node.groupby: if isinstance(group_exp.expression, Column): - alias: Optional[str] + alias: str | None if alias_wrap(leaf_node.table_alias): alias = f"{alias_wrap(leaf_node.table_alias)}.{group_exp.expression.alias}" else: @@ -960,7 +954,7 @@ def wrap_condition_columns(fn_call: FunctionCall) -> FunctionCall: if not isinstance(param, InitialParseResult): return [] - elif param.expression is not None: + if param.expression is not None: conditions = [] for c in param.conditions or []: assert isinstance(c, FunctionCall) @@ -973,13 +967,12 @@ def wrap_condition_columns(fn_call: FunctionCall) -> FunctionCall: ) ) return conditions - elif param.formula: + if param.formula: conditions = [] for p in param.parameters or []: conditions.extend(extract_filters(p)) return conditions - else: - raise InvalidQueryException("Could not extract valid filters for formula") + raise InvalidQueryException("Could not extract valid filters for formula") conditions = [] for p in parsed.parameters or []: @@ -1033,9 +1026,10 @@ def select_entity(mri: str, dataset: Dataset) -> EntityKey: if get_dataset_name(dataset) == "metrics": if entity := METRICS_ENTITIES.get(mri[0]): return entity - elif get_dataset_name(dataset) == "generic_metrics": - if entity := GENERIC_ENTITIES.get(mri[0]): - return entity + elif get_dataset_name(dataset) == "generic_metrics" and ( + entity := GENERIC_ENTITIES.get(mri[0]) + ): + return entity raise ParsingException(f"invalid metric type {mri[0]}") @@ -1087,6 +1081,9 @@ def populate_query_from_mql_context( query.set_totals(with_totals) if isinstance(query, CompositeQuery): + join_clause = query.get_from_clause() + assert isinstance(join_clause, JoinClause) + alias_node_map = join_clause.get_alias_node_map() def add_time_join_keys(join_clause: JoinClause[Any]) -> str: match (join_clause.left_node, join_clause.right_node): @@ -1145,9 +1142,7 @@ def convert_to_cross_join(join_clause: JoinClause[Any]) -> JoinClause[Any]: number_of_groupbys = len(query.get_groupby()) no_groupby_or_one_sided_groupby = False - if number_of_groupbys == 0: - no_groupby_or_one_sided_groupby = True - elif number_of_groupbys % number_of_joins != 0: + if number_of_groupbys == 0 or number_of_groupbys % number_of_joins != 0: no_groupby_or_one_sided_groupby = True if selected_time: @@ -1217,7 +1212,7 @@ def convert_cols_to_extrapolated(expr: Expression) -> Expression: return query, mql_context -def quantiles_to_quantile(query: Union[CompositeQuery[LogicalDataSource], LogicalQuery]) -> None: +def quantiles_to_quantile(query: CompositeQuery[LogicalDataSource] | LogicalQuery) -> None: """ Changes quantiles(0.5)(...) to arrayElement(quantiles(0.5)(...), 1). This is to simplify the API (so that the arrays don't need to be unwrapped) and also avoids bugs where comparing @@ -1225,19 +1220,19 @@ def quantiles_to_quantile(query: Union[CompositeQuery[LogicalDataSource], Logica """ def transform(exp: Expression) -> Expression: - if isinstance(exp, CurriedFunctionCall): - if exp.internal_function.function_name in ("quantiles", "quantilesIf"): - if len(exp.internal_function.parameters) == 1: - return arrayElement(exp.alias, replace(exp, alias=None), Literal(None, 1)) + if ( + isinstance(exp, CurriedFunctionCall) + and exp.internal_function.function_name in ("quantiles", "quantilesIf") + and len(exp.internal_function.parameters) == 1 + ): + return arrayElement(exp.alias, replace(exp, alias=None), Literal(None, 1)) return exp query.transform_expressions(transform) -CustomProcessors = Sequence[ - Callable[[Union[CompositeQuery[LogicalDataSource], LogicalQuery]], None] -] +CustomProcessors = Sequence[Callable[[CompositeQuery[LogicalDataSource] | LogicalQuery], None]] MQL_POST_PROCESSORS: CustomProcessors = POST_PROCESSORS + [ quantiles_to_quantile, @@ -1248,9 +1243,9 @@ def parse_mql_query( body: str, mql_context_dict: dict[str, Any], dataset: Dataset, - custom_processing: Optional[CustomProcessors] = None, + custom_processing: CustomProcessors | None = None, settings: QuerySettings | None = None, -) -> Union[CompositeQuery[LogicalDataSource], LogicalQuery]: +) -> CompositeQuery[LogicalDataSource] | LogicalQuery: # dummy variables that dont matter dummy_timer = Timer("mql_pipeline") dummy_settings = HTTPQuerySettings() @@ -1367,6 +1362,6 @@ def _process_data( @dataclass -class FilterFactorValue(object): +class FilterFactorValue: value: str | Sequence[str] | Condition | BooleanCondition contains_wildcard: bool diff --git a/snuba/query/parser/__init__.py b/snuba/query/parser/__init__.py index b69126e81a0..19da1740e2a 100644 --- a/snuba/query/parser/__init__.py +++ b/snuba/query/parser/__init__.py @@ -1,5 +1,5 @@ +from collections.abc import Mapping, MutableMapping, Sequence from dataclasses import replace -from typing import Mapping, MutableMapping, Optional, Sequence, Tuple, Union from snuba import environment from snuba.query.composite import CompositeQuery @@ -25,7 +25,7 @@ metrics = MetricsWrapper(environment.metrics, "parser") -def validate_aliases(query: Union[CompositeQuery[LogicalDataSource], Query]) -> None: +def validate_aliases(query: CompositeQuery[LogicalDataSource] | Query) -> None: """ Ensures that no alias has been defined multiple times for different expressions in the query. Thus rejecting queries with shadowing. @@ -46,11 +46,10 @@ def validate_aliases(query: Union[CompositeQuery[LogicalDataSource], Query]) -> ), should_report=False, ) - else: - all_declared_aliases[exp.alias] = exp + all_declared_aliases[exp.alias] = exp -def parse_subscriptables(query: Union[CompositeQuery[LogicalDataSource], Query]) -> None: +def parse_subscriptables(query: CompositeQuery[LogicalDataSource] | Query) -> None: """ Turns columns formatted as tags[asd] into SubscriptableReference. """ @@ -74,7 +73,7 @@ def transform(exp: Expression) -> Expression: query.transform_expressions(transform) -def apply_column_aliases(query: Union[CompositeQuery[LogicalDataSource], Query]) -> None: +def apply_column_aliases(query: CompositeQuery[LogicalDataSource] | Query) -> None: """ Applies an alias to all the columns in the query equal to the column name unless a column already has one or the alias is already defined. @@ -89,13 +88,12 @@ def apply_column_aliases(query: Union[CompositeQuery[LogicalDataSource], Query]) def apply_aliases(exp: Expression) -> Expression: if not isinstance(exp, Column) or exp.alias or exp.column_name in current_aliases: return exp - else: - return replace(exp, alias=exp.column_name) + return replace(exp, alias=exp.column_name) query.transform_expressions(apply_aliases) -def expand_aliases(query: Union[CompositeQuery[LogicalDataSource], Query]) -> None: +def expand_aliases(query: CompositeQuery[LogicalDataSource] | Query) -> None: """ Recursively expand all the references to aliases in the query. This makes life easy to query processors and translators that only have to @@ -179,10 +177,9 @@ def visit_column(self, exp: Column) -> Expression: if self.__expand_nested: # The expanded expression may contain more alias references to expand. return self.__alias_lookup_table[name].accept(self) - else: - return self.__alias_lookup_table[name] + return self.__alias_lookup_table[name] - def __append_alias(self, alias: Optional[str]) -> Sequence[str]: + def __append_alias(self, alias: str | None) -> Sequence[str]: return [*self.__visited_stack, alias] if alias is not None else self.__visited_stack def visit_subscriptable_reference(self, exp: SubscriptableReference) -> Expression: @@ -196,21 +193,25 @@ def visit_subscriptable_reference(self, exp: SubscriptableReference) -> Expressi assert isinstance(expanded_column, Column), ( "A subscriptable column cannot be resolved to anything other than a column" ) + expanded_key = exp.key.accept( + AliasExpanderVisitor( + self.__alias_lookup_table, + self.__append_alias(exp.alias), + self.__expand_nested, + ) + ) + assert isinstance(expanded_key, Literal), ( + "A subscriptable key cannot be resolved to anything other than a literal" + ) return replace( exp, column=expanded_column, - key=exp.key.accept( - AliasExpanderVisitor( - self.__alias_lookup_table, - self.__append_alias(exp.alias), - self.__expand_nested, - ) - ), + key=expanded_key, ) def __visit_sequence( - self, alias: Optional[str], parameters: Sequence[Expression] - ) -> Tuple[Expression, ...]: + self, alias: str | None, parameters: Sequence[Expression] + ) -> tuple[Expression, ...]: return tuple( p.accept( AliasExpanderVisitor( @@ -226,15 +227,20 @@ def visit_function_call(self, exp: FunctionCall) -> Expression: return replace(exp, parameters=self.__visit_sequence(exp.alias, exp.parameters)) def visit_curried_function_call(self, exp: CurriedFunctionCall) -> Expression: + internal_function = exp.internal_function.accept( + AliasExpanderVisitor( + self.__alias_lookup_table, + self.__append_alias(exp.alias), + self.__expand_nested, + ) + ) + assert isinstance(internal_function, FunctionCall), ( + "The internal function of a curried function call cannot be resolved " + "to anything other than a function call" + ) return replace( exp, - internal_function=exp.internal_function.accept( - AliasExpanderVisitor( - self.__alias_lookup_table, - self.__append_alias(exp.alias), - self.__expand_nested, - ) - ), + internal_function=internal_function, parameters=self.__visit_sequence(exp.alias, exp.parameters), ) diff --git a/snuba/query/parser/expressions.py b/snuba/query/parser/expressions.py index 5992945ae02..9a50c9a5650 100644 --- a/snuba/query/parser/expressions.py +++ b/snuba/query/parser/expressions.py @@ -1,4 +1,5 @@ -from typing import Any, Iterable, List, Tuple, Union +from collections.abc import Iterable +from typing import Any from parsimonious.grammar import Grammar from parsimonious.nodes import Node, NodeVisitor @@ -66,13 +67,13 @@ # parsimonious isn't properly type hinted yet, NodeVisitor has a type of Any # Add an ignore until parsimonious is properly typed. -class ClickhouseVisitor(NodeVisitor): # type: ignore +class ClickhouseVisitor(NodeVisitor): # type: ignore[misc] """ Builds Snuba AST expressions from the Parsimonious parse tree. """ def visit_root_element( - self, node: Node, visited_children: Tuple[Expression, Any] + self, node: Node, visited_children: tuple[Expression, Any] ) -> Expression: ret, _ = visited_children return ret @@ -84,12 +85,12 @@ def visit_column_name(self, node: Node, visited_children: Iterable[Any]) -> Colu return visit_column_name(node, visited_children) def visit_low_pri_tuple( - self, node: Node, visited_children: Tuple[LowPriOperator, Any, Expression] + self, node: Node, visited_children: tuple[LowPriOperator, Any, Expression] ) -> LowPriTuple: return visit_low_pri_tuple(node, visited_children) def visit_high_pri_tuple( - self, node: Node, visited_children: Tuple[HighPriOperator, Any, Expression] + self, node: Node, visited_children: tuple[HighPriOperator, Any, Expression] ) -> HighPriTuple: return visit_high_pri_tuple(node, visited_children) @@ -100,46 +101,46 @@ def visit_high_pri_op(self, node: Node, visited_children: Iterable[Any]) -> High return visit_high_pri_op(node, visited_children) def visit_arithmetic_term( - self, node: Node, visited_children: Tuple[Any, Expression] + self, node: Node, visited_children: tuple[Any, Expression] ) -> Expression: return visit_arithmetic_term(node, visited_children) def visit_low_pri_arithmetic( self, node: Node, - visited_children: Tuple[Any, Expression, LowPriArithmetic], + visited_children: tuple[Any, Expression, LowPriArithmetic], ) -> Expression: return visit_low_pri_arithmetic(node, visited_children) def visit_high_pri_arithmetic( self, node: Node, - visited_children: Tuple[Any, Expression, HighPriArithmetic], + visited_children: tuple[Any, Expression, HighPriArithmetic], ) -> Expression: return visit_high_pri_arithmetic(node, visited_children) def visit_numeric_literal(self, node: Node, visited_children: Iterable[Any]) -> Literal: return visit_numeric_literal(node, visited_children) - def visit_quoted_literal(self, node: Node, visited_children: Tuple[Node]) -> Literal: + def visit_quoted_literal(self, node: Node, visited_children: tuple[Node]) -> Literal: return visit_quoted_literal(node, visited_children) def visit_parameter( - self, node: Node, visited_children: Tuple[Expression, Any, Any, Any] + self, node: Node, visited_children: tuple[Expression, Any, Any, Any] ) -> Expression: return visit_parameter(node, visited_children) def visit_parameters_list( self, node: Node, - visited_children: Tuple[Union[Expression, List[Expression]], Expression], - ) -> List[Expression]: + visited_children: tuple[Expression | list[Expression], Expression], + ) -> list[Expression]: return visit_parameters_list(node, visited_children) def visit_function_call( self, node: Node, - visited_children: Tuple[str, Any, List[Expression], Any, Union[Node, List[Expression]]], + visited_children: tuple[str, Any, list[Expression], Any, Node | list[Expression]], ) -> Expression: return visit_function_call(node, visited_children) @@ -155,4 +156,4 @@ def parse_clickhouse_function(function: str) -> Expression: f"Cannot parse aggregation {function}", should_report=False ) from cause - return ClickhouseVisitor().visit(expression_tree) # type: ignore + return ClickhouseVisitor().visit(expression_tree) # type: ignore[no-any-return] diff --git a/snuba/query/parser/validation/functions.py b/snuba/query/parser/validation/functions.py index 393e8a36960..19595ecf7f9 100644 --- a/snuba/query/parser/validation/functions.py +++ b/snuba/query/parser/validation/functions.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import List, Mapping +from collections.abc import Mapping from snuba.clickhouse.columns import Array, DateTime, String from snuba.datasets.entities.factory import get_entity @@ -32,12 +32,12 @@ "like": SignatureValidator([Column({Array, String}), Any()]), "notLike": SignatureValidator([Column({Array, String}), Any()]), } -global_validators: List[FunctionCallValidator] = [AllowedFunctionValidator()] +global_validators: list[FunctionCallValidator] = [AllowedFunctionValidator()] class QueryEntityFinder( - DataSourceVisitor[List[QueryEntity], QueryEntity], - JoinVisitor[List[QueryEntity], QueryEntity], + DataSourceVisitor[list[QueryEntity], QueryEntity], + JoinVisitor[list[QueryEntity], QueryEntity], ): """ Finds the QueryEntity from the data source. The QueryEntity is passed @@ -65,22 +65,22 @@ class QueryEntityFinder( a single QueryEntity. """ - def _visit_simple_source(self, data_source: QueryEntity) -> List[QueryEntity]: + def _visit_simple_source(self, data_source: QueryEntity) -> list[QueryEntity]: return [data_source] - def _visit_join(self, data_source: JoinClause[QueryEntity]) -> List[QueryEntity]: + def _visit_join(self, data_source: JoinClause[QueryEntity]) -> list[QueryEntity]: return self.visit_join_clause(data_source) - def _visit_simple_query(self, data_source: ProcessableQuery[QueryEntity]) -> List[QueryEntity]: + def _visit_simple_query(self, data_source: ProcessableQuery[QueryEntity]) -> list[QueryEntity]: return self.visit(data_source.get_from_clause()) - def _visit_composite_query(self, data_source: CompositeQuery[QueryEntity]) -> List[QueryEntity]: + def _visit_composite_query(self, data_source: CompositeQuery[QueryEntity]) -> list[QueryEntity]: return [] - def visit_individual_node(self, node: IndividualNode[QueryEntity]) -> List[QueryEntity]: + def visit_individual_node(self, node: IndividualNode[QueryEntity]) -> list[QueryEntity]: return self.visit(node.data_source) - def visit_join_clause(self, node: JoinClause[QueryEntity]) -> List[QueryEntity]: + def visit_join_clause(self, node: JoinClause[QueryEntity]) -> list[QueryEntity]: return node.right_node.accept(self) + node.left_node.accept(self) diff --git a/snuba/query/parsing.py b/snuba/query/parsing.py index f8fb9bff686..6e0fc336610 100644 --- a/snuba/query/parsing.py +++ b/snuba/query/parsing.py @@ -1,6 +1,3 @@ -from typing import List - - class ParsingContext: """ This class is passed around during the query parsing process @@ -9,7 +6,7 @@ class ParsingContext: """ def __init__(self) -> None: - self.__alias_cache: List[str] = [] + self.__alias_cache: list[str] = [] def add_alias(self, alias: str) -> None: self.__alias_cache.append(alias) diff --git a/snuba/query/processors/condition_checkers/__init__.py b/snuba/query/processors/condition_checkers/__init__.py index f7725b1ea78..b3e5d1bc1cb 100644 --- a/snuba/query/processors/condition_checkers/__init__.py +++ b/snuba/query/processors/condition_checkers/__init__.py @@ -2,7 +2,7 @@ import os from abc import ABC, abstractmethod -from typing import Type, cast +from typing import cast from snuba.clickhouse.query import Expression from snuba.utils.registered_class import RegisteredClass, import_submodules_in_directory @@ -30,8 +30,8 @@ def check(self, expression: Expression) -> bool: raise NotImplementedError @classmethod - def get_from_name(cls, name: str) -> Type["ConditionChecker"]: - return cast(Type["ConditionChecker"], cls.class_from_name(name)) + def get_from_name(cls, name: str) -> type[ConditionChecker]: + return cast(type["ConditionChecker"], cls.class_from_name(name)) @classmethod def from_kwargs(cls, **kwargs: str) -> ConditionChecker: diff --git a/snuba/query/processors/logical/basic_functions.py b/snuba/query/processors/logical/basic_functions.py index 234e791c87a..715b373dd58 100644 --- a/snuba/query/processors/logical/basic_functions.py +++ b/snuba/query/processors/logical/basic_functions.py @@ -48,12 +48,14 @@ def process_functions(exp: Expression) -> Expression: Literal(None, 0), ), ) - if isinstance(exp, CurriedFunctionCall): - if exp.internal_function.function_name == "top": - return replace( - exp, - internal_function=replace(exp.internal_function, function_name="topK"), - ) + if ( + isinstance(exp, CurriedFunctionCall) + and exp.internal_function.function_name == "top" + ): + return replace( + exp, + internal_function=replace(exp.internal_function, function_name="topK"), + ) return exp query.transform_expressions(process_functions) diff --git a/snuba/query/processors/logical/custom_function.py b/snuba/query/processors/logical/custom_function.py index e096dd11dc6..9198a98b467 100644 --- a/snuba/query/processors/logical/custom_function.py +++ b/snuba/query/processors/logical/custom_function.py @@ -1,5 +1,6 @@ +from collections.abc import Mapping, Sequence from dataclasses import replace -from typing import Any, Mapping, Sequence, Tuple +from typing import Any from sentry_relay.consts import SPAN_STATUS_NAME_TO_CODE @@ -41,7 +42,7 @@ def simple_function(body: str) -> Expression: return parse_clickhouse_function(body) -def partial_function(body: str, constants: Sequence[Tuple[str, Any]]) -> Expression: +def partial_function(body: str, constants: Sequence[tuple[str, Any]]) -> Expression: parsed = parse_clickhouse_function(body) constants_lookup = {name: Literal(None, value) for (name, value) in constants} return replace_in_expression(parsed, constants_lookup) @@ -92,14 +93,14 @@ class _CustomFunction(LogicalQueryProcessor): def __init__( self, name: str, - signature: Sequence[Tuple[str, ParamType]], + signature: Sequence[tuple[str, ParamType]], body: Expression, ) -> None: self.__function_name = name self.__param_names: Sequence[str] = [] param_types: Sequence[ParamType] = [] if len(signature) > 0: - self.__param_names, param_types = zip(*signature) + self.__param_names, param_types = zip(*signature, strict=False) self.__body = body self.__validator = SignatureValidator(param_types) @@ -132,15 +133,11 @@ def apply_function(expression: Expression) -> Expression: should_report=False, ) from exception - resolved_params = { - name: expression - for (name, expression) in zip(self.__param_names, expression.parameters) - } + resolved_params = dict(zip(self.__param_names, expression.parameters, strict=False)) ret = replace_in_expression(self.__body, resolved_params) return replace(ret, alias=expression.alias) - else: - return expression + return expression query.transform_expressions(apply_function) diff --git a/snuba/query/processors/logical/granularity_processor.py b/snuba/query/processors/logical/granularity_processor.py index 04420a28173..fa2f7e64c95 100644 --- a/snuba/query/processors/logical/granularity_processor.py +++ b/snuba/query/processors/logical/granularity_processor.py @@ -1,5 +1,6 @@ from abc import abstractmethod -from typing import List, Mapping, NamedTuple, Optional +from collections.abc import Mapping +from typing import NamedTuple from snuba.query.conditions import ( BooleanFunctions, @@ -47,14 +48,12 @@ def get_highest_common_available_granularity_multiple( ) -> int: raise NotImplementedError - def find_granularities_in_expression( - self, expression: Optional[Expression] - ) -> List[MatchResult]: + def find_granularities_in_expression(self, expression: Expression | None) -> list[MatchResult]: """ Finds all granularity conditions in an expression. Returns List[Tuple[MatchResult, int]] where [0] is the matched condition and [1] is highest common available granularity multiple """ - matches: List[MatchResult] = [] + matches: list[MatchResult] = [] match = FunctionCall( String(ConditionFunctions.EQ), ( diff --git a/snuba/query/processors/logical/hash_bucket_functions.py b/snuba/query/processors/logical/hash_bucket_functions.py index a55d3098c49..9328f96eb95 100644 --- a/snuba/query/processors/logical/hash_bucket_functions.py +++ b/snuba/query/processors/logical/hash_bucket_functions.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.query.expressions import Column, Expression, FunctionCall, Literal from snuba.query.logical import Query diff --git a/snuba/query/processors/logical/low_cardinality_processor.py b/snuba/query/processors/logical/low_cardinality_processor.py index 07aa77b5e48..2d306edc7ac 100644 --- a/snuba/query/processors/logical/low_cardinality_processor.py +++ b/snuba/query/processors/logical/low_cardinality_processor.py @@ -29,7 +29,7 @@ class LowCardinalityProcessor(LogicalQueryProcessor): def __init__(self, columns: list[str]) -> None: self.low_card_columns = set() - self.subscriptable_columns: dict[str, set[str]] = dict() + self.subscriptable_columns: dict[str, set[str]] = {} for c in columns: if c.startswith("tags") or c.startswith("contexts"): column, key = c.split("[") diff --git a/snuba/query/processors/logical/optional_attribute_aggregation.py b/snuba/query/processors/logical/optional_attribute_aggregation.py index 39e216c1d16..43db4626e1a 100644 --- a/snuba/query/processors/logical/optional_attribute_aggregation.py +++ b/snuba/query/processors/logical/optional_attribute_aggregation.py @@ -46,12 +46,11 @@ def find_subscriptable_reference( and exp.column.column_name in self._attribute_column_names ): return exp - elif isinstance(exp, FunctionCall) and exp.parameters: - for param in exp.parameters: - result = find_subscriptable_reference(param) - if result: - return result - elif isinstance(exp, CurriedFunctionCall): + if ( + isinstance(exp, FunctionCall) + and exp.parameters + or isinstance(exp, CurriedFunctionCall) + ): for param in exp.parameters: result = find_subscriptable_reference(param) if result: diff --git a/snuba/query/processors/logical/timeseries_processor.py b/snuba/query/processors/logical/timeseries_processor.py index dce5ed5e3c2..225e9dde513 100644 --- a/snuba/query/processors/logical/timeseries_processor.py +++ b/snuba/query/processors/logical/timeseries_processor.py @@ -1,4 +1,4 @@ -from typing import Mapping, Optional, Sequence +from collections.abc import Mapping, Sequence from snuba.query.conditions import ConditionFunctions from snuba.query.dsl import multiply @@ -91,21 +91,18 @@ def __init__( ), ) - def __group_time_column(self, exp: Expression, granularity: Optional[int]) -> Expression: - if isinstance(exp, Column): - if exp.column_name in self.__time_replace_columns: - real_column_name = self.__time_replace_columns[exp.column_name] - if granularity is None: - granularity = 3600 - time_column_fn = self.__group_time_function( - real_column_name, granularity, exp.alias - ) - return time_column_fn + def __group_time_column(self, exp: Expression, granularity: int | None) -> Expression: + if isinstance(exp, Column) and exp.column_name in self.__time_replace_columns: + real_column_name = self.__time_replace_columns[exp.column_name] + if granularity is None: + granularity = 3600 + time_column_fn = self.__group_time_function(real_column_name, granularity, exp.alias) + return time_column_fn return exp def __group_time_function( - self, column_name: str, granularity: int, alias: Optional[str] + self, column_name: str, granularity: int, alias: str | None ) -> FunctionCall: function_call = { 3600: FunctionCall( @@ -193,7 +190,7 @@ def process_query(self, query: Query, query_settings: QuerySettings) -> None: } -def extract_granularity_from_query(query: Query, column: str) -> Optional[int]: +def extract_granularity_from_query(query: Query, column: str) -> int | None: """ This extracts the `granularity` from the `groupby` statement of the query. The matches are essentially the reverse of `TimeSeriesProcessor.__group_time_function`. diff --git a/snuba/query/processors/physical/abstract_array_join_optimizer.py b/snuba/query/processors/physical/abstract_array_join_optimizer.py index 3bc0f0a57ba..749afd43783 100644 --- a/snuba/query/processors/physical/abstract_array_join_optimizer.py +++ b/snuba/query/processors/physical/abstract_array_join_optimizer.py @@ -1,5 +1,6 @@ +from collections.abc import Callable, Sequence from itertools import combinations -from typing import Callable, Dict, List, Optional, Sequence, Set, Tuple, TypeVar, Union +from typing import TypeVar from snuba.clickhouse.query import Query from snuba.query.conditions import ( @@ -40,21 +41,21 @@ def __init__( self.__val_names = val_names @property - def key_columns(self) -> List[str]: + def key_columns(self) -> list[str]: """ The full name of all the nested key columns """ return [f"{self.column_name}.{column}" for column in self.__key_names] @property - def val_column(self) -> List[str]: + def val_column(self) -> list[str]: """ The full name of all the nested value columns """ return [f"{self.column_name}.{column}" for column in self.__val_names] @property - def all_columns(self) -> List[str]: + def all_columns(self) -> list[str]: """ The full name of all the nested columns """ @@ -62,7 +63,7 @@ def all_columns(self) -> List[str]: def get_filtered_arrays( self, query: Query, all_column_names: Sequence[str] - ) -> Tuple[Dict[str, Sequence[str]], Dict[Tuple[str, ...], Sequence[Tuple[str, ...]]]]: + ) -> tuple[dict[str, Sequence[str]], dict[tuple[str, ...], Sequence[tuple[str, ...]]]]: # Check which array joins have been selected selected_array_joins = { column_name @@ -124,8 +125,8 @@ def array_join_pattern(*column_names: str) -> FunctionCall: ) -T = TypeVar("T", bound=Union[str, Tuple[str, ...]]) -Extractor = Callable[[Expression], Set[T]] +T = TypeVar("T", bound=str | tuple[str, ...]) +Extractor = Callable[[Expression], set[T]] def skippable_condition_pattern(*column_names: str) -> Callable[[Expression], bool]: @@ -161,8 +162,8 @@ def get_single_column_filters(query: Query, column_name: str) -> Sequence[str]: def get_multiple_columns_filters( - query: Query, column_names: Tuple[str, ...] -) -> Sequence[Tuple[str, ...]]: + query: Query, column_names: tuple[str, ...] +) -> Sequence[tuple[str, ...]]: pattern = array_join_pattern(*column_names) return get_filtered_mapping_keys( @@ -188,7 +189,7 @@ def get_filtered_mapping_keys( in the query. """ ast_condition = query.get_condition() - cond_keys: Optional[Set[T]] = ( + cond_keys: set[T] | None = ( get_mapping_keys_in_condition(ast_condition, extractors, is_skippable_condition) if ast_condition is not None else set() @@ -199,7 +200,7 @@ def get_filtered_mapping_keys( return [] ast_having = query.get_having() - having_keys: Optional[Set[T]] = ( + having_keys: set[T] | None = ( get_mapping_keys_in_condition(ast_having, extractors, is_skippable_condition) if ast_having is not None else set() @@ -209,14 +210,14 @@ def get_filtered_mapping_keys( return [] keys = cond_keys | having_keys - return sorted(list(keys)) + return sorted(keys) def get_mapping_keys_in_condition( conditions: Expression, extractors: Sequence[Extractor[T]], is_skippable_condition: Callable[[Expression], bool], -) -> Optional[Set[T]]: +) -> set[T] | None: """ Examines the top level AND conditions and applies the extractor functions to extract the matching keys. @@ -224,7 +225,7 @@ def get_mapping_keys_in_condition( If any we find any OR conditions, we exit early though there could be possible optimizations to be done in these situations. """ - keys_found: Set[T] = set() + keys_found: set[T] = set() for c in get_first_level_and_conditions(conditions): if is_skippable_condition(c): @@ -242,7 +243,7 @@ def get_mapping_keys_in_condition( def string_literal_equal_condition_extractor( key_pattern: Pattern[Expression], ) -> Extractor[str]: - def extractor(condition: Expression) -> Set[str]: + def extractor(condition: Expression) -> set[str]: match = FunctionCall( String(ConditionFunctions.EQ), (key_pattern, Literal(Param("key", Any(str)))), @@ -259,7 +260,7 @@ def extractor(condition: Expression) -> Set[str]: def string_literal_in_condition_extractor( key_pattern: Pattern[Expression], ) -> Extractor[str]: - def extractor(condition: Expression) -> Set[str]: + def extractor(condition: Expression) -> set[str]: match = is_in_condition_pattern(key_pattern).match(condition) if match is None: @@ -279,8 +280,8 @@ def extractor(condition: Expression) -> Set[str]: def tuple_literal_equal_condition_extractor( key_pattern: Pattern[Expression], -) -> Extractor[Tuple[str, ...]]: - def extractor(condition: Expression) -> Set[Tuple[str, ...]]: +) -> Extractor[tuple[str, ...]]: + def extractor(condition: Expression) -> set[tuple[str, ...]]: match = FunctionCall( String(ConditionFunctions.EQ), (key_pattern, Param("tuple", FunctionCall(String("tuple"), None))), @@ -306,8 +307,8 @@ def extractor(condition: Expression) -> Set[Tuple[str, ...]]: def tuple_literal_in_condition_extractor( key_pattern: Pattern[Expression], -) -> Extractor[Tuple[str, ...]]: - def extractor(condition: Expression) -> Set[Tuple[str, ...]]: +) -> Extractor[tuple[str, ...]]: + def extractor(condition: Expression) -> set[tuple[str, ...]]: match = is_in_condition_pattern(key_pattern).match(condition) if match is None: @@ -320,7 +321,7 @@ def extractor(condition: Expression) -> Set[Tuple[str, ...]]: ): return set() - parameters: Set[Tuple[str, ...]] = set() + parameters: set[tuple[str, ...]] = set() for tuple_param in function.parameters: if ( diff --git a/snuba/query/processors/physical/array_has_optimizer.py b/snuba/query/processors/physical/array_has_optimizer.py index d06b3694b14..4241c55cfd8 100644 --- a/snuba/query/processors/physical/array_has_optimizer.py +++ b/snuba/query/processors/physical/array_has_optimizer.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.query import Query from snuba.query.expressions import Expression diff --git a/snuba/query/processors/physical/arrayjoin_keyvalue_optimizer.py b/snuba/query/processors/physical/arrayjoin_keyvalue_optimizer.py index fe319be17b6..64fd9d8f4d0 100644 --- a/snuba/query/processors/physical/arrayjoin_keyvalue_optimizer.py +++ b/snuba/query/processors/physical/arrayjoin_keyvalue_optimizer.py @@ -1,4 +1,4 @@ -from typing import Optional, Sequence, Set +from collections.abc import Sequence from snuba.clickhouse.query import Query from snuba.query.conditions import ( @@ -34,7 +34,7 @@ def array_join_pattern(column_name: str) -> FunctionCall: ) -def _get_mapping_keys_in_condition(condition: Expression, column_name: str) -> Optional[Set[str]]: +def _get_mapping_keys_in_condition(condition: Expression, column_name: str) -> set[str] | None: """ Finds the top level conditions that include filter based on the arrayJoin. This is meant to be used to find the keys the query is filtering the arrayJoin @@ -92,7 +92,7 @@ def get_filtered_mapping_keys(query: Query, column_name: str) -> Sequence[str]: ) if not array_join_found: - return list() + return [] ast_condition = query.get_condition() cond_keys = ( @@ -114,7 +114,7 @@ def get_filtered_mapping_keys(query: Query, column_name: str) -> Sequence[str]: return [] keys = cond_keys | having_keys - return sorted(list(keys)) + return sorted(keys) class ArrayJoinKeyValueOptimizer(ClickhouseQueryProcessor): @@ -196,29 +196,27 @@ def replace_expression(expr: Expression) -> Expression: return _unfiltered_mapping_pairs( expr.alias, self.__column_name, pair_alias, array_index ) - else: - return _filtered_mapping_pairs( - expr.alias, - self.__column_name, - pair_alias, - filtered_keys, - array_index, - ) + return _filtered_mapping_pairs( + expr.alias, + self.__column_name, + pair_alias, + filtered_keys, + array_index, + ) - elif filtered_keys: + if filtered_keys: # Only one between arrayJoin(col.key) and arrayJoin(col.value) # is present, and it is arrayJoin(col.key) since we found # filtered keys. return _filtered_mapping_keys(expr.alias, self.__column_name, filtered_keys) - else: - # No viable optimization - return expr + # No viable optimization + return expr query.transform_expressions(replace_expression) def _unfiltered_mapping_pairs( - alias: Optional[str], column_name: str, pair_alias: str, tuple_index: LiteralExpr + alias: str | None, column_name: str, pair_alias: str, tuple_index: LiteralExpr ) -> Expression: # (arrayJoin( # arrayMap((x,y) -> (x,y), tags.key, tags.value) @@ -237,7 +235,7 @@ def _unfiltered_mapping_pairs( def _filtered_mapping_pairs( - alias: Optional[str], + alias: str | None, column_name: str, pair_alias: str, filtered_tags: Sequence[LiteralExpr], @@ -264,7 +262,7 @@ def _filtered_mapping_pairs( def _filtered_mapping_keys( - alias: Optional[str], column_name: str, filtered_tags: Sequence[LiteralExpr] + alias: str | None, column_name: str, filtered_tags: Sequence[LiteralExpr] ) -> Expression: # arrayJoin(arrayFilter( # tag -> tag IN (tags), diff --git a/snuba/query/processors/physical/arrayjoin_optimizer.py b/snuba/query/processors/physical/arrayjoin_optimizer.py index 37d27a3b104..46c5401a463 100644 --- a/snuba/query/processors/physical/arrayjoin_optimizer.py +++ b/snuba/query/processors/physical/arrayjoin_optimizer.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Optional, Sequence, Set, Tuple +from collections.abc import Sequence from snuba.clickhouse.query import Query from snuba.query.conditions import ( @@ -54,11 +54,11 @@ def __find_tuple_index(self, column_name: str) -> LiteralExpr: return LiteralExpr(None, i + 1) raise ValueError(f"Unknown column: {column_name}") - def __get_array_joins_in_query(self, query: Query) -> Set[str]: + def __get_array_joins_in_query(self, query: Query) -> set[str]: """ Get all of the arrayJoins on the possible columns that are present in the query. """ - array_joins_in_query: Set[str] = set() + array_joins_in_query: set[str] = set() for e in query.get_all_expressions(): match = self.__array_join_pattern.match(e) @@ -125,7 +125,7 @@ def replace_expression(expr: Expression) -> Expression: ) # Only array join present is one of the key columns - elif len(array_joins_in_query) == 1 and any( + if len(array_joins_in_query) == 1 and any( column in array_joins_in_query for column in self.key_columns ): column_name = array_joins_in_query.pop() @@ -141,12 +141,12 @@ def replace_expression(expr: Expression) -> Expression: def filtered_mapping_tuples( - alias: Optional[str], + alias: str | None, tuple_alias: str, tuple_index: LiteralExpr, column_names: Sequence[str], - single_filtered: Dict[LiteralExpr, Sequence[str]], - multiple_filtered: Dict[Tuple[LiteralExpr, ...], Sequence[Tuple[str, ...]]], + single_filtered: dict[LiteralExpr, Sequence[str]], + multiple_filtered: dict[tuple[LiteralExpr, ...], Sequence[tuple[str, ...]]], ) -> Expression: return tupleElement( alias, @@ -164,13 +164,13 @@ def filtered_mapping_tuples( def filter_expression( columns: Expression, - single_filtered: Dict[LiteralExpr, Sequence[str]], - multiple_filtered: Dict[Tuple[LiteralExpr, ...], Sequence[Tuple[str, ...]]], + single_filtered: dict[LiteralExpr, Sequence[str]], + multiple_filtered: dict[tuple[LiteralExpr, ...], Sequence[tuple[str, ...]]], ) -> Expression: argument_name = "arg" argument = Argument(None, argument_name) - conditions: List[Expression] = [] + conditions: list[Expression] = [] for index in single_filtered: conditions.append( @@ -217,7 +217,7 @@ def filter_expression( def unfiltered_mapping_tuples( - alias: Optional[str], + alias: str | None, tuple_alias: str, tuple_index: LiteralExpr, column_names: Sequence[str], @@ -233,7 +233,7 @@ def unfiltered_mapping_tuples( def filtered_mapping_keys( - alias: Optional[str], column_name: str, filtered: Sequence[str] + alias: str | None, column_name: str, filtered: Sequence[str] ) -> Expression: return arrayJoin( alias, diff --git a/snuba/query/processors/physical/bloom_filter_optimizer.py b/snuba/query/processors/physical/bloom_filter_optimizer.py index 7d28bb960a0..2f869b19f42 100644 --- a/snuba/query/processors/physical/bloom_filter_optimizer.py +++ b/snuba/query/processors/physical/bloom_filter_optimizer.py @@ -1,5 +1,5 @@ from collections import defaultdict -from typing import Dict, Optional, Sequence, Set, Tuple +from collections.abc import Sequence from snuba.clickhouse.query import Query from snuba.query.conditions import combine_and_conditions, combine_or_conditions @@ -27,9 +27,9 @@ def process_query(self, query: Query, query_settings: QuerySettings) -> None: def generate_bloom_filter_condition( column_name: str, - single_filtered: Dict[str, Sequence[str]], - multiple_filtered: Dict[Tuple[str, ...], Sequence[Tuple[str, ...]]], -) -> Optional[Expression]: + single_filtered: dict[str, Sequence[str]], + multiple_filtered: dict[tuple[str, ...], Sequence[tuple[str, ...]]], +) -> Expression | None: """ Generate the filters on the array columns to use the bloom filter index on the spans.op and spans.group columns in order to filter the transactions @@ -39,7 +39,7 @@ def generate_bloom_filter_condition( the final condition is built up from a series of has conditions. """ - per_key_vals: Dict[str, Set[str]] = defaultdict(set) + per_key_vals: dict[str, set[str]] = defaultdict(set) for key, single_filter in single_filtered.items(): for val in single_filter: @@ -47,7 +47,7 @@ def generate_bloom_filter_condition( for keys, multiple_filter in multiple_filtered.items(): for val_tuple in multiple_filter: - for key, val in zip(keys, val_tuple): + for key, val in zip(keys, val_tuple, strict=False): per_key_vals[key].add(val) conditions = [ diff --git a/snuba/query/processors/physical/clickhouse_settings_override.py b/snuba/query/processors/physical/clickhouse_settings_override.py index 90d15854043..e27f1c8d9e8 100644 --- a/snuba/query/processors/physical/clickhouse_settings_override.py +++ b/snuba/query/processors/physical/clickhouse_settings_override.py @@ -1,4 +1,5 @@ -from typing import Any, MutableMapping +from collections.abc import MutableMapping +from typing import Any from snuba.clickhouse.query import Query from snuba.query.processors.physical import ClickhouseQueryProcessor diff --git a/snuba/query/processors/physical/column_filter_processor.py b/snuba/query/processors/physical/column_filter_processor.py index 666acc2401c..40e2e934b08 100644 --- a/snuba/query/processors/physical/column_filter_processor.py +++ b/snuba/query/processors/physical/column_filter_processor.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.query import Query from snuba.query.exceptions import InvalidQueryException diff --git a/snuba/query/processors/physical/conditions_enforcer.py b/snuba/query/processors/physical/conditions_enforcer.py index 4b1cb9fa406..ee25accfe3c 100644 --- a/snuba/query/processors/physical/conditions_enforcer.py +++ b/snuba/query/processors/physical/conditions_enforcer.py @@ -1,5 +1,5 @@ import logging -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.query import Expression, Query from snuba.query.conditions import get_first_level_and_conditions @@ -29,15 +29,14 @@ def __init__(self, condition_checkers: Sequence[ConditionChecker]) -> None: self.__condition_checkers = condition_checkers def process_query(self, query: Query, query_settings: QuerySettings) -> None: - missing_checkers = {checker for checker in self.__condition_checkers} + missing_checkers = set(self.__condition_checkers) def inspect_expression(condition: Expression) -> None: top_level = get_first_level_and_conditions(condition) for condition in top_level: for checker in self.__condition_checkers: - if checker in missing_checkers: - if checker.check(condition): - missing_checkers.remove(checker) + if checker in missing_checkers and checker.check(condition): + missing_checkers.remove(checker) condition = query.get_condition() if condition is not None: diff --git a/snuba/query/processors/physical/fixedstring_array_column_processor.py b/snuba/query/processors/physical/fixedstring_array_column_processor.py index 8259e471b42..9aaf71921d5 100644 --- a/snuba/query/processors/physical/fixedstring_array_column_processor.py +++ b/snuba/query/processors/physical/fixedstring_array_column_processor.py @@ -1,5 +1,3 @@ -from typing import Set - from snuba.query.expressions import Expression, FunctionCall, Literal from snuba.query.processors.physical.type_converters import ( BaseTypeConverter, @@ -8,7 +6,7 @@ class FixedStringArrayColumnProcessor(BaseTypeConverter): - def __init__(self, columns: Set[str], fixed_length: int): + def __init__(self, columns: set[str], fixed_length: int): self.fixed_length = fixed_length super().__init__(columns, optimize_ordering=True) @@ -20,8 +18,8 @@ def _translate_literal(self, exp: Literal) -> Expression: "toFixedString", (Literal(None, value=exp.value), Literal(None, self.fixed_length)), ) - except (AssertionError, ValueError): - raise ColumnTypeError("Not a valid UUID string", should_report=False) + except (AssertionError, ValueError) as e: + raise ColumnTypeError("Not a valid UUID string", should_report=False) from e def _process_expressions(self, exp: Expression) -> Expression: # FixedString is converted to regular string just fine in query return diff --git a/snuba/query/processors/physical/group_id_column_processor.py b/snuba/query/processors/physical/group_id_column_processor.py index 9625e1a24f0..f43d1db118d 100644 --- a/snuba/query/processors/physical/group_id_column_processor.py +++ b/snuba/query/processors/physical/group_id_column_processor.py @@ -7,16 +7,15 @@ class GroupIdColumnProcessor(ClickhouseQueryProcessor): def process_query(self, query: Query, query_settings: QuerySettings) -> None: def process_column(exp: Expression) -> Expression: - if isinstance(exp, Column): - if exp.column_name == "group_id": - return FunctionCall( - exp.alias, - "nullIf", - ( - Column(None, exp.table_name, exp.column_name), - Literal(None, 0), - ), - ) + if isinstance(exp, Column) and exp.column_name == "group_id": + return FunctionCall( + exp.alias, + "nullIf", + ( + Column(None, exp.table_name, exp.column_name), + Literal(None, 0), + ), + ) return exp diff --git a/snuba/query/processors/physical/hexint_column_processor.py b/snuba/query/processors/physical/hexint_column_processor.py index 26342507e90..92fe1e0d7bb 100644 --- a/snuba/query/processors/physical/hexint_column_processor.py +++ b/snuba/query/processors/physical/hexint_column_processor.py @@ -1,5 +1,3 @@ -from typing import Set - from snuba.query.dsl import Functions as f from snuba.query.dsl import column, if_cond, literal from snuba.query.expressions import ( @@ -17,7 +15,7 @@ class HexIntColumnProcessor(BaseTypeConverter): - def __init__(self, columns: Set[str], size: int = 16) -> None: + def __init__(self, columns: set[str], size: int = 16) -> None: """ size is the number of characters in the hex string representation of the integer (e.g. 32 for 128 bit integers) """ @@ -31,8 +29,8 @@ def _translate_literal(self, exp: Literal) -> Literal: if self._size == 32: return Literal(alias=exp.alias, value=str(int(exp.value, 16))) return Literal(alias=exp.alias, value=int(exp.value, 16)) - except (AssertionError, ValueError): - raise ColumnTypeError("Invalid hexint", should_report=False) + except (AssertionError, ValueError) as e: + raise ColumnTypeError("Invalid hexint", should_report=False) from e def _process_expressions(self, exp: Expression) -> Expression: if isinstance(exp, Column) and exp.column_name in self.columns: @@ -63,8 +61,8 @@ def _translate_literal(self, exp: Literal) -> Literal: try: assert isinstance(exp.value, str) return Literal(alias=exp.alias, value=int(exp.value, 16)) - except (AssertionError, ValueError): - raise ColumnTypeError("Invalid hexint", should_report=False) + except (AssertionError, ValueError) as e: + raise ColumnTypeError("Invalid hexint", should_report=False) from e def _process_expressions(self, exp: Expression) -> Expression: if isinstance(exp, Column) and exp.column_name in self.columns: diff --git a/snuba/query/processors/physical/mapping_optimizer.py b/snuba/query/processors/physical/mapping_optimizer.py index 18887e20c51..169533090e6 100644 --- a/snuba/query/processors/physical/mapping_optimizer.py +++ b/snuba/query/processors/physical/mapping_optimizer.py @@ -1,6 +1,6 @@ import logging from enum import Enum -from typing import Optional, Tuple, cast +from typing import cast from snuba import environment from snuba.clickhouse.query import Query @@ -149,7 +149,7 @@ def __init__( def __classify_combined_conditions(self, condition: Expression) -> ConditionClass: if not isinstance(condition, FunctionExpr): return ConditionClass.IRRELEVANT - elif condition.function_name in (BooleanFunctions.AND, BooleanFunctions.OR): + if condition.function_name in (BooleanFunctions.AND, BooleanFunctions.OR): conditions = ( get_first_level_and_conditions(condition) if condition.function_name == BooleanFunctions.AND @@ -158,12 +158,10 @@ def __classify_combined_conditions(self, condition: Expression) -> ConditionClas classified = {self.__classify_combined_conditions(c) for c in conditions} if ConditionClass.NOT_OPTIMIZABLE in classified: return ConditionClass.NOT_OPTIMIZABLE - elif ConditionClass.OPTIMIZABLE in classified: + if ConditionClass.OPTIMIZABLE in classified: return ConditionClass.OPTIMIZABLE - else: - return ConditionClass.IRRELEVANT - else: - return self.__classify_condition(condition) + return ConditionClass.IRRELEVANT + return self.__classify_condition(condition) def __classify_condition(self, condition: Expression) -> ConditionClass: # Expects this to be an individual condition @@ -199,7 +197,7 @@ def __classify_condition(self, condition: Expression) -> ConditionClass: return ConditionClass.NOT_OPTIMIZABLE return ConditionClass.OPTIMIZABLE - elif equals_condition_match is None and in_condition_match is None: + if equals_condition_match is None and in_condition_match is None: # If this condition is not matching an optimizable condition, # check that it does not reference the optimizable column. # If it does, it means we should not optimize this query. @@ -210,8 +208,7 @@ def __classify_condition(self, condition: Expression) -> ConditionClass: ): return ConditionClass.NOT_OPTIMIZABLE return ConditionClass.IRRELEVANT - else: - return ConditionClass.IRRELEVANT + return ConditionClass.IRRELEVANT def __replace_with_hash(self, condition: Expression) -> Expression: equals_condition_match = self.__equals_condition_pattern.match(condition) @@ -245,7 +242,7 @@ def __replace_with_hash(self, condition: Expression) -> Expression: ), ), ) - elif ( + if ( in_condition_match is not None and in_condition_match.string(KEY_COL_MAPPING_PARAM) == f"{self.__column_name}.key" ): @@ -322,13 +319,13 @@ def _get_condition_without_redundant_checks( """ if not isinstance(condition, FunctionExpr): return condition - elif condition.function_name == BooleanFunctions.OR: + if condition.function_name == BooleanFunctions.OR: sub_conditions = get_first_level_or_conditions(condition) pruned_conditions = [ self._get_condition_without_redundant_checks(c, query) for c in sub_conditions ] return combine_or_conditions(pruned_conditions) - elif condition.function_name == BooleanFunctions.AND: + if condition.function_name == BooleanFunctions.AND: sub_conditions = get_first_level_and_conditions(condition) tag_eq_match_keys = set() matched_tag_exists_conditions = {} @@ -344,7 +341,7 @@ def _get_condition_without_redundant_checks( tag_eq_match_keys.add(eq_match.scalar(KEY_MAPPING_PARAM)) useful_conditions = [] for condition_id, cond in enumerate(sub_conditions): - tag_exist_match = matched_tag_exists_conditions.get(condition_id, None) + tag_exist_match = matched_tag_exists_conditions.get(condition_id) if tag_exist_match: requested_tag = tag_exist_match.scalar("key") if requested_tag in tag_eq_match_keys: @@ -353,19 +350,17 @@ def _get_condition_without_redundant_checks( continue useful_conditions.append(self._get_condition_without_redundant_checks(cond, query)) return combine_and_conditions(useful_conditions) - else: - return condition + return condition def __get_reduced_and_classified_query_clause( - self, clause: Optional[Expression], query: Query - ) -> Tuple[Optional[Expression], ConditionClass]: + self, clause: Expression | None, query: Query + ) -> tuple[Expression | None, ConditionClass]: cond_class = ConditionClass.IRRELEVANT if clause is not None: new_clause = self._get_condition_without_redundant_checks(clause, query) cond_class = self.__classify_combined_conditions(new_clause) return new_clause, cond_class - else: - return clause, cond_class + return clause, cond_class def process_query(self, query: Query, query_settings: QuerySettings) -> None: if not get_config(self.__killswitch, 1): diff --git a/snuba/query/processors/physical/mapping_promoter.py b/snuba/query/processors/physical/mapping_promoter.py index 8f5db34c40b..eb6536a0b76 100644 --- a/snuba/query/processors/physical/mapping_promoter.py +++ b/snuba/query/processors/physical/mapping_promoter.py @@ -1,4 +1,5 @@ -from typing import Mapping, NamedTuple, Optional +from collections.abc import Mapping +from typing import NamedTuple from snuba.clickhouse.query import Query from snuba.clickhouse.translators.snuba.mappers import ( @@ -15,14 +16,14 @@ class SubscriptableMatch(NamedTuple): # The table name associated with the nested column found in the query. - table_name: Optional[str] + table_name: str | None # The nested column name column_name: str # The key found in the query (like key in tags[key]) key: str -def match_subscriptable_reference(exp: Expression) -> Optional[SubscriptableMatch]: +def match_subscriptable_reference(exp: Expression) -> SubscriptableMatch | None: """ Finds the expression, in the Clickhouse query, that loads the value of a specific tag (or any nested column that represents a mapping, @@ -112,12 +113,11 @@ def transform_nested_column(exp: Expression) -> Expression: and "FixedString" not in col_type_name ): return Column(exp.alias, subscript.table_name, promoted_col_name) - else: - return FunctionCall( - exp.alias, - "toString", - (Column(None, subscript.table_name, promoted_col_name),), - ) + return FunctionCall( + exp.alias, + "toString", + (Column(None, subscript.table_name, promoted_col_name),), + ) return exp diff --git a/snuba/query/processors/physical/null_column_caster.py b/snuba/query/processors/physical/null_column_caster.py index 1eb6d980531..82c19da8d9b 100644 --- a/snuba/query/processors/physical/null_column_caster.py +++ b/snuba/query/processors/physical/null_column_caster.py @@ -1,4 +1,4 @@ -from typing import Dict, Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import FlattenedColumn, SchemaModifiers from snuba.clickhouse.query import Query @@ -60,7 +60,7 @@ class NullColumnCaster(ClickhouseQueryProcessor): """ - def _find_mismatched_null_columns(self) -> Dict[str, FlattenedColumn]: + def _find_mismatched_null_columns(self) -> dict[str, FlattenedColumn]: # This has to be imported here since the storage factory will also initialize this query processor # and importing it at the top will create an import cycle @@ -68,8 +68,8 @@ def _find_mismatched_null_columns(self) -> Dict[str, FlattenedColumn]: # good first-class support for merge tables in snuba atm (12/06/2022) which makes us rely on this hack from snuba.datasets.storages.factory import get_storage - mismatched_col_name_to_col: Dict[str, FlattenedColumn] = {} - col_name_to_nullable: Dict[str, bool] = {} + mismatched_col_name_to_col: dict[str, FlattenedColumn] = {} + col_name_to_nullable: dict[str, bool] = {} for table_storage_key in self.__merge_table_sources_keys: table_storage = get_storage(StorageKey(table_storage_key)) for col in table_storage.get_schema().get_columns(): @@ -92,10 +92,10 @@ def __init__(self, merge_table_sources: Sequence[StorageKeyStr]): """ self.__merge_table_sources_keys = merge_table_sources - self.__mismatched_null_columns: Dict[str, FlattenedColumn] = {} + self.__mismatched_null_columns: dict[str, FlattenedColumn] = {} @property - def mismatched_null_columns(self) -> Dict[str, FlattenedColumn]: + def mismatched_null_columns(self) -> dict[str, FlattenedColumn]: # The first time the query processor is run, we calculate the mismatched null columns # which never change. We don't do this at initialization time because there is no guarantee that # all the storages will be loaded at the time this query processor is @@ -106,29 +106,28 @@ def mismatched_null_columns(self) -> Dict[str, FlattenedColumn]: def process_query(self, query: Query, query_settings: QuerySettings) -> None: def cast_column_to_nullable(exp: Expression) -> Expression: - if isinstance(exp, Column): - if exp.column_name in self.mismatched_null_columns: - # depending on the order of the storage, this dictionary will contain - # either the nullable or non-nullable version of the column. No matter - # which one is in there, due to the mismatch on the merge table it needs to - # be cast as nullable anyways - mismatched_column = self.mismatched_null_columns[exp.column_name] - col_is_nullable = _col_is_nullable(mismatched_column) - col_type = mismatched_column.type.for_schema() - cast_str = col_type if col_is_nullable else f"Nullable({col_type})" - return FunctionCall( - exp.alias, - "cast", - ( - # move the alias up to the cast function - Column( - None, - table_name=exp.table_name, - column_name=exp.column_name, - ), - Literal(None, cast_str), + if isinstance(exp, Column) and exp.column_name in self.mismatched_null_columns: + # depending on the order of the storage, this dictionary will contain + # either the nullable or non-nullable version of the column. No matter + # which one is in there, due to the mismatch on the merge table it needs to + # be cast as nullable anyways + mismatched_column = self.mismatched_null_columns[exp.column_name] + col_is_nullable = _col_is_nullable(mismatched_column) + col_type = mismatched_column.type.for_schema() + cast_str = col_type if col_is_nullable else f"Nullable({col_type})" + return FunctionCall( + exp.alias, + "cast", + ( + # move the alias up to the cast function + Column( + None, + table_name=exp.table_name, + column_name=exp.column_name, ), - ) + Literal(None, cast_str), + ), + ) return exp def transform_aggregate_functions_with_mismatched_nullable_parameters( diff --git a/snuba/query/processors/physical/prewhere.py b/snuba/query/processors/physical/prewhere.py index c81687a34fc..b9168f149f7 100644 --- a/snuba/query/processors/physical/prewhere.py +++ b/snuba/query/processors/physical/prewhere.py @@ -1,4 +1,4 @@ -from typing import Optional, Sequence, Set +from collections.abc import Sequence from snuba import environment, settings from snuba.clickhouse.query import Query @@ -44,12 +44,12 @@ class PrewhereProcessor(ClickhouseQueryProcessor): def __init__( self, prewhere_candidates: Sequence[str], - omit_if_final: Optional[Sequence[str]] = None, - max_prewhere_conditions: Optional[int] = None, + omit_if_final: Sequence[str] | None = None, + max_prewhere_conditions: int | None = None, ) -> None: self.__prewhere_candidates = prewhere_candidates self.__omit_if_final = omit_if_final - self.__max_prewhere_conditions: Optional[int] = max_prewhere_conditions + self.__max_prewhere_conditions: int | None = max_prewhere_conditions def process_query(self, query: Query, query_settings: QuerySettings) -> None: max_prewhere_conditions: int = ( @@ -60,7 +60,7 @@ def process_query(self, query: Query, query_settings: QuerySettings) -> None: # We remove the candidates that appear in a uniq or -If aggregations # because a query like `countIf(col=x) .. PREWHERE col=x` can make # the Clickhouse server crash. - uniq_cols: Set[str] = set() + uniq_cols: set[str] = set() expressions = query.get_all_expressions() for exp in expressions: if isinstance(exp, FunctionCall) and ( diff --git a/snuba/query/processors/physical/replaced_groups.py b/snuba/query/processors/physical/replaced_groups.py index 4c49b1ac0d9..670e94f1242 100644 --- a/snuba/query/processors/physical/replaced_groups.py +++ b/snuba/query/processors/physical/replaced_groups.py @@ -1,6 +1,6 @@ +from collections.abc import MutableMapping from dataclasses import replace from datetime import datetime -from typing import MutableMapping, Optional, Set from snuba import environment, settings from snuba.clickhouse.query import Query @@ -32,7 +32,7 @@ class PostReplacementConsistencyEnforcer(ClickhouseQueryProcessor): have to remove those rows manually or to run the query in FINAL mode. """ - def __init__(self, project_column: str, replacer_state_name: Optional[str]) -> None: + def __init__(self, project_column: str, replacer_state_name: str | None) -> None: self.__project_column = project_column self.__groups_column = "group_id" # This is used to allow us to keep the replacement state in redis for multiple @@ -139,7 +139,7 @@ def _initialize_tags( """ Initialize tags dictionary for DataDog metrics. """ - tags = {replacement_type: "True" for replacement_type in flags.replacement_types} + tags = dict.fromkeys(flags.replacement_types, "True") tags["referrer"] = query_settings.referrer return tags @@ -155,7 +155,7 @@ def _set_query_final(self, query: Query, final: bool) -> None: def _query_overlaps_replacements( self, query: Query, - latest_replacement_time: Optional[datetime], + latest_replacement_time: datetime | None, ) -> bool: """ Given a Query and the latest replacement time for any project @@ -167,7 +167,7 @@ def _query_overlaps_replacements( latest_replacement_time > query_from if latest_replacement_time and query_from else True ) - def _groups_to_exclude(self, query: Query, group_ids_to_exclude: Set[int]) -> Set[int]: + def _groups_to_exclude(self, query: Query, group_ids_to_exclude: set[int]) -> set[int]: """ Given a Query and the group ids to exclude for any project this query touches, returns the intersection of the group ids diff --git a/snuba/query/processors/physical/tuple_unaliaser.py b/snuba/query/processors/physical/tuple_unaliaser.py index edcac160f80..dbd1d1e703c 100644 --- a/snuba/query/processors/physical/tuple_unaliaser.py +++ b/snuba/query/processors/physical/tuple_unaliaser.py @@ -43,9 +43,14 @@ def visit_function_call(self, exp: FunctionCall) -> Expression: def visit_curried_function_call(self, exp: CurriedFunctionCall) -> Expression: self.__level += 1 transfomed_params = tuple([param.accept(self) for param in exp.parameters]) + internal_function = exp.internal_function.accept(self) + assert isinstance(internal_function, FunctionCall), ( + "The internal function of a curried function call cannot be resolved " + "to anything other than a function call" + ) res = replace( exp, - internal_function=exp.internal_function.accept(self), + internal_function=internal_function, parameters=transfomed_params, ) self.__level -= 1 diff --git a/snuba/query/processors/physical/type_converters.py b/snuba/query/processors/physical/type_converters.py index 5c4005de5c7..1c587177c70 100644 --- a/snuba/query/processors/physical/type_converters.py +++ b/snuba/query/processors/physical/type_converters.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod -from typing import Set +from collections.abc import Callable +from typing import cast from snuba.clickhouse.query import Query from snuba.query.conditions import ConditionFunctions @@ -19,7 +20,7 @@ class ColumnTypeError(ValidationException): class BaseTypeConverter(ClickhouseQueryProcessor, ABC): - def __init__(self, columns: Set[str], optimize_ordering: bool = False): + def __init__(self, columns: set[str], optimize_ordering: bool = False): self.columns = columns self.optimize_ordering = optimize_ordering column_match = Or([String(col) for col in columns]) @@ -181,15 +182,20 @@ def assert_literal(lit: Expression) -> Literal: for param in params: assert isinstance(param, Literal) - wrapper = tuple if collection_func.function_name == "tuple" else list + wrapper: Callable[[list[Expression]], tuple[Expression, ...] | list[Expression]] = ( + tuple if collection_func.function_name == "tuple" else list + ) new_collection_func = FunctionCall( collection_func.alias, collection_func.function_name, - parameters=wrapper( - [ - self._translate_literal(assert_literal(lit)) - for lit in collection_func.parameters - ] + parameters=cast( + "tuple[Expression, ...]", + wrapper( + [ + self._translate_literal(assert_literal(lit)) + for lit in collection_func.parameters + ] + ), ), ) return FunctionCall( diff --git a/snuba/query/processors/physical/uniq_in_select_and_having.py b/snuba/query/processors/physical/uniq_in_select_and_having.py index 8f98e16b499..29e92b3decb 100644 --- a/snuba/query/processors/physical/uniq_in_select_and_having.py +++ b/snuba/query/processors/physical/uniq_in_select_and_having.py @@ -7,7 +7,8 @@ """ import logging -from typing import Any, Dict, Sequence, cast +from collections.abc import Sequence +from typing import Any, cast from snuba.clickhouse.query import Query from snuba.query.exceptions import InvalidQueryException @@ -40,7 +41,7 @@ class UniqInSelectAndHavingProcessor(ClickhouseQueryProcessor): def process_query(self, query: Query, query_settings: QuerySettings) -> None: having_clause = query.get_having() if not having_clause: - return None + return selected_columns = query.get_selected_columns() uniq_matcher = Param("function", FunctionCallMatch(String("uniq"))) found_functions = [] @@ -59,9 +60,8 @@ def process_query(self, query: Query, query_settings: QuerySettings) -> None: ) if should_throw: raise error - else: - logging.warning( - "Aggregation is in HAVING clause but not SELECT", - exc_info=True, - extra=cast(Dict[str, Any], error.to_dict()), - ) + logging.warning( + "Aggregation is in HAVING clause but not SELECT", + exc_info=True, + extra=cast(dict[str, Any], error.to_dict()), + ) diff --git a/snuba/query/processors/physical/user_column_processor.py b/snuba/query/processors/physical/user_column_processor.py index eaccacc67b0..6be45d73eea 100644 --- a/snuba/query/processors/physical/user_column_processor.py +++ b/snuba/query/processors/physical/user_column_processor.py @@ -13,13 +13,12 @@ class UserColumnProcessor(ClickhouseQueryProcessor): def process_query(self, query: Query, query_settings: QuerySettings) -> None: def process_column(exp: Expression) -> Expression: - if isinstance(exp, Column): - if exp.column_name == "user": - return FunctionCall( - exp.alias, - "nullIf", - (Column(None, None, "user"), Literal(None, "")), - ) + if isinstance(exp, Column) and exp.column_name == "user": + return FunctionCall( + exp.alias, + "nullIf", + (Column(None, None, "user"), Literal(None, "")), + ) return exp diff --git a/snuba/query/processors/physical/uuid_array_column_processor.py b/snuba/query/processors/physical/uuid_array_column_processor.py index bdb128f5b84..53b23c91310 100644 --- a/snuba/query/processors/physical/uuid_array_column_processor.py +++ b/snuba/query/processors/physical/uuid_array_column_processor.py @@ -1,5 +1,4 @@ import uuid -from typing import Set from snuba.query.expressions import ( Argument, @@ -16,7 +15,7 @@ class UUIDArrayColumnProcessor(BaseTypeConverter): - def __init__(self, columns: Set[str]) -> None: + def __init__(self, columns: set[str]) -> None: super().__init__(columns) def _translate_literal(self, exp: Literal) -> Expression: @@ -24,8 +23,8 @@ def _translate_literal(self, exp: Literal) -> Expression: assert isinstance(exp.value, str) new_val = str(uuid.UUID(exp.value)) return FunctionCall(exp.alias, "toUUID", (Literal(None, value=new_val),)) - except (AssertionError, ValueError): - raise ColumnTypeError("Not a valid UUID string", should_report=False) + except (AssertionError, ValueError) as e: + raise ColumnTypeError("Not a valid UUID string", should_report=False) from e def _process_expressions(self, exp: Expression) -> Expression: if isinstance(exp, Column) and exp.column_name in self.columns: diff --git a/snuba/query/processors/physical/uuid_column_processor.py b/snuba/query/processors/physical/uuid_column_processor.py index 071b88bdb3c..687c310d7db 100644 --- a/snuba/query/processors/physical/uuid_column_processor.py +++ b/snuba/query/processors/physical/uuid_column_processor.py @@ -1,5 +1,4 @@ import uuid -from typing import Set from snuba.query.expressions import Column, Expression, FunctionCall, Literal from snuba.query.processors.physical.type_converters import ( @@ -9,7 +8,7 @@ class UUIDColumnProcessor(BaseTypeConverter): - def __init__(self, columns: Set[str]) -> None: + def __init__(self, columns: set[str]) -> None: super().__init__(columns, optimize_ordering=False) def _translate_literal(self, exp: Literal) -> Literal: @@ -17,8 +16,8 @@ def _translate_literal(self, exp: Literal) -> Literal: assert isinstance(exp.value, str) new_val = str(uuid.UUID(exp.value)) return Literal(alias=exp.alias, value=new_val) - except (AssertionError, ValueError): - raise ColumnTypeError("Not a valid UUID string", should_report=False) + except (AssertionError, ValueError) as e: + raise ColumnTypeError("Not a valid UUID string", should_report=False) from e def _process_expressions(self, exp: Expression) -> Expression: if isinstance(exp, Column) and exp.column_name in self.columns: diff --git a/snuba/query/query_settings.py b/snuba/query/query_settings.py index e07c5547e6b..0000948f616 100644 --- a/snuba/query/query_settings.py +++ b/snuba/query/query_settings.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod -from typing import Any, MutableMapping, Optional +from collections.abc import MutableMapping +from typing import Any from snuba.downsampled_storage_tiers import Tier from snuba.state.quota import ResourceQuota @@ -37,7 +38,7 @@ def get_legacy(self) -> bool: pass @abstractmethod - def get_resource_quota(self) -> Optional[ResourceQuota]: + def get_resource_quota(self) -> ResourceQuota | None: pass @abstractmethod @@ -99,7 +100,7 @@ def __init__( self.__debug = debug self.__dry_run = dry_run self.__legacy = legacy - self.__resource_quota: Optional[ResourceQuota] = None + self.__resource_quota: ResourceQuota | None = None self.__clickhouse_settings: MutableMapping[str, Any] = {} self.referrer = referrer self.__asynchronous = asynchronous @@ -121,7 +122,7 @@ def get_dry_run(self) -> bool: def get_legacy(self) -> bool: return self.__legacy - def get_resource_quota(self) -> Optional[ResourceQuota]: + def get_resource_quota(self) -> ResourceQuota | None: return self.__resource_quota def set_resource_quota(self, quota: ResourceQuota) -> None: @@ -214,7 +215,7 @@ def get_app_id(self) -> str: def get_feature(self) -> str: return self.__feature - def get_resource_quota(self) -> Optional[ResourceQuota]: + def get_resource_quota(self) -> ResourceQuota | None: return None def set_resource_quota(self, quota: ResourceQuota) -> None: diff --git a/snuba/query/snql/discover_entity_selection.py b/snuba/query/snql/discover_entity_selection.py index cc33ff0cac1..b25d80bf431 100644 --- a/snuba/query/snql/discover_entity_selection.py +++ b/snuba/query/snql/discover_entity_selection.py @@ -179,9 +179,11 @@ def match_query_to_entity( if result: if result.string("function") == ConditionFunctions.EQ: event_types.add(event_type) - elif result.string("function") == ConditionFunctions.NEQ: - if event_type == "transaction": - return EntityKey.DISCOVER_EVENTS + elif ( + result.string("function") == ConditionFunctions.NEQ + and event_type == "transaction" + ): + return EntityKey.DISCOVER_EVENTS if len(event_types) == 1 and "transaction" in event_types: return EntityKey.DISCOVER_TRANSACTIONS @@ -227,12 +229,11 @@ def match_query_to_entity( if has_event_columns and has_transaction_columns: # Impossible query, use the merge table return EntityKey.DISCOVER - elif has_event_columns: + if has_event_columns: return EntityKey.DISCOVER_EVENTS - elif has_transaction_columns: + if has_transaction_columns: return EntityKey.DISCOVER_TRANSACTIONS - else: - return EntityKey.DISCOVER + return EntityKey.DISCOVER def _track_bad_query( diff --git a/snuba/query/snql/expression_visitor.py b/snuba/query/snql/expression_visitor.py index 0a0704d432c..13d60464d0e 100644 --- a/snuba/query/snql/expression_visitor.py +++ b/snuba/query/snql/expression_visitor.py @@ -1,15 +1,10 @@ import re +from collections.abc import Callable, Iterable, Sequence from enum import Enum from typing import ( Any, - Callable, - Iterable, - List, NamedTuple, - Optional, - Sequence, - Tuple, - Union, + TypeAlias, ) from parsimonious.nodes import Node @@ -44,8 +39,8 @@ class HighPriTuple(NamedTuple): arithm: Expression -HighPriArithmetic = Union[Node, HighPriTuple, Sequence[HighPriTuple]] -LowPriArithmetic = Union[Node, LowPriTuple, Sequence[LowPriTuple]] +HighPriArithmetic: TypeAlias = Node | HighPriTuple | Sequence[HighPriTuple] +LowPriArithmetic: TypeAlias = Node | LowPriTuple | Sequence[LowPriTuple] ARITHMETIC_OP_TO_FUNCTION = { @@ -58,19 +53,19 @@ class HighPriTuple(NamedTuple): def get_arithmetic_function( operator: Enum, -) -> Callable[[Expression, Expression, Optional[str]], FunctionCall]: +) -> Callable[[Expression, Expression, str | None], FunctionCall]: return ARITHMETIC_OP_TO_FUNCTION[operator] def get_arithmetic_expression( term: Expression, - exp: Union[LowPriArithmetic, HighPriArithmetic, Sequence[Any]], + exp: LowPriArithmetic | HighPriArithmetic | Sequence[Any], ) -> Expression: if isinstance(exp, Node): return term if isinstance(exp, (LowPriTuple, HighPriTuple)): return get_arithmetic_function(exp.op)(term, exp.arithm, None) - elif isinstance(exp, list): + if isinstance(exp, list): for elem in exp: if isinstance(elem, (LowPriTuple, HighPriTuple)): term = get_arithmetic_function(elem.op)(term, elem.arithm, None) @@ -89,14 +84,14 @@ def visit_column_name(node: Node, visited_children: Iterable[Any]) -> Column: def visit_low_pri_tuple( - node: Node, visited_children: Tuple[LowPriOperator, Any, Expression] + node: Node, visited_children: tuple[LowPriOperator, Any, Expression] ) -> LowPriTuple: left, _, right = visited_children return LowPriTuple(op=left, arithm=right) def visit_high_pri_tuple( - node: Node, visited_children: Tuple[HighPriOperator, Any, Expression] + node: Node, visited_children: tuple[HighPriOperator, Any, Expression] ) -> HighPriTuple: left, _, right = visited_children return HighPriTuple(op=left, arithm=right) @@ -110,14 +105,14 @@ def visit_high_pri_op(node: Node, visited_children: Iterable[Any]) -> HighPriOpe return HighPriOperator(node.text) -def visit_arithmetic_term(node: Node, visited_children: Tuple[Any, Expression]) -> Expression: +def visit_arithmetic_term(node: Node, visited_children: tuple[Any, Expression]) -> Expression: _, term = visited_children return term def visit_low_pri_arithmetic( node: Node, - visited_children: Tuple[Any, Expression, LowPriArithmetic], + visited_children: tuple[Any, Expression, LowPriArithmetic], ) -> Expression: _, term, exp = visited_children return get_arithmetic_expression(term, exp) @@ -125,7 +120,7 @@ def visit_low_pri_arithmetic( def visit_high_pri_arithmetic( node: Node, - visited_children: Tuple[Any, Expression, HighPriArithmetic], + visited_children: tuple[Any, Expression, HighPriArithmetic], ) -> Expression: _, term, exp = visited_children @@ -142,24 +137,24 @@ def visit_numeric_literal(node: Node, visited_children: Iterable[Any]) -> Litera newline_re = re.compile("((?:\\{2})*)(\\n)") -def visit_quoted_literal(node: Node, visited_children: Tuple[Any]) -> Literal: +def visit_quoted_literal(node: Node, visited_children: tuple[Any]) -> Literal: text = node.text[1:-1] text = newline_re.sub(text, "\n") match = text.replace("\\'", "'") return Literal(None, match) -def visit_parameter(node: Node, visited_children: Tuple[Expression, Any, Any, Any]) -> Expression: +def visit_parameter(node: Node, visited_children: tuple[Expression, Any, Any, Any]) -> Expression: param, _, _, _ = visited_children return param def visit_parameters_list( node: Node, - visited_children: Tuple[Union[Expression, List[Expression]], Expression], -) -> List[Expression]: + visited_children: tuple[Expression | list[Expression], Expression], +) -> list[Expression]: left_section, right_section = visited_children - ret: List[Expression] = [] + ret: list[Expression] = [] if not isinstance(left_section, Node): # We get a Node when the parameter rule is empty. Thus # no parameters @@ -168,14 +163,14 @@ def visit_parameters_list( # thus the generic visitor method removes the list. ret = [left_section] else: - ret = [p for p in left_section] + ret = list(left_section) ret.append(right_section) return ret def visit_function_call( node: Node, - visited_children: Tuple[str, Any, List[Expression], Any, Union[Node, List[Expression]]], + visited_children: tuple[str, Any, list[Expression], Any, Node | list[Expression]], ) -> Expression: name, _, params1, _, params2 = visited_children param_list1 = tuple(params1) diff --git a/snuba/query/snql/joins.py b/snuba/query/snql/joins.py index 9d119665525..72811233a76 100644 --- a/snuba/query/snql/joins.py +++ b/snuba/query/snql/joins.py @@ -1,6 +1,7 @@ from __future__ import annotations -from typing import MutableMapping, NamedTuple, Optional, Sequence, Union +from collections.abc import MutableMapping, Sequence +from typing import NamedTuple from snuba.datasets.entities.entity_key import EntityKey from snuba.query.data_source.join import ( @@ -36,11 +37,11 @@ class Node: def __init__( self, entity_data: IndividualNode[QueryEntity], - relationship: Optional[JoinRelationship] = None, + relationship: JoinRelationship | None = None, ) -> None: self.entity_data = entity_data self.relationship = relationship - self.child: Optional[Node] = None + self.child: Node | None = None self.join_conditions: Sequence[JoinCondition] = [] @property @@ -120,7 +121,7 @@ def build_list(relationships: Sequence[RelationshipTuple]) -> Node: roots: MutableMapping[EntityKey, Node] = {} children: MutableMapping[EntityKey, Node] = {} - def update_children(child: Optional[Node]) -> None: + def update_children(child: Node | None) -> None: while child is not None: children[child.entity] = child child = child.child @@ -129,12 +130,11 @@ def update_children(child: Optional[Node]) -> None: lhs = Node(rel.lhs) rhs = Node(rel.rhs, rel.data) orphan = roots.get(rhs.entity) - if orphan: - if not orphan.has_child(lhs.entity): - # The orphan is a child of this join. Combine them. - if orphan.child: - rhs.push_child(orphan.child) - del roots[orphan.entity] + if orphan and not orphan.has_child(lhs.entity): + # The orphan is a child of this join. Combine them. + if orphan.child: + rhs.push_child(orphan.child) + del roots[orphan.entity] if lhs.entity in roots: roots[lhs.entity].push_child(rhs) @@ -159,8 +159,8 @@ def update_children(child: Optional[Node]) -> None: def build_join_clause_loop( node_list: Node, - lhs: Optional[Union[IndividualNode[QueryEntity], JoinClause[QueryEntity]]], -) -> Union[IndividualNode[QueryEntity], JoinClause[QueryEntity]]: + lhs: IndividualNode[QueryEntity] | JoinClause[QueryEntity] | None, +) -> IndividualNode[QueryEntity] | JoinClause[QueryEntity]: rhs = node_list.entity_data if lhs is None: lhs = rhs diff --git a/snuba/query/snql/parser.py b/snuba/query/snql/parser.py index 03e7d3f1541..ed6dc431930 100644 --- a/snuba/query/snql/parser.py +++ b/snuba/query/snql/parser.py @@ -1,21 +1,13 @@ from __future__ import annotations import logging +from collections.abc import Callable, Iterable, MutableMapping, Sequence from dataclasses import replace from datetime import datetime, timedelta from functools import partial from typing import ( Any, - Callable, - Iterable, - List, - MutableMapping, NamedTuple, - Optional, - Sequence, - Set, - Tuple, - Union, cast, ) @@ -236,13 +228,13 @@ class OrTuple(NamedTuple): exp: Expression -class SnQLVisitor(NodeVisitor): # type: ignore +class SnQLVisitor(NodeVisitor): # type: ignore[misc] """ Builds Snuba AST expressions from the SnQL Parsimonious parse tree. """ @staticmethod - def __extract_alias_from_match(alias: Union[Node, List[Node]]) -> str: + def __extract_alias_from_match(alias: Node | list[Node]) -> str: extracted_alias: str if isinstance(alias, list): # Validate that we are parsing an expression that is @@ -256,7 +248,7 @@ def __extract_alias_from_match(alias: Union[Node, List[Node]]) -> str: def visit_query_exp( self, node: Node, visited_children: Iterable[Any] - ) -> Union[LogicalQuery, CompositeQuery[LogicalDataSource]]: + ) -> LogicalQuery | CompositeQuery[LogicalDataSource]: args: MutableMapping[str, Any] = {} ( data_source, @@ -302,29 +294,21 @@ def visit_query_exp( def visit_match_clause( self, node: Node, - visited_children: Tuple[ + visited_children: tuple[ Any, Any, Any, - Union[ - QueryEntity, - CompositeQuery[LogicalDataSource], - LogicalQuery, - RelationshipTuple, - Sequence[RelationshipTuple], - ], + QueryEntity + | CompositeQuery[LogicalDataSource] + | LogicalQuery + | RelationshipTuple + | Sequence[RelationshipTuple], ], - ) -> Union[ - CompositeQuery[LogicalDataSource], - LogicalQuery, - QueryEntity, - # joins not availble for storage queries as of 2024-04-12 - JoinClause[QueryEntity], - ]: + ) -> CompositeQuery[LogicalDataSource] | LogicalQuery | QueryEntity | JoinClause[QueryEntity]: _, _, _, match = visited_children if isinstance(match, (CompositeQuery, LogicalQuery)): return match - elif isinstance(match, RelationshipTuple): + if isinstance(match, RelationshipTuple): join_clause = build_join_clause([match]) return join_clause if isinstance(match, list) and all(isinstance(m, RelationshipTuple) for m in match): @@ -337,7 +321,7 @@ def visit_match_clause( def visit_entity_single( self, node: Node, - visited_children: Tuple[Any, Any, EntityKey, Union[Optional[float], Node], Any, Any], + visited_children: tuple[Any, Any, EntityKey, float | None | Node, Any, Any], ) -> QueryEntity: _, _, name, sample, _, _ = visited_children if isinstance(sample, Node): @@ -348,9 +332,7 @@ def visit_entity_single( def visit_storage_single( self, node: Node, - visited_children: Tuple[ - Any, Any, Any, Any, StorageKey, Union[Optional[float], Node], Any, Any - ], + visited_children: tuple[Any, Any, Any, Any, StorageKey, float | None | Node, Any, Any], ) -> QueryStorage: _, _, _, _, storage_key, sample, _, _ = visited_children if isinstance(sample, Node): @@ -361,9 +343,7 @@ def visit_storage_single( def visit_entity_match( self, node: Node, - visited_children: Tuple[ - Any, str, Any, Any, EntityKey, Union[Optional[float], Node], Any, Any - ], + visited_children: tuple[Any, str, Any, Any, EntityKey, float | None | Node, Any, Any], ) -> IndividualNode[QueryEntity]: _, alias, _, _, name, sample, _, _ = visited_children if isinstance(sample, Node): @@ -371,25 +351,25 @@ def visit_entity_match( return IndividualNode(alias, QueryEntity(name, get_entity(name).get_data_model(), sample)) - def visit_entity_alias(self, node: Node, visited_children: Tuple[Any]) -> str: + def visit_entity_alias(self, node: Node, visited_children: tuple[Any]) -> str: return str(node.text) - def visit_entity_name(self, node: Node, visited_children: Tuple[Any]) -> EntityKey: + def visit_entity_name(self, node: Node, visited_children: tuple[Any]) -> EntityKey: try: return EntityKey(node.text) - except Exception: - raise ParsingException(f"{node.text} is not a valid entity name") + except Exception as e: + raise ParsingException(f"{node.text} is not a valid entity name") from e - def visit_storage_name(self, node: Node, visited_children: Tuple[Any]) -> StorageKey: + def visit_storage_name(self, node: Node, visited_children: tuple[Any]) -> StorageKey: try: return StorageKey(node.text) - except Exception: - raise ParsingException(f"{node.text} is not a valid Storage name") + except Exception as e: + raise ParsingException(f"{node.text} is not a valid Storage name") from e def visit_relationships( self, node: Node, - visited_children: Tuple[RelationshipTuple, Any], + visited_children: tuple[RelationshipTuple, Any], ) -> Sequence[RelationshipTuple]: relationships = [visited_children[0]] if isinstance(visited_children[1], Node): @@ -406,7 +386,7 @@ def visit_relationships( def visit_relationship_match( self, node: Node, - visited_children: Tuple[ + visited_children: tuple[ Any, IndividualNode[QueryEntity], Any, @@ -424,20 +404,20 @@ def visit_relationship_match( raise ParsingException( f"{lhs.data_source.key.value} does not have a join relationship -[{relationship}]->" ) - elif data.rhs_entity != rhs.data_source.key: + if data.rhs_entity != rhs.data_source.key: raise ParsingException( f"-[{relationship}]-> cannot be used to join {lhs.data_source.key.value} to {rhs.data_source.key.value}" ) return RelationshipTuple(lhs, relationship, rhs, data) - def visit_relationship_link(self, node: Node, visited_children: Tuple[Any, Node, Any]) -> str: + def visit_relationship_link(self, node: Node, visited_children: tuple[Any, Node, Any]) -> str: _, relationship, _ = visited_children return str(relationship.text) def visit_subquery( - self, node: Node, visited_children: Tuple[Any, Node, Any] - ) -> Union[LogicalQuery, CompositeQuery[LogicalDataSource]]: + self, node: Node, visited_children: tuple[Any, Node, Any] + ) -> LogicalQuery | CompositeQuery[LogicalDataSource]: _, query, _ = visited_children assert isinstance(query, (CompositeQuery, LogicalQuery)) # mypy return query @@ -460,34 +440,34 @@ def visit_flag_column(self, node: Node, visited_children: Iterable[Any]) -> Colu return x def visit_and_tuple( - self, node: Node, visited_children: Tuple[Any, Node, Expression] + self, node: Node, visited_children: tuple[Any, Node, Expression] ) -> AndTuple: _, and_string, exp = visited_children return AndTuple(and_string.text, exp) - def visit_or_tuple(self, node: Node, visited_children: Tuple[Any, Node, Expression]) -> OrTuple: + def visit_or_tuple(self, node: Node, visited_children: tuple[Any, Node, Expression]) -> OrTuple: _, or_string, exp = visited_children return OrTuple(or_string.text, exp) def visit_parenthesized_cdn( - self, node: Node, visited_children: Tuple[Any, Any, Expression, Any] + self, node: Node, visited_children: tuple[Any, Any, Expression, Any] ) -> Expression: _, _, condition, _ = visited_children return condition def visit_parenthesized_arithm( - self, node: Node, visited_children: Tuple[Any, Expression, Any] + self, node: Node, visited_children: tuple[Any, Expression, Any] ) -> Expression: _, arithm, _ = visited_children return arithm def visit_low_pri_tuple( - self, node: Node, visited_children: Tuple[LowPriOperator, Any, Expression] + self, node: Node, visited_children: tuple[LowPriOperator, Any, Expression] ) -> LowPriTuple: return visit_low_pri_tuple(node, visited_children) def visit_high_pri_tuple( - self, node: Node, visited_children: Tuple[HighPriOperator, Any, Expression] + self, node: Node, visited_children: tuple[HighPriOperator, Any, Expression] ) -> HighPriTuple: return visit_high_pri_tuple(node, visited_children) @@ -498,21 +478,21 @@ def visit_high_pri_op(self, node: Node, visited_children: Iterable[Any]) -> High return visit_high_pri_op(node, visited_children) def visit_arithmetic_term( - self, node: Node, visited_children: Tuple[Any, Expression] + self, node: Node, visited_children: tuple[Any, Expression] ) -> Expression: return visit_arithmetic_term(node, visited_children) def visit_low_pri_arithmetic( self, node: Node, - visited_children: Tuple[Any, Expression, LowPriArithmetic], + visited_children: tuple[Any, Expression, LowPriArithmetic], ) -> Expression: return visit_low_pri_arithmetic(node, visited_children) def visit_high_pri_arithmetic( self, node: Node, - visited_children: Tuple[Any, Expression, HighPriArithmetic], + visited_children: tuple[Any, Expression, HighPriArithmetic], ) -> Expression: return visit_high_pri_arithmetic(node, visited_children) @@ -531,17 +511,17 @@ def visit_boolean_literal(self, node: Node, visited_children: Iterable[Any]) -> def visit_null_literal(self, node: Node, visited_children: Iterable[Any]) -> Literal: return Literal(None, None) - def visit_quoted_literal(self, node: Node, visited_children: Tuple[Node]) -> Literal: + def visit_quoted_literal(self, node: Node, visited_children: tuple[Node]) -> Literal: return visit_quoted_literal(node, visited_children) def visit_where_clause( - self, node: Node, visited_children: Tuple[Any, Any, Any, Expression] + self, node: Node, visited_children: tuple[Any, Any, Any, Expression] ) -> Expression: _, _, _, conditions = visited_children return conditions def visit_having_clause( - self, node: Node, visited_children: Tuple[Any, Any, Any, Expression] + self, node: Node, visited_children: tuple[Any, Any, Any, Expression] ) -> Expression: _, _, _, conditions = visited_children return conditions @@ -549,7 +529,7 @@ def visit_having_clause( def visit_and_expression( self, node: Node, - visited_children: Tuple[Any, Expression, Node], + visited_children: tuple[Any, Expression, Node], ) -> Expression: _, left_condition, and_condition = visited_children args = [left_condition] @@ -560,16 +540,16 @@ def visit_and_expression( if isinstance(and_condition, (AndTuple, OrTuple)): _, exp = and_condition return combine_and_conditions([left_condition, exp]) - elif isinstance(and_condition, list): + if isinstance(and_condition, list): for elem in and_condition: if isinstance(elem, Node): continue - elif isinstance(elem, (AndTuple, OrTuple)): + if isinstance(elem, (AndTuple, OrTuple)): args.append(elem.exp) return combine_and_conditions(args) def visit_or_expression( - self, node: Node, visited_children: Tuple[Any, Expression, Node] + self, node: Node, visited_children: tuple[Any, Expression, Node] ) -> Expression: _, left_condition, or_condition = visited_children args = [left_condition] @@ -580,16 +560,16 @@ def visit_or_expression( if isinstance(or_condition, (AndTuple, OrTuple)): _, exp = or_condition return combine_or_conditions([left_condition, exp]) - elif isinstance(or_condition, list): + if isinstance(or_condition, list): for elem in or_condition: if isinstance(elem, Node): continue - elif isinstance(elem, (AndTuple, OrTuple)): + if isinstance(elem, (AndTuple, OrTuple)): args.append(elem.exp) return combine_or_conditions(args) def visit_unary_condition( - self, node: Node, visited_children: Tuple[Expression, Any, str] + self, node: Node, visited_children: tuple[Expression, Any, str] ) -> Expression: exp, _, op = visited_children return unary_condition(op, exp) @@ -600,7 +580,7 @@ def visit_unary_op(self, node: Node, visited_children: Iterable[Any]) -> str: def visit_main_condition( self, node: Node, - visited_children: Tuple[Expression, Any, str, Any, Expression], + visited_children: tuple[Expression, Any, str, Any, Expression], ) -> Expression: exp, _, op, _, literal = visited_children return binary_condition(op, exp, literal) @@ -609,16 +589,16 @@ def visit_condition_op(self, node: Node, visited_children: Iterable[Any]) -> str return OPERATOR_TO_FUNCTION[node.text] def visit_order_by_clause( - self, node: Node, visited_children: Tuple[Any, Any, Any, Sequence[OrderBy]] + self, node: Node, visited_children: tuple[Any, Any, Any, Sequence[OrderBy]] ) -> Sequence[OrderBy]: _, _, _, order_columns = visited_children return order_columns def visit_order_list( - self, node: Node, visited_children: Tuple[OrderBy, Expression, Any, Node] + self, node: Node, visited_children: tuple[OrderBy, Expression, Any, Node] ) -> Sequence[OrderBy]: left_order_list, right_order, _, order = visited_children - ret: List[OrderBy] = [] + ret: list[OrderBy] = [] # in the case of one OrderBy # left_order_list will be an empty node @@ -635,7 +615,7 @@ def visit_order_list( return ret def visit_order_columns( - self, node: Node, visited_children: Tuple[Expression, Any, Node, Any, Any, Any] + self, node: Node, visited_children: tuple[Expression, Any, Node, Any, Any, Any] ) -> OrderBy: column, _, order, _, _, _ = visited_children @@ -643,21 +623,21 @@ def visit_order_columns( return OrderBy(direction, column) def visit_sample_clause( - self, node: Node, visited_children: Tuple[Any, Any, Any, Literal] + self, node: Node, visited_children: tuple[Any, Any, Any, Literal] ) -> float: _, _, _, sample = visited_children assert isinstance(sample.value, float) # mypy return sample.value def visit_granularity_clause( - self, node: Node, visited_children: Tuple[Any, Any, Any, Literal] + self, node: Node, visited_children: tuple[Any, Any, Any, Literal] ) -> float: _, _, _, granularity = visited_children assert isinstance(granularity.value, int) # mypy return granularity.value def visit_totals_clause( - self, node: Node, visited_children: Tuple[Any, Any, Any, Literal] + self, node: Node, visited_children: tuple[Any, Any, Any, Literal] ) -> float: _, _, _, totals = visited_children assert isinstance(totals.value, bool) # mypy @@ -666,8 +646,8 @@ def visit_totals_clause( def visit_limit_by_clause( self, node: Node, - visited_children: Tuple[ - Any, Any, Any, Literal, Any, Any, Any, Column, Optional[Sequence[Column]] + visited_children: tuple[ + Any, Any, Any, Literal, Any, Any, Any, Column, Sequence[Column] | None ], ) -> LimitBy: _, _, _, limit, _, _, _, column_one, columns_rest = visited_children @@ -678,23 +658,23 @@ def visit_limit_by_clause( return LimitBy(limit.value, columns) def visit_limit_by_columns( - self, node: Node, visited_children: Sequence[Tuple[Any, Any, Any, Column]] + self, node: Node, visited_children: Sequence[tuple[Any, Any, Any, Column]] ) -> Sequence[Column]: - columns: List[Column] = [] + columns: list[Column] = [] for column_visit in visited_children: _, _, _, column_inst = column_visit columns.append(column_inst) return columns def visit_limit_clause( - self, node: Node, visited_children: Tuple[Any, Any, Any, Literal] + self, node: Node, visited_children: tuple[Any, Any, Any, Literal] ) -> int: _, _, _, limit = visited_children assert isinstance(limit.value, int) # mypy return limit.value def visit_offset_clause( - self, node: Node, visited_children: Tuple[Any, Any, Any, Literal] + self, node: Node, visited_children: tuple[Any, Any, Any, Literal] ) -> int: _, _, _, offset = visited_children assert isinstance(offset.value, int) # mypy @@ -703,13 +683,13 @@ def visit_offset_clause( def visit_group_by_clause( self, node: Node, - visited_children: Tuple[Any, Any, Any, Sequence[SelectedExpression]], + visited_children: tuple[Any, Any, Any, Sequence[SelectedExpression]], ) -> Sequence[SelectedExpression]: _, _, _, group_columns = visited_children return group_columns def visit_group_columns( - self, node: Node, visited_children: Tuple[SelectedExpression, Any, Any] + self, node: Node, visited_children: tuple[SelectedExpression, Any, Any] ) -> SelectedExpression: columns, _, _ = visited_children return columns @@ -717,10 +697,10 @@ def visit_group_columns( def visit_group_list( self, node: Node, - visited_children: Tuple[SelectedExpression, SelectedExpression], + visited_children: tuple[SelectedExpression, SelectedExpression], ) -> Sequence[SelectedExpression]: left_group_list, right_group = visited_children - ret: List[SelectedExpression] = [] + ret: list[SelectedExpression] = [] # in the case of one GroupBy / By # left_group_list will be an empty node @@ -737,7 +717,7 @@ def visit_group_list( def visit_select_clause( self, node: Node, - visited_children: Tuple[Any, Any, Any, Sequence[SelectedExpression]], + visited_children: tuple[Any, Any, Any, Sequence[SelectedExpression]], ) -> Sequence[SelectedExpression]: _, _, _, selected_columns = visited_children return selected_columns @@ -745,7 +725,7 @@ def visit_select_clause( def visit_selected_expression( self, node: Node, - visited_children: Tuple[Any, Union[SelectedExpression, Expression]], + visited_children: tuple[Any, SelectedExpression | Expression], ) -> SelectedExpression: _, exp = visited_children if isinstance(exp, SelectedExpression): @@ -757,7 +737,7 @@ def visit_selected_expression( return SelectedExpression(alias, exp) def visit_select_columns( - self, node: Node, visited_children: Tuple[SelectedExpression, Any, Any] + self, node: Node, visited_children: tuple[SelectedExpression, Any, Any] ) -> SelectedExpression: columns, _, _ = visited_children return columns @@ -765,10 +745,10 @@ def visit_select_columns( def visit_select_list( self, node: Node, - visited_children: Tuple[SelectedExpression, SelectedExpression], + visited_children: tuple[SelectedExpression, SelectedExpression], ) -> Sequence[SelectedExpression]: column_list, right_column = visited_children - ret: List[SelectedExpression] = [] + ret: list[SelectedExpression] = [] # in the case of one Collect # column_list will be an empty node @@ -785,7 +765,7 @@ def visit_select_list( def visit_arrayjoin_clause( self, node: Node, - visited_children: Tuple[Any, Any, Any, Expression, Optional[List[Expression]]], + visited_children: tuple[Any, Any, Any, Expression, list[Expression] | None], ) -> Sequence[Expression]: _, _, _, join_first, join_rest = visited_children exprs = [join_first] @@ -798,9 +778,9 @@ def visit_arrayjoin_clause( def visit_arrayjoin_optional( self, node: Node, - visited_children: List[Tuple[Any, Any, Any, Expression]], - ) -> List[Expression]: - exprs: List[Expression] = list() + visited_children: list[tuple[Any, Any, Any, Expression]], + ) -> list[Expression]: + exprs: list[Expression] = [] if visited_children is not None: for child in visited_children: _, _, _, exp = child @@ -808,27 +788,27 @@ def visit_arrayjoin_optional( return exprs def visit_parameter( - self, node: Node, visited_children: Tuple[Expression, Any, Any, Any] + self, node: Node, visited_children: tuple[Expression, Any, Any, Any] ) -> Expression: return visit_parameter(node, visited_children) def visit_parameters_list( self, node: Node, - visited_children: Tuple[Union[Expression, List[Expression]], Expression], - ) -> List[Expression]: + visited_children: tuple[Expression | list[Expression], Expression], + ) -> list[Expression]: return visit_parameters_list(node, visited_children) def visit_function_call( self, node: Node, - visited_children: Tuple[ + visited_children: tuple[ str, Any, - List[Expression], + list[Expression], Any, - Union[Node, List[Expression]], - Union[Node, List[Any]], + Node | list[Expression], + Node | list[Any], ], ) -> Expression: name, _, params1, _, params2, alias = visited_children @@ -856,7 +836,7 @@ def visit_function_call( def visit_aliased_tag_column( self, node: Node, - visited_children: Tuple[Column, Any, Any, Any, Node], + visited_children: tuple[Column, Any, Any, Any, Node], ) -> SelectedExpression: column, _, _, _, alias = visited_children return SelectedExpression(self.__extract_alias_from_match(alias), column) @@ -864,7 +844,7 @@ def visit_aliased_tag_column( def visit_aliased_subscriptable( self, node: Node, - visited_children: Tuple[Column, Any, Any, Any, Node], + visited_children: tuple[Column, Any, Any, Any, Node], ) -> SelectedExpression: column, _, _, _, alias = visited_children return SelectedExpression(self.__extract_alias_from_match(alias), column) @@ -872,22 +852,22 @@ def visit_aliased_subscriptable( def visit_aliased_column_name( self, node: Node, - visited_children: Tuple[Column, Any, Any, Any, Node], + visited_children: tuple[Column, Any, Any, Any, Node], ) -> SelectedExpression: column, _, _, _, alias = visited_children return SelectedExpression(self.__extract_alias_from_match(alias), column) - def visit_identifier(self, node: Node, visited_children: Tuple[Any, Node, Any]) -> Argument: + def visit_identifier(self, node: Node, visited_children: tuple[Any, Node, Any]) -> Argument: return Argument(None, visited_children[1].text) def visit_lambda( self, node: Node, - visited_children: Tuple[ + visited_children: tuple[ Any, Any, Argument, - Union[Node, List[Node | Argument]], + Node | list[Node | Argument], Any, Any, Any, @@ -915,7 +895,7 @@ def generic_visit(self, node: Node, visited_children: Any) -> Any: def parse_snql_query_initial( body: str, -) -> Union[CompositeQuery[LogicalDataSource], LogicalQuery]: +) -> CompositeQuery[LogicalDataSource] | LogicalQuery: """ Parses the query body generating the AST. This only takes into account the initial query body. Extensions are parsed by extension @@ -937,12 +917,12 @@ def parse_snql_query_initial( idx = e.column() prefix = line[max(0, idx - 3) : idx] suffix = line[idx : (idx + 10)] - raise ParsingException(f"Parsing error on line {e.line()} at '{prefix}{suffix}'") + raise ParsingException(f"Parsing error on line {e.line()} at '{prefix}{suffix}'") from e except Exception as e: message = str(e) if "\n" in message: message, _ = message.split("\n", 1) - raise ParsingException(message) + raise ParsingException(message) from e assert isinstance(parsed, (CompositeQuery, LogicalQuery)) # mypy @@ -959,7 +939,7 @@ def parse_snql_query_initial( return parsed -def _qualify_columns(query: Union[CompositeQuery[LogicalDataSource], LogicalQuery]) -> None: +def _qualify_columns(query: CompositeQuery[LogicalDataSource] | LogicalQuery) -> None: """ All columns in a join query should be qualified with the entity alias, e.g. e.event_id Take those aliases and put them in the table name. This has to be done in a post @@ -988,7 +968,7 @@ def transform(exp: Expression) -> Expression: def _treeify_or_and_conditions( - query: Union[CompositeQuery[LogicalDataSource], LogicalQuery], + query: CompositeQuery[LogicalDataSource] | LogicalQuery, ) -> None: """ look for expressions like or(a, b, c) and turn them into or(a, or(b, c)) @@ -1009,10 +989,9 @@ def transform(exp: Expression) -> Expression: if exp.function_name == "and": return combine_and_conditions(exp.parameters) - elif exp.function_name == "or": + if exp.function_name == "or": return combine_or_conditions(exp.parameters) - else: - return exp + return exp query.transform_expressions(transform) @@ -1022,7 +1001,7 @@ def transform(exp: Expression) -> Expression: ) -def _parse_datetime_literals(query: Union[CompositeQuery[LogicalDataSource], LogicalQuery]) -> None: +def _parse_datetime_literals(query: CompositeQuery[LogicalDataSource] | LogicalQuery) -> None: def parse(exp: Expression) -> Expression: result = DATETIME_MATCH.match(exp) if result is not None: @@ -1047,7 +1026,7 @@ def parse(exp: Expression) -> Expression: def _array_join_transformation( - query: Union[CompositeQuery[LogicalDataSource], LogicalQuery], + query: CompositeQuery[LogicalDataSource] | LogicalQuery, ) -> None: def parse(exp: Expression) -> Expression: result = ARRAY_JOIN_MATCH.match(exp) @@ -1091,10 +1070,8 @@ def parse(exp: Expression) -> Expression: query.transform_expressions(parse) -def _transform_array_condition(array_columns: Set[str], exp: Expression) -> Expression: - if not is_condition(exp) or not isinstance(exp, FunctionCall): - return exp - elif len(exp.parameters) < 2: +def _transform_array_condition(array_columns: set[str], exp: Expression) -> Expression: + if not is_condition(exp) or not isinstance(exp, FunctionCall) or len(exp.parameters) < 2: return exp lhs = exp.parameters[0] @@ -1138,11 +1115,11 @@ def _transform_array_condition(array_columns: Set[str], exp: Expression) -> Expr def _unpack_array_conditions( - query: Union[CompositeQuery[LogicalDataSource], LogicalQuery], + query: CompositeQuery[LogicalDataSource] | LogicalQuery, schema: ColumnSet, - entity_alias: Optional[str] = None, + entity_alias: str | None = None, ) -> None: - array_columns: Set[str] = set() + array_columns: set[str] = set() array_join_arguments = query.get_arrayjoin() array_join_columns = set() if array_join_arguments is not None: @@ -1170,7 +1147,7 @@ def _unpack_array_conditions( ) -def _array_column_conditions(query: Union[CompositeQuery[LogicalDataSource], LogicalQuery]) -> None: +def _array_column_conditions(query: CompositeQuery[LogicalDataSource] | LogicalQuery) -> None: """ Find conditions on array columns, and if those columns are not in the array join, convert them to the appropriate higher order function. @@ -1191,7 +1168,7 @@ def _array_column_conditions(query: Union[CompositeQuery[LogicalDataSource], Log def _mangle_query_aliases( - query: Union[CompositeQuery[LogicalDataSource], LogicalQuery], + query: CompositeQuery[LogicalDataSource] | LogicalQuery, ) -> None: """ If a query has a subquery, the inner query will get its aliases mangled. This is @@ -1227,14 +1204,14 @@ def mangle_column_value(exp: Expression) -> Expression: def validate_identifiers_in_lambda( - query: Union[CompositeQuery[LogicalDataSource], LogicalQuery], + query: CompositeQuery[LogicalDataSource] | LogicalQuery, ) -> None: """ Check to make sure that any identifiers referenced in a lambda were defined in that lambda or in an outer lambda. """ - identifiers: Set[str] = set() - unseen_identifiers: Set[str] = set() + identifiers: set[str] = set() + unseen_identifiers: set[str] = set() def validate_lambda(exp: Lambda) -> None: for p in exp.parameters: @@ -1260,10 +1237,7 @@ def validate_lambda(exp: Lambda) -> None: def _replace_time_condition( - query: Union[ - CompositeQuery[LogicalDataSource], - LogicalQuery, - ], + query: CompositeQuery[LogicalDataSource] | LogicalQuery, ) -> None: condition = query.get_condition() top_level = get_first_level_and_conditions(condition) if condition is not None else [] @@ -1295,9 +1269,9 @@ def _replace_time_condition( def _align_max_days_date_align( key: EntityKey | StorageKey, old_top_level: Sequence[Expression], - max_days: Optional[int], + max_days: int | None, date_align: int, - alias: Optional[str] = None, + alias: str | None = None, ) -> Sequence[Expression]: data_source: Entity | Storage | None = None data_source_name = "entity" @@ -1330,7 +1304,7 @@ def _align_max_days_date_align( f"Missing >= condition with a datetime literal on column {data_source.required_time_column} for {data_source_name} {key.value}. " f"Example: {data_source.required_time_column} >= toDateTime('2023-05-16 00:00')" ) - elif not upper: + if not upper: raise ParsingException( f"Missing < condition with a datetime literal on column {data_source.required_time_column} for {data_source_name} {key.value}. " f"Example: {data_source.required_time_column} < toDateTime('2023-05-16 00:00')" @@ -1353,12 +1327,12 @@ def _align_max_days_date_align( def replace_cond(exp: Expression) -> Expression: if not isinstance(exp, FunctionCall): return exp - elif exp == from_exp: + if exp == from_exp: return replace( exp, parameters=(from_exp.parameters[0], Literal(None, from_date)), ) - elif exp == to_exp: + if exp == to_exp: return replace(exp, parameters=(to_exp.parameters[0], Literal(None, to_date))) return exp @@ -1368,8 +1342,8 @@ def replace_cond(exp: Expression) -> Expression: def _select_entity_for_dataset( dataset: Dataset, -) -> Callable[[Union[CompositeQuery[LogicalDataSource], LogicalQuery]], None]: - def selector(query: Union[CompositeQuery[LogicalDataSource], LogicalQuery]) -> None: +) -> Callable[[CompositeQuery[LogicalDataSource] | LogicalQuery], None]: + def selector(query: CompositeQuery[LogicalDataSource] | LogicalQuery) -> None: # If you are doing a JOIN, then you have to specify the entity if isinstance(query, CompositeQuery): return @@ -1417,8 +1391,8 @@ def replace_time_condition_aliases(exp: Expression) -> Expression: def _post_process( - query: Union[CompositeQuery[LogicalDataSource], LogicalQuery], - funcs: Sequence[Callable[[Union[CompositeQuery[LogicalDataSource], LogicalQuery]], None]], + query: CompositeQuery[LogicalDataSource] | LogicalQuery, + funcs: Sequence[Callable[[CompositeQuery[LogicalDataSource] | LogicalQuery], None]], settings: QuerySettings | None = None, ) -> None: for func in funcs: @@ -1456,17 +1430,15 @@ def _post_process( ] -CustomProcessors = Sequence[ - Callable[[Union[CompositeQuery[LogicalDataSource], LogicalQuery]], None] -] +CustomProcessors = Sequence[Callable[[CompositeQuery[LogicalDataSource] | LogicalQuery], None]] def parse_snql_query( body: str, dataset: Dataset, - custom_processing: Optional[CustomProcessors] = None, + custom_processing: CustomProcessors | None = None, settings: QuerySettings | None = None, -) -> Union[CompositeQuery[LogicalDataSource], LogicalQuery]: +) -> CompositeQuery[LogicalDataSource] | LogicalQuery: with sentry_sdk.start_span(op="parser", description="parse_snql_query_initial"): query = parse_snql_query_initial(body) @@ -1495,8 +1467,8 @@ def parse_snql_query( return res.data except InvalidQueryException: raise - except Exception: - raise PostProcessingError(query) + except Exception as e: + raise PostProcessingError(query) from e class PostProcessAndValidateQuery( diff --git a/snuba/query/validation/__init__.py b/snuba/query/validation/__init__.py index 3e6a818eaee..a31990c8b0f 100644 --- a/snuba/query/validation/__init__.py +++ b/snuba/query/validation/__init__.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Sequence +from collections.abc import Sequence from snuba.query.data_source import DataSource from snuba.query.expressions import Expression diff --git a/snuba/query/validation/functions.py b/snuba/query/validation/functions.py index 93c366a9931..b77fff4f010 100644 --- a/snuba/query/validation/functions.py +++ b/snuba/query/validation/functions.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba import environment, state from snuba.query.data_source import DataSource @@ -24,5 +24,4 @@ def validate( if state.get_config("function-validator.enabled", False): raise InvalidFunctionCall(f"Invalid function name: {func_name}") - else: - metrics.increment("invalid_funcs", tags={"func_name": func_name}) + metrics.increment("invalid_funcs", tags={"func_name": func_name}) diff --git a/snuba/query/validation/signature.py b/snuba/query/validation/signature.py index 166e208491b..eeb68dc79f7 100644 --- a/snuba/query/validation/signature.py +++ b/snuba/query/validation/signature.py @@ -1,7 +1,7 @@ import logging -from abc import ABC +from abc import ABC, abstractmethod +from collections.abc import Sequence from datetime import date, datetime -from typing import Sequence, Set, Type, Union from snuba.clickhouse.columns import ( UUID, @@ -14,6 +14,7 @@ IPv6, Nullable, String, + TypeModifiers, UInt, ) from snuba.query.data_source import DataSource @@ -29,6 +30,7 @@ class ParamType(ABC): + @abstractmethod def validate(self, expression: Expression, schema: ColumnSet) -> None: raise NotImplementedError @@ -46,28 +48,22 @@ def __str__(self) -> str: column_name=Param("column_name", AnyMatcher(str)), ) -AllowedTypes = Union[ - Type[Array], - Type[String], - Type[UUID], - Type[IPv4], - Type[IPv6], - Type[FixedString], - Type[UInt], - Type[Float], - Type[Date], - Type[DateTime], -] - -AllowedScalarTypes = Union[ - Type[None], - Type[bool], - Type[str], - Type[float], - Type[int], - Type[date], - Type[datetime], -] +AllowedTypes = ( + type[Array[TypeModifiers]] + | type[String[TypeModifiers]] + | type[UUID[TypeModifiers]] + | type[IPv4[TypeModifiers]] + | type[IPv6[TypeModifiers]] + | type[FixedString[TypeModifiers]] + | type[UInt[TypeModifiers]] + | type[Float[TypeModifiers]] + | type[Date[TypeModifiers]] + | type[DateTime[TypeModifiers]] +) + +AllowedScalarTypes = ( + type[None] | type[bool] | type[str] | type[float] | type[int] | type[date] | type[datetime] +) class Column(ParamType): @@ -84,7 +80,7 @@ class Column(ParamType): is False it will require non nullable columns. """ - def __init__(self, types: Set[AllowedTypes], allow_nullable: bool = True) -> None: + def __init__(self, types: set[AllowedTypes], allow_nullable: bool = True) -> None: self.__valid_types = types self.__allow_nullable = allow_nullable @@ -111,10 +107,8 @@ def validate(self, expression: Expression, schema: ColumnSet) -> None: nullable and not self.__allow_nullable ): raise InvalidFunctionCall( - ( - f"Illegal type {'Nullable ' if nullable else ''}{str(column.type)} " - f"of argument `{column_name}`. Required types {self.__valid_types}" - ) + f"Illegal type {'Nullable ' if nullable else ''}{str(column.type)} " + f"of argument `{column_name}`. Required types {self.__valid_types}" ) @@ -128,7 +122,7 @@ class Literal(ParamType): expressions can be passed as arguments in certain functions. """ - def __init__(self, types: Set[AllowedScalarTypes], allow_nullable: bool = False) -> None: + def __init__(self, types: set[AllowedScalarTypes], allow_nullable: bool = False) -> None: self.__valid_types = types if allow_nullable: self.__valid_types.add(type(None)) @@ -138,7 +132,7 @@ def __str__(self) -> str: def validate(self, expression: Expression, schema: ColumnSet) -> None: if not isinstance(expression, LiteralType): - return None + return value = expression.value if not isinstance(value, tuple(self.__valid_types)): @@ -176,8 +170,7 @@ def validate( except InvalidFunctionCall as exception: if self.__enforce: raise exception - else: - logger.warning(f"Query validation exception. Validator: {self}", exc_info=True) + logger.warning(f"Query validation exception. Validator: {self}", exc_info=True) def __validate_impl( self, func_name: str, parameters: Sequence[Expression], data_source: DataSource @@ -192,5 +185,5 @@ def __validate_impl( f"Too many arguments. Required {[str(t) for t in self.__param_types]}" ) - for validator, param in zip(self.__param_types, parameters): + for validator, param in zip(self.__param_types, parameters, strict=False): validator.validate(param, data_source.get_columns()) diff --git a/snuba/query/validation/validators.py b/snuba/query/validation/validators.py index 0b6b046bd75..b3f3e1449a0 100644 --- a/snuba/query/validation/validators.py +++ b/snuba/query/validation/validators.py @@ -2,9 +2,10 @@ import logging from abc import ABC, abstractmethod +from collections.abc import Sequence from datetime import datetime from enum import Enum -from typing import Optional, Sequence, Type, cast +from typing import cast from snuba.clickhouse.translators.snuba.allowed import DefaultNoneColumnMapper from snuba.clickhouse.translators.snuba.function_call_mappers import ( @@ -62,14 +63,14 @@ def config_key(cls) -> str: return cls.__name__ @classmethod - def get_from_name(cls, name: str) -> Type["QueryValidator"]: - return cast(Type["QueryValidator"], cls.class_from_name(name)) + def get_from_name(cls, name: str) -> type[QueryValidator]: + return cast(type["QueryValidator"], cls.class_from_name(name)) @abstractmethod def validate( self, query: Query, - alias: Optional[str] = None, + alias: str | None = None, ) -> None: """ Validate that the query is correct. If the query is not valid, raise an @@ -99,7 +100,7 @@ def __init__( self.required_columns = set(required_filter_columns) self.required_str_columns = set(required_str_columns) if required_str_columns else set() - def validate(self, query: Query, alias: Optional[str] = None) -> None: + def validate(self, query: Query, alias: str | None = None) -> None: condition = query.get_condition() top_level = get_first_level_and_conditions(condition) if condition else [] @@ -158,7 +159,7 @@ def __init__( elif isinstance(col_mapping, DefaultNoneColumnMapper): self.mapped_columns.update(col_mapping.column_names) - def validate(self, query: Query, alias: Optional[str] = None) -> None: + def validate(self, query: Query, alias: str | None = None) -> None: if self.validation_mode == ColumnValidationMode.DO_NOTHING: return @@ -175,7 +176,7 @@ def validate(self, query: Query, alias: Optional[str] = None) -> None: error_message = f"Tag keys ({', '.join(missing)}) not resolved" if self.validation_mode == ColumnValidationMode.ERROR: raise InvalidQueryException(error_message) - elif self.validation_mode == ColumnValidationMode.WARN: + if self.validation_mode == ColumnValidationMode.WARN: logger.warning(error_message, exc_info=True) @@ -201,7 +202,7 @@ def __init__(self, required_time_column: str) -> None: param_type=datetime, ) - def validate(self, query: Query, alias: Optional[str] = None) -> None: + def validate(self, query: Query, alias: str | None = None) -> None: condition = query.get_condition() top_level = get_first_level_and_conditions(condition) if condition else [] for cond in top_level: @@ -231,7 +232,7 @@ def __init__( @staticmethod def _validate_groupby_fields_have_matching_conditions( - query: Query, alias: Optional[str] = None + query: Query, alias: str | None = None ) -> None: """ Method that insures that for every field in the group by clause, there should be a @@ -273,7 +274,7 @@ def _validate_groupby_fields_have_matching_conditions( def validate( self, query: Query, - alias: Optional[str] = None, + alias: str | None = None, ) -> None: selected = query.get_selected_columns() if len(selected) > self.max_allowed_aggregations: @@ -304,7 +305,7 @@ def __init__(self, minimum: int, required: bool = False): self.minimum = minimum self.required = required - def validate(self, query: Query, alias: Optional[str] = None) -> None: + def validate(self, query: Query, alias: str | None = None) -> None: granularity = query.get_granularity() if granularity is None: if self.required: @@ -336,7 +337,7 @@ def __init__(self) -> None: array_ops=[ConditionFunctions.IN, ConditionFunctions.NOT_IN], ) - def validate(self, query: Query, alias: Optional[str] = None) -> None: + def validate(self, query: Query, alias: str | None = None) -> None: condition = query.get_condition() if not condition: return @@ -383,7 +384,7 @@ def initialize(self, schema: ColumnSet) -> None: ) self.matchers.append(matcher) - def validate(self, query: Query, alias: Optional[str] = None) -> None: + def validate(self, query: Query, alias: str | None = None) -> None: if not isinstance(query, LogicalQuery): return # TODO: This doesn't work for queries with multiple entities. @@ -461,7 +462,7 @@ def __init__(self) -> None: [String(f"{func_name}If") for func_name in common_aggregate_functions] ) - def validate(self, query: Query, alias: Optional[str] = None) -> None: + def validate(self, query: Query, alias: str | None = None) -> None: def find_illegal_aggregate_functions( expression: Expression, ) -> list[MatchResult]: diff --git a/snuba/querylog/__init__.py b/snuba/querylog/__init__.py index fcef4cccf7a..ad85d5c8303 100644 --- a/snuba/querylog/__init__.py +++ b/snuba/querylog/__init__.py @@ -1,8 +1,9 @@ from __future__ import annotations import time +from collections.abc import Mapping from random import random -from typing import Any, Mapping, Optional, Union +from typing import Any from uuid import UUID import sentry_sdk @@ -13,14 +14,18 @@ from snuba.cogs.accountant import record_cogs from snuba.datasets.storage import StorageNotAvailable from snuba.query.exceptions import QueryPlanException -from snuba.querylog.query_metadata import QueryStatus, SnubaQueryMetadata, Status +from snuba.querylog.query_metadata import ( + QueryStatus, + SnubaQueryMetadata, + Status, + get_request_status, +) from snuba.request import Request from snuba.utils.metrics.timer import Timer from snuba.utils.metrics.wrapper import MetricsWrapper from snuba.web import QueryException, QueryResult metrics = MetricsWrapper(environment.metrics, "api") -from snuba.querylog.query_metadata import get_request_status _ITEM_TYPE_TO_APP_FEATURE: dict[str, str] = { "TRACE_ITEM_TYPE_SPAN": "spans", @@ -49,7 +54,7 @@ def _record_timer_metrics( request: Request, timer: Timer, query_metadata: SnubaQueryMetadata, - result: Union[QueryResult, QueryException, QueryPlanException], + result: QueryResult | QueryException | QueryPlanException, ) -> None: final = str(request.query.get_final()) referrer = request.referrer or "none" @@ -95,7 +100,7 @@ def _record_timer_metrics( def _record_bytes_scanned_metrics( query_metadata: SnubaQueryMetadata, - result: Union[QueryResult, QueryException, QueryPlanException], + result: QueryResult | QueryException | QueryPlanException, ) -> None: """ Experimental metrics - trying to understand whether or not @@ -122,7 +127,7 @@ def _record_bytes_scanned_metrics( def _record_cogs( request: Request, query_metadata: SnubaQueryMetadata, - result: Union[QueryResult, QueryException, QueryPlanException], + result: QueryResult | QueryException | QueryPlanException, ) -> None: """ Record bytes scanned for the clickhouse compute of resource of a query. @@ -189,7 +194,7 @@ def record_query( request: Request, timer: Timer, query_metadata: SnubaQueryMetadata, - result: Union[QueryResult, QueryException, QueryPlanException], + result: QueryResult | QueryException | QueryPlanException, ) -> None: """ Records a request after it has been parsed and validated, whether @@ -213,8 +218,8 @@ def record_query( def _add_tags( timer: Timer, - experiments: Optional[Mapping[str, Any]] = None, - metadata: Optional[SnubaQueryMetadata] = None, + experiments: Mapping[str, Any] | None = None, + metadata: SnubaQueryMetadata | None = None, ) -> None: if sentry_sdk.get_current_span(): duration_group = timer.get_duration_group() @@ -238,7 +243,7 @@ def _build_failed_request_dict( dataset: str, organization: int, request_status: Status, - referrer: Optional[str], + referrer: str | None, exception_name: str | None = None, ) -> snuba_queries_v1.Querylog: return { @@ -272,7 +277,7 @@ def record_invalid_request( organization: int, timer: Timer, request_status: Status, - referrer: Optional[str], + referrer: str | None, exception_name: str | None = None, ) -> None: """ @@ -303,7 +308,7 @@ def record_error_building_request( organization: int, timer: Timer, request_status: Status, - referrer: Optional[str], + referrer: str | None, exception_name: str | None = None, ) -> None: """ @@ -331,7 +336,7 @@ def _record_failure_metric_with_status( status: QueryStatus, request_status: Status, timer: Timer, - referrer: Optional[str], + referrer: str | None, exception_name: str | None = None, ) -> None: # TODO: Revisit if recording some data for these queries in the querylog diff --git a/snuba/querylog/query_metadata.py b/snuba/querylog/query_metadata.py index ed913609c0e..b0ca84f9e8b 100644 --- a/snuba/querylog/query_metadata.py +++ b/snuba/querylog/query_metadata.py @@ -1,9 +1,10 @@ from __future__ import annotations +from collections.abc import Mapping, MutableSequence from dataclasses import dataclass from datetime import datetime from enum import Enum -from typing import Any, Dict, Mapping, MutableSequence, Optional, Set, cast +from typing import Any, cast from clickhouse_driver.errors import ErrorCodes from sentry_kafka_schemas.schema_types import snuba_queries_v1 @@ -126,7 +127,7 @@ def slo(self) -> SLO: def get_request_status( - cause: Exception | None = None, context: Optional[Mapping[str, Any]] = None + cause: Exception | None = None, context: Mapping[str, Any] | None = None ) -> Status: slo_status: RequestStatus if cause is None: @@ -152,18 +153,22 @@ def get_request_status( elif isinstance(cause, ExecutionTimeoutError): slo_status = RequestStatus.CACHE_WAIT_TIMEOUT elif isinstance( - cause, (StorageNotAvailable, InvalidJsonRequestException, InvalidQueryException) + cause, + ( + StorageNotAvailable, + InvalidJsonRequestException, + InvalidQueryException, + QueryTooLongException, + ), ): slo_status = RequestStatus.INVALID_REQUEST - elif isinstance(cause, QueryTooLongException): - slo_status = RequestStatus.INVALID_REQUEST else: slo_status = RequestStatus.ERROR return Status(slo_status) -Columnset = Set[str] +Columnset = set[str] @dataclass(frozen=True) @@ -187,7 +192,7 @@ class ClickhouseQueryProfile: it easier to analyze both in the querylog and in discover. """ - time_range: Optional[int] # range in days + time_range: int | None # range in days table: str all_columns: Columnset # True if we have a combination of AND and OR instead of @@ -216,14 +221,14 @@ def to_dict(self) -> snuba_queries_v1.ClickhouseQueryProfile: class ClickhouseQueryMetadata: sql: str sql_anonymized: str - start_timestamp: Optional[datetime] - end_timestamp: Optional[datetime] - stats: Dict[str, Any] + start_timestamp: datetime | None + end_timestamp: datetime | None + stats: dict[str, Any] status: QueryStatus request_status: Status profile: ClickhouseQueryProfile trace_id: str - result_profile: Optional[snuba_queries_v1._QueryMetadataResultProfileObject] = None + result_profile: snuba_queries_v1._QueryMetadataResultProfileObject | None = None def to_dict(self) -> snuba_queries_v1.QueryMetadata: start = int(self.start_timestamp.timestamp()) if self.start_timestamp else None @@ -257,7 +262,7 @@ def __init__( end_timestamp: datetime | None = None, entity: str | None = None, query_list: MutableSequence[ClickhouseQueryMetadata] | None = None, - projects: Set[int] | None = None, + projects: set[int] | None = None, snql_anonymized: str | None = None, ): if not ( diff --git a/snuba/reader.py b/snuba/reader.py index 788e26b818d..ceb1aa2f23c 100644 --- a/snuba/reader.py +++ b/snuba/reader.py @@ -3,44 +3,37 @@ import itertools import re from abc import ABC, abstractmethod +from collections.abc import Callable, Iterator, Mapping, Sequence +from re import Pattern from typing import ( Any, - Callable, - Dict, - Iterator, - List, - Mapping, - Optional, - Pattern, - Sequence, - Tuple, TypedDict, TypeVar, ) from snuba.clickhouse.formatter.nodes import FormattedQuery -Column = TypedDict("Column", {"name": str, "type": str}, total=False) -Row = Dict[str, Any] - -Result = TypedDict( - "Result", - { - "meta": List[Column], - "data": List[Row], - "totals": Row, - "profile": Optional[Dict[str, Any]], - "trace_output": str, - }, - total=False, -) + +class Column(TypedDict, total=False): + name: str + type: str + + +Row = dict[str, Any] + + +class Result(TypedDict, total=False): + meta: list[Column] + data: list[Row] + totals: Row + profile: dict[str, Any] | None + trace_output: str def iterate_rows(result: Result) -> Iterator[Row]: if "totals" in result: return itertools.chain(result["data"], [result["totals"]]) - else: - return iter(result["data"]) + return iter(result["data"]) def transform_rows(result: Result, transformer: Callable[[Row], Row]) -> None: @@ -58,30 +51,28 @@ def transform_rows(result: Result, transformer: Callable[[Row], Row]) -> None: NULLABLE_RE = re.compile(r"^Nullable\((.+)\)$") -def unwrap_nullable_type(type: str) -> Tuple[bool, str]: +def unwrap_nullable_type(type: str) -> tuple[bool, str]: match = NULLABLE_RE.match(type) if match is not None: return True, match.groups()[0] - else: - return False, type + return False, type T = TypeVar("T") R = TypeVar("R") -def transform_nullable(function: Callable[[T], R]) -> Callable[[Optional[T]], Optional[R]]: - def transform_column(value: Optional[T]) -> Optional[R]: +def transform_nullable(function: Callable[[T], R]) -> Callable[[T | None], R | None]: + def transform_column(value: T | None) -> R | None: if value is None: return value - else: - return function(value) + return function(value) return transform_column def build_result_transformer( - column_transformations: Sequence[Tuple[Pattern[str], Callable[[Any], Any]]], + column_transformations: Sequence[tuple[Pattern[str], Callable[[Any], Any]]], ) -> Callable[[Result], None]: """ Builds and returns a function that can be used to mutate a ``Result`` @@ -116,9 +107,7 @@ def transform_result(result: Result) -> None: class Reader(ABC): - def __init__( - self, cache_partition_id: Optional[str], query_settings_prefix: Optional[str] - ) -> None: + def __init__(self, cache_partition_id: str | None, query_settings_prefix: str | None) -> None: self.__cache_partition_id = cache_partition_id self.__query_settings_prefix = query_settings_prefix @@ -126,7 +115,7 @@ def __init__( def execute( self, query: FormattedQuery, - settings: Optional[Mapping[str, str]] = None, + settings: Mapping[str, str] | None = None, with_totals: bool = False, robust: bool = False, capture_trace: bool = False, @@ -135,7 +124,7 @@ def execute( raise NotImplementedError @property - def cache_partition_id(self) -> Optional[str]: + def cache_partition_id(self) -> str | None: """ Return the cache partition if there is one. @@ -147,7 +136,7 @@ def cache_partition_id(self) -> Optional[str]: """ return self.__cache_partition_id - def get_query_settings_prefix(self) -> Optional[str]: + def get_query_settings_prefix(self) -> str | None: """ Return the query settings prefix if there is one. """ diff --git a/snuba/redis.py b/snuba/redis.py index 0c251b0bdfc..860b71be8dc 100644 --- a/snuba/redis.py +++ b/snuba/redis.py @@ -1,9 +1,10 @@ -from __future__ import absolute_import, annotations +from __future__ import annotations import time +from collections.abc import Callable, Iterable, Mapping from enum import Enum from functools import wraps -from typing import Any, Callable, Iterable, Mapping, TypeVar, Union, cast +from typing import Any, TypeAlias, TypeVar, cast from redis.cluster import ClusterNode, RedisCluster from redis.exceptions import RedisClusterException @@ -19,7 +20,7 @@ # FailoverRedis anyway. SingleNodeRedis = FailoverRedis -RedisClientType = Union[SingleNodeRedis, RedisCluster] +RedisClientType: TypeAlias = SingleNodeRedis | RedisCluster class FailedClusterInitization(SerializableException): @@ -70,18 +71,17 @@ def _initialize_redis_cluster(config: settings.RedisClusterConfig) -> RedisClien reinitialize_steps=config["reinitialize_steps"], socket_timeout=config.get("socket_timeout", settings.REDIS_SOCKET_TIMEOUT), ) - else: - return SingleNodeRedis( - _retries=2, - _backoff_max=3, - host=config["host"], - port=config["port"], - password=config["password"], - db=config["db"], - ssl=config.get("ssl", False), - socket_keepalive=True, - socket_timeout=config.get("socket_timeout", settings.REDIS_SOCKET_TIMEOUT), - ) + return SingleNodeRedis( + _retries=2, + _backoff_max=3, + host=config["host"], + port=config["port"], + password=config["password"], + db=config["db"], + ssl=config.get("ssl", False), + socket_keepalive=True, + socket_timeout=config.get("socket_timeout", settings.REDIS_SOCKET_TIMEOUT), + ) _default_redis_client: RedisClientType = _initialize_redis_cluster( @@ -105,7 +105,7 @@ def _initialize_specialized_redis_cluster( if config is None: return _default_redis_client for k, v in overrides.items(): - config[k] = v # type: ignore + config[k] = v # type: ignore[literal-required] return _initialize_redis_cluster(config) diff --git a/snuba/replacer.py b/snuba/replacer.py index 6cf9e7f547b..12d83459cf6 100644 --- a/snuba/replacer.py +++ b/snuba/replacer.py @@ -4,18 +4,12 @@ import time from abc import ABC, abstractmethod from collections import defaultdict +from collections.abc import Callable, Mapping, MutableMapping, Sequence from concurrent.futures import ThreadPoolExecutor, as_completed from datetime import datetime from functools import partial from typing import ( Any, - Callable, - List, - Mapping, - MutableMapping, - Optional, - Sequence, - Tuple, TypeVar, ) @@ -106,7 +100,7 @@ def __init__( cluster: ClickhouseCluster, ) -> None: self.__cluster = cluster - self.__nodes: Mapping[int, List[ClickhouseNode]] = defaultdict(list) + self.__nodes: Mapping[int, list[ClickhouseNode]] = defaultdict(list) self.__nodes_refreshed_at = time.time() def get_connections(self) -> Mapping[int, Sequence[ClickhouseNode]]: @@ -289,7 +283,7 @@ class ReplacerStrategyFactory(ProcessingStrategyFactory[KafkaPayload]): def __init__( self, worker: ReplacerWorker, - health_check_file: Optional[str] = None, + health_check_file: str | None = None, ) -> None: self.__worker = worker self.__health_check_file = health_check_file @@ -335,7 +329,7 @@ def __init__( self.__sharded_pool = InOrderConnectionPool(self.__storage.get_cluster()) self.__rate_limiter = RateLimiter("replacements") - self.__last_offset_processed_per_partition: MutableMapping[str, int] = dict() + self.__last_offset_processed_per_partition: MutableMapping[str, int] = {} self.__consumer_group = consumer_group def __get_insert_executor(self, replacement: Replacement) -> InsertExecutor: @@ -363,11 +357,11 @@ def run_query( ) -> None: t = time.time() - logger.debug("Executing replace query: %s" % query) + logger.debug(f"Executing replace query: {query}") connection.execute_robust(query) duration = int((time.time() - t) * 1000) - logger.info("Replacing %s rows took %sms" % (records_count, duration)) + logger.info(f"Replacing {records_count} rows took {duration}ms") metrics.timing( "replacements.count", records_count, @@ -406,7 +400,7 @@ def run_query( def process_message( self, message: Message[KafkaPayload] - ) -> Optional[Tuple[ReplacementMessageMetadata, Replacement]]: + ) -> tuple[ReplacementMessageMetadata, Replacement] | None: assert isinstance(message.value, BrokerValue) metadata = ReplacementMessageMetadata( partition_index=message.value.partition.index, @@ -437,12 +431,10 @@ def process_message( ) if processed is not None: return metadata, processed - else: - return None - else: - raise InvalidMessageVersion("Unknown message format: " + str(seq_message)) + return None + raise InvalidMessageVersion("Unknown message format: " + str(seq_message)) - def flush_batch(self, batch: Sequence[Tuple[ReplacementMessageMetadata, Replacement]]) -> None: + def flush_batch(self, batch: Sequence[tuple[ReplacementMessageMetadata, Replacement]]) -> None: need_optimize = False clickhouse_read = self.__storage.get_cluster().get_query_connection( ClickhouseClientSettings.REPLACE @@ -487,7 +479,7 @@ def flush_batch(self, batch: Sequence[Tuple[ReplacementMessageMetadata, Replacem num_dropped = run_optimize( clickhouse_read, self.__storage, self.__database_name, before=today ) - logger.info("Optimized %s partitions on %s" % (num_dropped, clickhouse_read.host)) + logger.info(f"Optimized {num_dropped} partitions on {clickhouse_read.host}") def _message_already_processed(self, metadata: ReplacementMessageMetadata) -> bool: """ @@ -592,5 +584,5 @@ def _attempt_emitting_metric_for_projects_exceeding_limit( ",".join(str(project_id) for project_id in projects_exceeding_limit) ) ) - for project_id in projects_exceeding_limit: + for _project_id in projects_exceeding_limit: self.metrics.increment("project_processing_time_exceeded_time_limit") diff --git a/snuba/replacers/errors_replacer.py b/snuba/replacers/errors_replacer.py index 011601b4ea5..b449b599697 100644 --- a/snuba/replacers/errors_replacer.py +++ b/snuba/replacers/errors_replacer.py @@ -6,20 +6,12 @@ import uuid from abc import abstractmethod from collections import deque +from collections.abc import Mapping, MutableMapping, Sequence from dataclasses import dataclass from datetime import datetime from functools import cached_property from typing import ( Any, - Deque, - List, - Mapping, - MutableMapping, - Optional, - Sequence, - Tuple, - Type, - Union, cast, ) @@ -78,7 +70,7 @@ class ExcludeGroups: group_ids: Sequence[int] -QueryTimeFlags = Union[NeedsFinal, ExcludeGroups] +QueryTimeFlags = NeedsFinal | ExcludeGroups @dataclass(frozen=True) @@ -97,11 +89,11 @@ class Replacement(ReplacementBase): @abstractmethod def parse_message( cls, message: ReplacementMessage[Any], context: ReplacementContext - ) -> Optional[Replacement]: + ) -> Replacement | None: raise NotImplementedError() @abstractmethod - def get_query_time_flags(self) -> Optional[QueryTimeFlags]: + def get_query_time_flags(self) -> QueryTimeFlags | None: raise NotImplementedError() @abstractmethod @@ -111,9 +103,7 @@ def get_project_id(self) -> int: def should_write_every_node(self) -> bool: write_node_replacement_setting = get_float_config("write_node_replacements_global", 1.0) assert isinstance(write_node_replacement_setting, float) - if random.random() < write_node_replacement_setting: - return True - return False + return random.random() < write_node_replacement_setting class ErrorsReplacer(ReplacerProcessor[Replacement]): @@ -155,9 +145,7 @@ def __initialize_schema(self) -> None: promoted_tags=self.__promoted_tags, ) - def process_message( - self, message: ReplacementMessage[Mapping[str, Any]] - ) -> Optional[Replacement]: + def process_message(self, message: ReplacementMessage[Mapping[str, Any]]) -> Replacement | None: if not self.__schema: self.__initialize_schema() assert self.__schema is not None @@ -174,7 +162,7 @@ def process_message( tags={"type": type_, "consumer_group": message.metadata.consumer_group}, ) - processed: Optional[Replacement] + processed: Replacement | None if type_ in ( ReplacementType.START_DELETE_GROUPS, @@ -183,13 +171,13 @@ def process_message( ReplacementType.START_DELETE_TAG, ): return None - elif type_ in _REPLACEMENT_BY_TYPE: + if type_ in _REPLACEMENT_BY_TYPE: processed = _REPLACEMENT_BY_TYPE[type_].parse_message( message, self.__replacement_context, ) else: - raise InvalidMessageType("Invalid message type: {}".format(type_)) + raise InvalidMessageType(f"Invalid message type: {type_}") if processed is not None: manual_bypass_projects = get_config("replacements_bypass_projects", "[]") @@ -262,10 +250,10 @@ def pre_replacement(self, replacement: Replacement, matching_records: int) -> bo def _build_event_set_filter( project_id: int, event_ids: Sequence[str], - from_timestamp: Optional[str], - to_timestamp: Optional[str], -) -> Tuple[List[str], List[str], MutableMapping[str, str]]: - def get_timestamp_condition(msg_value: Optional[str], operator: str) -> str: + from_timestamp: str | None, + to_timestamp: str | None, +) -> tuple[list[str], list[str], MutableMapping[str, str]]: + def get_timestamp_condition(msg_value: str | None, operator: str) -> str: if not msg_value: return "" @@ -280,7 +268,7 @@ def get_timestamp_condition(msg_value: Optional[str], operator: str) -> str: to_condition = get_timestamp_condition(to_timestamp, "<=") event_id_lhs = "event_id" - event_id_list = ", ".join("'%s'" % uuid.UUID(eid) for eid in event_ids) + event_id_list = ", ".join(f"'{uuid.UUID(eid)}'" for eid in event_ids) prewhere = [f"{event_id_lhs} IN (%(event_ids)s)"] where = ["project_id = %(project_id)s", "NOT deleted"] @@ -313,8 +301,8 @@ class ReplaceGroupReplacement(Replacement): event_ids: Sequence[str] project_id: int - from_timestamp: Optional[str] - to_timestamp: Optional[str] + from_timestamp: str | None + to_timestamp: str | None new_group_id: int all_columns: Sequence[FlattenedColumn] @@ -323,7 +311,7 @@ def parse_message( cls, message: ReplacementMessage[ReplaceGroupMessageBody], context: ReplacementContext, - ) -> Optional[ReplaceGroupReplacement]: + ) -> ReplaceGroupReplacement | None: event_ids = message.data["event_ids"] if not event_ids: return None @@ -340,7 +328,7 @@ def parse_message( def get_project_id(self) -> int: return self.project_id - def get_query_time_flags(self) -> Optional[QueryTimeFlags]: + def get_query_time_flags(self) -> QueryTimeFlags | None: return None @classmethod @@ -358,20 +346,17 @@ def _where_clause(self) -> str: return f"PREWHERE {' AND '.join(prewhere)} WHERE {' AND '.join(where)}" % query_args - def get_count_query(self, table_name: str) -> Optional[str]: + def get_count_query(self, table_name: str) -> str | None: return f"""\ SELECT count() FROM {table_name} FINAL {self._where_clause} """ - def get_insert_query(self, table_name: str) -> Optional[str]: + def get_insert_query(self, table_name: str) -> str | None: all_column_names = [c.escaped for c in self.all_columns] select_columns = ", ".join( - map( - lambda i: i if i != "group_id" else str(self.new_group_id), - all_column_names, - ) + i if i != "group_id" else str(self.new_group_id) for i in all_column_names ) all_columns = ", ".join(all_column_names) @@ -396,7 +381,7 @@ def parse_message( cls, message: ReplacementMessage[EndDeleteGroupsMessageBody], context: ReplacementContext, - ) -> Optional[DeleteGroupsReplacement]: + ) -> DeleteGroupsReplacement | None: group_ids = message.data["group_ids"] if not group_ids: return None @@ -415,7 +400,7 @@ def parse_message( def get_project_id(self) -> int: return self.project_id - def get_query_time_flags(self) -> Optional[QueryTimeFlags]: + def get_query_time_flags(self) -> QueryTimeFlags | None: return ExcludeGroups(group_ids=self.group_ids) @classmethod @@ -434,18 +419,16 @@ def _where_clause(self) -> str: AND NOT deleted """ - def get_count_query(self, table_name: str) -> Optional[str]: + def get_count_query(self, table_name: str) -> str | None: return f"""\ SELECT count() FROM {table_name} FINAL {self._where_clause} """ - def get_insert_query(self, table_name: str) -> Optional[str]: + def get_insert_query(self, table_name: str) -> str | None: required_columns = ", ".join(self.required_columns) - select_columns = ", ".join( - map(lambda i: i if i != "deleted" else "1", self.required_columns) - ) + select_columns = ", ".join(i if i != "deleted" else "1" for i in self.required_columns) return f"""\ INSERT INTO {table_name} ({required_columns}) SELECT {select_columns} @@ -457,10 +440,10 @@ def get_insert_query(self, table_name: str) -> Optional[str]: @dataclass class TombstoneEventsReplacement(Replacement): event_ids: Sequence[str] - old_primary_hash: Optional[str] + old_primary_hash: str | None project_id: int - from_timestamp: Optional[str] - to_timestamp: Optional[str] + from_timestamp: str | None + to_timestamp: str | None required_columns: Sequence[str] @@ -469,7 +452,7 @@ def parse_message( cls, message: ReplacementMessage[TombstoneEventsMessageBody], context: ReplacementContext, - ) -> Optional[TombstoneEventsReplacement]: + ) -> TombstoneEventsReplacement | None: event_ids = message.data["event_ids"] if not event_ids: return None @@ -493,24 +476,22 @@ def _where_clause(self) -> str: ) if self.old_primary_hash: - query_args["old_primary_hash"] = "'%s'" % (str(uuid.UUID(self.old_primary_hash)),) + query_args["old_primary_hash"] = f"'{str(uuid.UUID(self.old_primary_hash))}'" prewhere.append("primary_hash = %(old_primary_hash)s") return f"PREWHERE {' AND '.join(prewhere)} WHERE {' AND '.join(where)}" % query_args - def get_count_query(self, table_name: str) -> Optional[str]: + def get_count_query(self, table_name: str) -> str | None: return f"""\ SELECT count() FROM {table_name} FINAL {self._where_clause} """ - def get_insert_query(self, table_name: str) -> Optional[str]: + def get_insert_query(self, table_name: str) -> str | None: required_columns = ", ".join(self.required_columns) - select_columns = ", ".join( - map(lambda i: i if i != "deleted" else "1", self.required_columns) - ) + select_columns = ", ".join(i if i != "deleted" else "1" for i in self.required_columns) return f"""\ INSERT INTO {table_name} ({required_columns}) @@ -519,7 +500,7 @@ def get_insert_query(self, table_name: str) -> Optional[str]: {self._where_clause} """ - def get_query_time_flags(self) -> Optional[QueryTimeFlags]: + def get_query_time_flags(self) -> QueryTimeFlags | None: return None def get_project_id(self) -> int: @@ -556,7 +537,7 @@ def parse_message( cls, message: ReplacementMessage[ExcludeGroupsMessageBody], context: ReplacementContext, - ) -> Optional[ExcludeGroupsReplacement]: + ) -> ExcludeGroupsReplacement | None: if not message.data["group_ids"]: return None @@ -568,21 +549,21 @@ def parse_message( def get_project_id(self) -> int: return self.project_id - def get_query_time_flags(self) -> Optional[QueryTimeFlags]: + def get_query_time_flags(self) -> QueryTimeFlags | None: return ExcludeGroups(group_ids=self.group_ids) @classmethod def get_replacement_type(cls) -> ReplacementType: return ReplacementType.EXCLUDE_GROUPS - def get_insert_query(self, table_name: str) -> Optional[str]: + def get_insert_query(self, table_name: str) -> str | None: return None - def get_count_query(self, table_name: str) -> Optional[str]: + def get_count_query(self, table_name: str) -> str | None: return None -SEEN_MERGE_TXN_CACHE: Deque[str] = deque(maxlen=100) +SEEN_MERGE_TXN_CACHE: deque[str] = deque(maxlen=100) @dataclass @@ -612,7 +593,7 @@ def parse_message( cls, message: ReplacementMessage[EndMergeMessageBody], context: ReplacementContext, - ) -> Optional[MergeReplacement]: + ) -> MergeReplacement | None: project_id = message.data["project_id"] previous_group_ids = message.data["previous_group_ids"] if not previous_group_ids: @@ -632,8 +613,7 @@ def parse_message( extra={"project_id": project_id}, ) return None - else: - SEEN_MERGE_TXN_CACHE.append(txn) + SEEN_MERGE_TXN_CACHE.append(txn) # new_group_first_seen was added to the message schema; keep this check # for backwards compatibility. @@ -666,14 +646,14 @@ def _where_clause(self) -> str: AND NOT deleted """ - def get_count_query(self, table_name: str) -> Optional[str]: + def get_count_query(self, table_name: str) -> str | None: return f"""\ SELECT count() FROM {table_name} FINAL {self._where_clause} """ - def get_insert_query(self, table_name: str) -> Optional[str]: + def get_insert_query(self, table_name: str) -> str | None: all_column_names = [c.escaped for c in self.all_columns] all_columns = ", ".join(all_column_names) replacement_columns = {"group_id": str(self.new_group_id)} @@ -682,13 +662,7 @@ def get_insert_query(self, table_name: str) -> Optional[str]: group_first_seen_str = self.new_group_first_seen.strftime(DATETIME_FORMAT) replacement_columns["group_first_seen"] = f"CAST('{group_first_seen_str}' AS DateTime)" - select_columns = ", ".join( - map( - # Get i from replacement_columns; default to i if no replacement. - lambda i: replacement_columns.get(i, i), - all_column_names, - ) - ) + select_columns = ", ".join(replacement_columns.get(i, i) for i in all_column_names) return f"""\ INSERT INTO {table_name} ({all_columns}) SELECT {select_columns} @@ -722,7 +696,7 @@ def parse_message( cls, message: ReplacementMessage[EndUnmergeMessageBody], context: ReplacementContext, - ) -> Optional["Replacement"]: + ) -> Replacement | None: hashes = message.data["hashes"] if not hashes: return None @@ -744,7 +718,7 @@ def parse_message( def get_project_id(self) -> int: return self.project_id - def get_query_time_flags(self) -> Optional[QueryTimeFlags]: + def get_query_time_flags(self) -> QueryTimeFlags | None: return NeedsFinal() @classmethod @@ -754,9 +728,9 @@ def get_replacement_type(cls) -> ReplacementType: @cached_property def _where_clause(self) -> str: if self.state_name == ReplacerState.ERRORS: - hashes = ", ".join(["'%s'" % str(uuid.UUID(_hashify(h))) for h in self.hashes]) + hashes = ", ".join([f"'{str(uuid.UUID(_hashify(h)))}'" for h in self.hashes]) else: - hashes = ", ".join("'%s'" % _hashify(h) for h in self.hashes) + hashes = ", ".join(f"'{_hashify(h)}'" for h in self.hashes) timestamp = self.timestamp.strftime(DATETIME_FORMAT) @@ -768,20 +742,17 @@ def _where_clause(self) -> str: AND NOT deleted """ - def get_count_query(self, table_name: str) -> Optional[str]: + def get_count_query(self, table_name: str) -> str | None: return f"""\ SELECT count() FROM {table_name} FINAL {self._where_clause} """ - def get_insert_query(self, table_name: str) -> Optional[str]: + def get_insert_query(self, table_name: str) -> str | None: all_column_names = [c.escaped for c in self.all_columns] select_columns = ", ".join( - map( - lambda i: i if i != "group_id" else str(self.new_group_id), - all_column_names, - ) + i if i != "group_id" else str(self.new_group_id) for i in all_column_names ) all_columns = ", ".join(all_column_names) @@ -797,14 +768,11 @@ def get_insert_query(self, table_name: str) -> Optional[str]: def _convert_hash(hash: str, state_name: ReplacerState, convert_types: bool = False) -> str: if state_name == ReplacerState.ERRORS: if convert_types: - return "toUUID('%s')" % str(uuid.UUID(_hashify(hash))) - else: - return "'%s'" % str(uuid.UUID(_hashify(hash))) - else: - if convert_types: - return "toFixedString('%s', 32)" % _hashify(hash) - else: - return "'%s'" % _hashify(hash) + return f"toUUID('{str(uuid.UUID(_hashify(hash)))}')" + return f"'{str(uuid.UUID(_hashify(hash)))}'" + if convert_types: + return f"toFixedString('{_hashify(hash)}', 32)" + return f"'{_hashify(hash)}'" @dataclass @@ -822,7 +790,7 @@ def parse_message( cls, message: ReplacementMessage[EndDeleteTagMessageBody], context: ReplacementContext, - ) -> Optional[DeleteTagReplacement]: + ) -> DeleteTagReplacement | None: tag = message.data["tag"] if not tag: return None @@ -870,20 +838,18 @@ def _select_columns(self) -> Sequence[str]: select_columns.append("''") elif col.flattened == "tags.key": select_columns.append( - "arrayFilter(x -> (indexOf(`tags.key`, x) != indexOf(`tags.key`, %s)), `tags.key`)" - % escape_string(self.tag) + f"arrayFilter(x -> (indexOf(`tags.key`, x) != indexOf(`tags.key`, {escape_string(self.tag)})), `tags.key`)" ) elif col.flattened == "tags.value": select_columns.append( - "arrayMap(x -> arrayElement(`tags.value`, x), arrayFilter(x -> x != indexOf(`tags.key`, %s), arrayEnumerate(`tags.value`)))" - % escape_string(self.tag) + f"arrayMap(x -> arrayElement(`tags.value`, x), arrayFilter(x -> x != indexOf(`tags.key`, {escape_string(self.tag)}), arrayEnumerate(`tags.value`)))" ) else: select_columns.append(col.escaped) return select_columns - def get_insert_query(self, table_name: str) -> Optional[str]: + def get_insert_query(self, table_name: str) -> str | None: all_columns = ", ".join(col.escaped for col in self.all_columns) select_columns = ", ".join(self._select_columns) @@ -894,7 +860,7 @@ def get_insert_query(self, table_name: str) -> Optional[str]: {self._where_clause} """ - def get_count_query(self, table_name: str) -> Optional[str]: + def get_count_query(self, table_name: str) -> str | None: return f"""\ SELECT count() FROM {table_name} FINAL @@ -912,10 +878,10 @@ def get_replacement_type(cls) -> ReplacementType: return ReplacementType.END_DELETE_TAG -_REPLACEMENT_BY_TYPE: Mapping[ReplacementType, Type[Replacement]] = dict( - (cls.get_replacement_type(), cls) +_REPLACEMENT_BY_TYPE: Mapping[ReplacementType, type[Replacement]] = { + cls.get_replacement_type(): cls for cls in cast( - Sequence[Type[Replacement]], + Sequence[type[Replacement]], [ DeleteGroupsReplacement, MergeReplacement, @@ -926,4 +892,4 @@ def get_replacement_type(cls) -> ReplacementType: ExcludeGroupsReplacement, ], ) -) +} diff --git a/snuba/replacers/projects_query_flags.py b/snuba/replacers/projects_query_flags.py index b1e59f99341..84ad83408a5 100644 --- a/snuba/replacers/projects_query_flags.py +++ b/snuba/replacers/projects_query_flags.py @@ -2,9 +2,10 @@ import sys import time +from collections.abc import Mapping, MutableMapping, Sequence from dataclasses import dataclass from datetime import datetime -from typing import Any, List, Mapping, MutableMapping, Optional, Sequence, Set, Tuple +from typing import Any import sentry_sdk from redis.cluster import ClusterPipeline as StrictClusterPipeline @@ -35,14 +36,14 @@ class ProjectsQueryFlags: """ needs_final: bool - group_ids_to_exclude: Set[int] - replacement_types: Set[str] - latest_replacement_time: Optional[datetime] + group_ids_to_exclude: set[int] + replacement_types: set[str] + latest_replacement_time: datetime | None @staticmethod def set_project_needs_final( project_id: int, - state_name: Optional[ReplacerState], + state_name: ReplacerState | None, replacement_type: ReplacementType, ) -> None: key, type_key = ProjectsQueryFlags._build_project_needs_final_key_and_type_key( @@ -57,7 +58,7 @@ def set_project_needs_final( def set_project_exclude_groups( project_id: int, group_ids: Sequence[int], - state_name: Optional[ReplacerState], + state_name: ReplacerState | None, # replacement type is just for metrics, not necessary for functionality replacement_type: ReplacementType, ) -> None: @@ -114,7 +115,7 @@ def set_project_exclude_groups( @classmethod def load_from_redis( - cls, project_ids: Sequence[int], state_name: Optional[ReplacerState] + cls, project_ids: Sequence[int], state_name: ReplacerState | None ) -> ProjectsQueryFlags: """ Loads flags for given project ids. @@ -151,13 +152,13 @@ def load_from_redis( sentry_sdk.capture_exception(e) return cls( needs_final=False, - group_ids_to_exclude=set([]), - replacement_types=set([]), + group_ids_to_exclude=set(), + replacement_types=set(), latest_replacement_time=None, ) @classmethod - def _process_redis_results(cls, results: List[Any], len_projects: int) -> ProjectsQueryFlags: + def _process_redis_results(cls, results: list[Any], len_projects: int) -> ProjectsQueryFlags: """ Produces readable data from flattened list of Redis pipeline results. @@ -212,8 +213,8 @@ def _process_redis_results(cls, results: List[Any], len_projects: int) -> Projec @staticmethod def _query_redis( - project_ids: Set[int], - state_name: Optional[ReplacerState], + project_ids: set[int], + state_name: ReplacerState | None, p: StrictClusterPipeline, ) -> None: """ @@ -261,7 +262,7 @@ def _query_redis( @staticmethod def _remove_stale_and_load_new_sorted_set_data( - p: StrictClusterPipeline, keys: List[str] + p: StrictClusterPipeline, keys: list[str] ) -> None: """ Remove stale data per key according to TTL. @@ -280,9 +281,9 @@ def _remove_stale_and_load_new_sorted_set_data( @staticmethod def _process_latest_replacement( needs_final: bool, - needs_final_result: List[Any], - latest_exclude_groups_result: List[Any], - ) -> Optional[datetime]: + needs_final_result: list[Any], + latest_exclude_groups_result: list[Any], + ) -> datetime | None: """ Process the relevant replacements data to look for the latest timestamp any replacement occured. @@ -308,15 +309,15 @@ def _process_latest_replacement( @staticmethod def _build_project_needs_final_key_and_type_key( - project_id: int, state_name: Optional[ReplacerState] - ) -> Tuple[str, str]: + project_id: int, state_name: ReplacerState | None + ) -> tuple[str, str]: key = f"project_needs_final:{f'{state_name.value}:' if state_name else ''}{project_id}" return key, f"{key}-type" @staticmethod def _build_project_exclude_groups_key_and_type_key( - project_id: int, state_name: Optional[ReplacerState] - ) -> Tuple[str, str]: + project_id: int, state_name: ReplacerState | None + ) -> tuple[str, str]: key = f"project_exclude_groups:{f'{state_name.value}:' if state_name else ''}{project_id}" return key, f"{key}-type" diff --git a/snuba/replacers/replacements_and_expiry.py b/snuba/replacers/replacements_and_expiry.py index a34adcb24c4..f7f7126f893 100644 --- a/snuba/replacers/replacements_and_expiry.py +++ b/snuba/replacers/replacements_and_expiry.py @@ -3,17 +3,16 @@ import logging import time import typing +from collections.abc import Mapping, Sequence from datetime import datetime, timedelta -from typing import Mapping, Sequence - -logger = logging.getLogger(__name__) - from snuba import environment from snuba.redis import RedisClientKey, get_redis_client from snuba.state import get_int_config from snuba.utils.metrics.wrapper import MetricsWrapper +logger = logging.getLogger(__name__) + metrics = MetricsWrapper(environment.metrics, "replacements_and_expiry") redis_client = get_redis_client(RedisClientKey.REPLACEMENTS_STORE) diff --git a/snuba/replacers/replacer_processor.py b/snuba/replacers/replacer_processor.py index 107ba308398..faf9a6665b3 100644 --- a/snuba/replacers/replacer_processor.py +++ b/snuba/replacers/replacer_processor.py @@ -1,8 +1,7 @@ from abc import ABC, abstractmethod +from collections.abc import Mapping from enum import Enum -from typing import Any, Generic, Mapping, Optional, TypeVar, cast - -from typing_extensions import NamedTuple +from typing import Any, Generic, NamedTuple, TypeVar, cast from snuba.datasets.schemas.tables import WritableTableSchema from snuba.processor import ReplacementType @@ -45,11 +44,11 @@ def get_replacement_type(cls) -> ReplacementType: raise NotImplementedError() @abstractmethod - def get_insert_query(self, table_name: str) -> Optional[str]: + def get_insert_query(self, table_name: str) -> str | None: raise NotImplementedError() @abstractmethod - def get_count_query(self, table_name: str) -> Optional[str]: + def get_count_query(self, table_name: str) -> str | None: raise NotImplementedError() @abstractmethod @@ -81,7 +80,7 @@ def get_from_name(cls, name: str) -> "ReplacerProcessor[R]": return cast("ReplacerProcessor[R]", cls.class_from_name(name)) @abstractmethod - def process_message(self, message: ReplacementMessage[Mapping[str, Any]]) -> Optional[R]: + def process_message(self, message: ReplacementMessage[Mapping[str, Any]]) -> R | None: """ Processes one message from the topic. """ diff --git a/snuba/request/__init__.py b/snuba/request/__init__.py index 04d2220e9c6..e5361bd7c01 100644 --- a/snuba/request/__init__.py +++ b/snuba/request/__init__.py @@ -2,7 +2,7 @@ import uuid from dataclasses import dataclass -from typing import Any, Dict, Union +from typing import Any from snuba.attribution.attribution_info import AttributionInfo from snuba.query import ProcessableQuery @@ -15,8 +15,8 @@ @dataclass(frozen=True) class Request: id: uuid.UUID - original_body: Dict[str, Any] - query: Union[Query, CompositeQuery[LogicalDataSource]] + original_body: dict[str, Any] + query: Query | CompositeQuery[LogicalDataSource] query_settings: QuerySettings attribution_info: AttributionInfo diff --git a/snuba/request/schema.py b/snuba/request/schema.py index a79afd18032..dd495d29865 100644 --- a/snuba/request/schema.py +++ b/snuba/request/schema.py @@ -1,7 +1,8 @@ from __future__ import annotations import itertools -from typing import Any, Mapping, MutableMapping, NamedTuple, Type +from collections.abc import Mapping, MutableMapping +from typing import Any, NamedTuple import jsonschema @@ -78,7 +79,7 @@ def __init__( @classmethod def build( cls, - settings_class: Type[QuerySettings], + settings_class: type[QuerySettings], is_mql: bool = False, is_delete: bool = False, ) -> RequestSchema: @@ -96,17 +97,17 @@ def validate(self, value: MutableMapping[str, Any]) -> RequestParts: raise JsonSchemaValidationException(str(error)) from error query_body = { - key: value.get(key) for key in self.__query_schema["properties"].keys() if key in value + key: value.get(key) for key in self.__query_schema["properties"] if key in value } query_settings = { key: value.get(key) - for key in self.__query_settings_schema["properties"].keys() + for key in self.__query_settings_schema["properties"] if key in value } attribution_info = { key: value.get(key) - for key in self.__attribution_info_schema["properties"].keys() + for key in self.__attribution_info_schema["properties"] if key in value } @@ -125,14 +126,14 @@ def __generate_template_impl(self, schema: Mapping[str, Any]) -> Any: if "default" in schema: default = schema["default"] return default() if callable(default) else default - elif typ == "object": + if typ == "object": return { prop: self.__generate_template_impl(subschema) for prop, subschema in schema.get("properties", {}).items() } - elif typ == "array": + if typ == "array": return [] - elif typ == "string": + if typ == "string": return "" return None @@ -154,7 +155,7 @@ def generate_template(self) -> Any: } -SETTINGS_SCHEMAS: Mapping[Type[QuerySettings], Schema] = { +SETTINGS_SCHEMAS: Mapping[type[QuerySettings], Schema] = { HTTPQuerySettings: { "type": "object", "properties": { diff --git a/snuba/request/validation.py b/snuba/request/validation.py index 43ea9643f16..3698f58b85f 100644 --- a/snuba/request/validation.py +++ b/snuba/request/validation.py @@ -3,7 +3,8 @@ import random import textwrap import uuid -from typing import Any, Dict, MutableMapping, Optional, Protocol, Type, Union +from collections.abc import MutableMapping +from typing import Any, Protocol import sentry_sdk @@ -42,16 +43,16 @@ def __call__( request_parts: RequestParts, settings: QuerySettings, dataset: Dataset, - custom_processing: Optional[CustomProcessors] = ..., - ) -> Union[Query, CompositeQuery[LogicalDataSource]]: ... + custom_processing: CustomProcessors | None = ..., + ) -> Query | CompositeQuery[LogicalDataSource]: ... def parse_snql_query( request_parts: RequestParts, settings: QuerySettings, dataset: Dataset, - custom_processing: Optional[CustomProcessors] = None, -) -> Union[Query, CompositeQuery[LogicalDataSource]]: + custom_processing: CustomProcessors | None = None, +) -> Query | CompositeQuery[LogicalDataSource]: return _parse_snql_query(request_parts.query["query"], dataset, custom_processing, settings) @@ -59,8 +60,8 @@ def parse_mql_query( request_parts: RequestParts, settings: QuerySettings, dataset: Dataset, - custom_processing: Optional[CustomProcessors] = None, -) -> Union[Query, CompositeQuery[LogicalDataSource]]: + custom_processing: CustomProcessors | None = None, +) -> Query | CompositeQuery[LogicalDataSource]: return _parse_mql_query( request_parts.query["query"], request_parts.query["mql_context"], @@ -76,15 +77,14 @@ def _consistent_override(original_setting: bool, referrer: str) -> bool: referrers_override = consistent_config.split(";") for config in referrers_override: referrer_config, percentage = config.split("=") - if referrer_config == referrer: - if random.random() > float(percentage): - return False + if referrer_config == referrer and random.random() > float(percentage): + return False return original_setting def update_attribution_info( - request_parts: RequestParts, referrer: str, query_project_id: Optional[int] + request_parts: RequestParts, referrer: str, query_project_id: int | None ) -> dict[str, Any]: attribution_info = dict(request_parts.attribution_info) @@ -99,14 +99,14 @@ def update_attribution_info( def build_request( - body: Dict[str, Any], + body: dict[str, Any], parser: Parser, - settings_class: Union[Type[HTTPQuerySettings], Type[SubscriptionQuerySettings]], + settings_class: type[HTTPQuerySettings] | type[SubscriptionQuerySettings], schema: RequestSchema, dataset: Dataset, timer: Timer, referrer: str, - custom_processing: Optional[CustomProcessors] = None, + custom_processing: CustomProcessors | None = None, ) -> Request: with sentry_sdk.start_span(description="build_request", op="validate") as span: try: @@ -189,7 +189,7 @@ def _get_referrer(request_parts: RequestParts, referrer: str) -> str: def _get_settings_object( - settings_class: Type[HTTPQuerySettings] | Type[SubscriptionQuerySettings], + settings_class: type[HTTPQuerySettings] | type[SubscriptionQuerySettings], request_parts: RequestParts, referrer: str, ) -> HTTPQuerySettings | SubscriptionQuerySettings: @@ -203,12 +203,12 @@ def _get_settings_object( # TODO: referrer probably doesn't need to be passed in, it should be from the body query_settings["referrer"] = referrer # the parameters accept either `str` or `bool` but we pass in `str | bool` - return settings_class(**query_settings) # type: ignore - elif settings_class == SubscriptionQuerySettings: + return settings_class(**query_settings) # type: ignore[arg-type] + if settings_class == SubscriptionQuerySettings: return settings_class( consistent=_consistent_override(True, referrer), ) - return None # type: ignore + return None # type: ignore[return-value] def _get_project_id(query: Query | CompositeQuery[LogicalDataSource]) -> int | None: diff --git a/snuba/schemas.py b/snuba/schemas.py index 5015d0be08f..3194a1f3800 100644 --- a/snuba/schemas.py +++ b/snuba/schemas.py @@ -1,5 +1,6 @@ import copy -from typing import Any, Generator, Mapping, MutableMapping +from collections.abc import Generator, Mapping, MutableMapping +from typing import Any import jsonschema @@ -23,7 +24,7 @@ def validate_and_default( properties: Mapping[str, Any], instance: MutableMapping[str, Any], schema: Mapping[str, Any], - ) -> Generator[Exception, None, None]: + ) -> Generator[Exception]: for property, subschema in properties.items(): if property not in instance and "default" in subschema: if callable(subschema["default"]): @@ -32,8 +33,7 @@ def validate_and_default( default_value = copy.deepcopy(subschema["default"]) instance[property] = default_value - for error in orig(validator, properties, instance, schema): - yield error + yield from orig(validator, properties, instance, schema) # Using schema defaults during validation will cause the input value to be # mutated, so to be on the safe side we create a deep copy of that value to diff --git a/snuba/settings/__init__.py b/snuba/settings/__init__.py index a1aebf43709..3f2f0af8660 100644 --- a/snuba/settings/__init__.py +++ b/snuba/settings/__init__.py @@ -1,15 +1,10 @@ from __future__ import annotations import os +from collections.abc import Mapping, MutableMapping, Sequence from pathlib import Path from typing import ( Any, - Mapping, - MutableMapping, - Optional, - Sequence, - Set, - Tuple, TypedDict, ) @@ -89,8 +84,8 @@ ALLOCATION_POLICY_ENABLED = True DEFAULT_DATASET_NAME = "events" -DISABLED_ENTITIES: Set[str] = set() -DISABLED_DATASETS: Set[str] = set() +DISABLED_ENTITIES: set[str] = set() +DISABLED_DATASETS: set[str] = set() # Clickhouse Options CLICKHOUSE_MAX_POOL_SIZE = 25 @@ -223,7 +218,7 @@ class RedisClusters(TypedDict): # Runtime Config Options CONFIG_MEMOIZE_TIMEOUT = 10 -CONFIG_STATE: Mapping[str, Optional[Any]] = {} +CONFIG_STATE: Mapping[str, Any | None] = {} # Sentry Options SENTRY_DSN: str | None = None @@ -283,7 +278,7 @@ class RedisClusters(TypedDict): ENFORCE_RETENTION: bool = False LOWER_RETENTION_DAYS = 30 DEFAULT_RETENTION_DAYS = 90 -VALID_RETENTION_DAYS = set([30, 90]) +VALID_RETENTION_DAYS = {30, 90} MAX_PREWHERE_CONDITIONS = 1 @@ -304,7 +299,7 @@ class RedisClusters(TypedDict): TURBO_SAMPLE_RATE = 0.1 -PROJECT_STACKTRACE_BLACKLIST: Set[int] = set() +PROJECT_STACKTRACE_BLACKLIST: set[int] = set() PRETTY_FORMAT_EXPRESSIONS = os.environ.get("PRETTY_FORMAT_EXPRESSIONS", "1") == "1" # By default, allocation policies won't block requests from going through in a production @@ -334,10 +329,10 @@ class RedisClusters(TypedDict): # The migration groups that can be skipped are listed in OPTIONAL_GROUPS. # Migrations for skipped groups will not be run. -SKIPPED_MIGRATION_GROUPS: Set[str] = set() +SKIPPED_MIGRATION_GROUPS: set[str] = set() # Dataset readiness states supported in this environment -SUPPORTED_STATES: Set[str] = { +SUPPORTED_STATES: set[str] = { "deprecate", "limited", "experimental", @@ -353,7 +348,7 @@ class RedisClusters(TypedDict): # These contexts will not be stored in the transactions table # Example: {123: {"context1", "context2"}} # where 123 is the project id. -TRANSACT_SKIP_CONTEXT_STORE: Mapping[int, Set[str]] = {} +TRANSACT_SKIP_CONTEXT_STORE: Mapping[int, set[str]] = {} # Map the Zookeeper path for the replicated merge tree to something else CLICKHOUSE_ZOOKEEPER_OVERRIDE: Mapping[str, str] = {} @@ -460,11 +455,11 @@ class RedisClusters(TypedDict): # Mapping of (logical topic names, slice id) pairs to custom physical topic names # This is only for sliced Kafka topics -SLICED_KAFKA_TOPIC_MAP: Mapping[Tuple[str, int], str] = {} +SLICED_KAFKA_TOPIC_MAP: Mapping[tuple[str, int], str] = {} # Mapping of (logical topic names, slice id) pairs to broker config # This is only for sliced Kafka topics -SLICED_KAFKA_BROKER_CONFIG: Mapping[Tuple[str, int], Mapping[str, Any]] = {} +SLICED_KAFKA_BROKER_CONFIG: Mapping[tuple[str, int], Mapping[str, Any]] = {} # When dataset yamls (i.e. dataset, storages, entities) are loaded into memory, should we validate # the jsonschema or not? In production we shouldn't need to do it, in CI we should. This is for performance @@ -476,13 +471,15 @@ class RedisClusters(TypedDict): MAX_ONGOING_MUTATIONS_FOR_DELETE = 5 MAX_PARTS_MUTATING_FOR_DELETE = 20 LW_DELETES_PARTITION_TRACKING_TTL = 86400 -SNQL_DISABLED_DATASETS: set[str] = set([]) +SNQL_DISABLED_DATASETS: set[str] = set() ENDPOINT_GET_TRACE_PAGINATION_MAX_ITEMS: int = 0 # 0 means no limit ENABLE_TRACE_PAGINATION_DEFAULT = 1 -def _load_settings(obj: MutableMapping[str, Any] = locals()) -> None: +# `locals()` default captures this module's namespace so settings can be injected as +# module-level globals; calling it in the body would return the function's locals instead. +def _load_settings(obj: MutableMapping[str, Any] = locals()) -> None: # noqa: B008 """Load settings from the path provided in the SNUBA_SETTINGS environment variable if provided. Users can provide a short name like `test` that will be expanded to `settings_test.py` in the main Snuba directory, or they can diff --git a/snuba/settings/settings_self_hosted.py b/snuba/settings/settings_self_hosted.py index 108e615ad91..cff9dcb3d44 100644 --- a/snuba/settings/settings_self_hosted.py +++ b/snuba/settings/settings_self_hosted.py @@ -1,5 +1,4 @@ import os -from typing import Set from snuba.utils.metrics.addr_config import get_statsd_addr @@ -9,16 +8,16 @@ DEBUG = env("DEBUG", "0").lower() in ("1", "true") DEFAULT_RETENTION_DAYS = int(env("SENTRY_EVENT_RETENTION_DAYS", 90)) -VALID_RETENTION_DAYS = set([int(env("SENTRY_EVENT_RETENTION_DAYS", 90)), 30, 60]) +VALID_RETENTION_DAYS = {int(env("SENTRY_EVENT_RETENTION_DAYS", 90)), 30, 60} LOWER_RETENTION_DAYS = min(DEFAULT_RETENTION_DAYS, 30) API_WORKERS = int(env("SNUBA_API_WORKERS", 1)) API_THREADS = int(env("SNUBA_API_THREADS", 8)) API_WORKERS_LIFETIME = ( - int(env("SNUBA_API_WORKERS_LIFETIME")) if env("SNUBA_API_WORKERS_LIFETIME") else None # type: ignore + int(_workers_lifetime) if (_workers_lifetime := env("SNUBA_API_WORKERS_LIFETIME")) else None ) API_WORKERS_MAX_RSS = ( - int(env("SNUBA_API_WORKERS_MAX_RSS")) if env("SNUBA_API_WORKERS_MAX_RSS") else None # type: ignore + int(_workers_max_rss) if (_workers_max_rss := env("SNUBA_API_WORKERS_MAX_RSS")) else None ) REDIS_HOST = env("REDIS_HOST", "127.0.0.1") @@ -31,7 +30,7 @@ DOGSTATSD_HOST, DOGSTATSD_PORT = get_statsd_addr() # Dataset readiness states supported in this environment -SUPPORTED_STATES: Set[str] = {"deprecate", "complete"} +SUPPORTED_STATES: set[str] = {"deprecate", "complete"} READINESS_STATE_FAIL_QUERIES: bool = False diff --git a/snuba/settings/settings_test.py b/snuba/settings/settings_test.py index 62f433792bc..15e293737f9 100644 --- a/snuba/settings/settings_test.py +++ b/snuba/settings/settings_test.py @@ -1,5 +1,4 @@ import os -from typing import Set TESTING = True @@ -11,8 +10,8 @@ SENTRY_DSN = os.getenv("SENTRY_DSN") -SKIPPED_MIGRATION_GROUPS: Set[str] = set() -SUPPORTED_STATES: Set[str] = { +SKIPPED_MIGRATION_GROUPS: set[str] = set() +SUPPORTED_STATES: set[str] = { "deprecate", "limited", "partial", diff --git a/snuba/settings/settings_test_distributed_migrations.py b/snuba/settings/settings_test_distributed_migrations.py index 17e6e03e0b4..3dbdfbd0469 100644 --- a/snuba/settings/settings_test_distributed_migrations.py +++ b/snuba/settings/settings_test_distributed_migrations.py @@ -1,5 +1,6 @@ import os -from typing import Any, Mapping, Sequence +from collections.abc import Mapping, Sequence +from typing import Any from snuba.settings.settings_test import * # noqa diff --git a/snuba/settings/validation.py b/snuba/settings/validation.py index cd3aef1b99f..8777ed2ea86 100644 --- a/snuba/settings/validation.py +++ b/snuba/settings/validation.py @@ -1,4 +1,6 @@ -from typing import Any, Mapping, MutableMapping +import contextlib +from collections.abc import Mapping, MutableMapping +from typing import Any from snuba.datasets.slicing import SENTRY_LOGICAL_PARTITIONS @@ -33,12 +35,10 @@ def validate_settings(locals: Mapping[str, Any]) -> None: for cluster in locals["CLUSTERS"]: for cluster_storage_set in cluster["storage_sets"]: - try: + # We allow definition of storage_sets in configuration files + # that are not defined in StorageSetKey. + with contextlib.suppress(ValueError): storage_set_to_cluster[StorageSetKey(cluster_storage_set)] = cluster - except ValueError: - # We allow definition of storage_sets in configuration files - # that are not defined in StorageSetKey. - pass def validate_slicing_settings(locals: Mapping[str, Any]) -> None: diff --git a/snuba/snapshots/__init__.py b/snuba/snapshots/__init__.py index d3ce5972e16..0029664cd7d 100644 --- a/snuba/snapshots/__init__.py +++ b/snuba/snapshots/__init__.py @@ -1,10 +1,11 @@ from __future__ import annotations from abc import ABC, abstractmethod +from collections.abc import Generator, Iterator, Mapping, Sequence from contextlib import contextmanager from dataclasses import dataclass from enum import Enum -from typing import Any, Generator, Iterator, Mapping, NewType, Optional, Sequence +from typing import Any, NewType SnapshotId = NewType("SnapshotId", str) SnapshotTableRow = Mapping[str, Any] @@ -18,7 +19,7 @@ class TableConfig: table: str zip: bool - columns: Optional[Sequence[ColumnConfig]] + columns: Sequence[ColumnConfig] | None @classmethod def from_dict(cls, content: Mapping[str, Any]) -> TableConfig: @@ -27,16 +28,14 @@ def from_dict(cls, content: Mapping[str, Any]) -> TableConfig: # This has already been validated by the jsonschema validator assert isinstance(column, Mapping) if column.get("formatter") is not None: - formatter: Optional[FormatterConfig] = FormatterConfig.from_dict( - column["formatter"] - ) + formatter: FormatterConfig | None = FormatterConfig.from_dict(column["formatter"]) else: formatter = None columns.append(ColumnConfig(name=column["name"], formatter=formatter)) return TableConfig(content["table"], content["zip"], columns) -class FormatterConfig(ABC): +class FormatterConfig(ABC): # noqa: B024 - intentional abstract parent class with no abstract methods """ Parent class to all the the formatter configs. """ @@ -45,8 +44,7 @@ class FormatterConfig(ABC): def from_dict(cls, content: Mapping[str, str]) -> FormatterConfig: if content["type"] == "datetime": return DateTimeFormatterConfig.from_dict(content) - else: - raise ValueError("Unknown config for column formatter") + raise ValueError("Unknown config for column formatter") class DateFormatPrecision(Enum): @@ -70,7 +68,7 @@ class ColumnConfig: """ name: str - formatter: Optional[FormatterConfig] = None + formatter: FormatterConfig | None = None @dataclass(frozen=True) @@ -103,12 +101,10 @@ def get_descriptor(self) -> SnapshotDescriptor: @abstractmethod @contextmanager - def get_parsed_table_file( - self, table: str - ) -> Generator[Iterator[SnapshotTableRow], None, None]: + def get_parsed_table_file(self, table: str) -> Generator[Iterator[SnapshotTableRow]]: raise NotImplementedError @abstractmethod @contextmanager - def get_preprocessed_table_file(self, table: str) -> Generator[Iterator[bytes], None, None]: + def get_preprocessed_table_file(self, table: str) -> Generator[Iterator[bytes]]: raise NotImplementedError diff --git a/snuba/snapshots/loaders/__init__.py b/snuba/snapshots/loaders/__init__.py index 54490715664..9a53e26465a 100644 --- a/snuba/snapshots/loaders/__init__.py +++ b/snuba/snapshots/loaders/__init__.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Callable, Optional +from collections.abc import Callable from snuba.clickhouse.http import JSONRow from snuba.writer import BatchWriter, BufferedWriterWrapper, WriterTableRow @@ -21,7 +21,7 @@ def load( self, writer: BufferedWriterWrapper[JSONRow, WriterTableRow], ignore_existing_data: bool, - progress_callback: Optional[ProgressCallback], + progress_callback: ProgressCallback | None, ) -> None: raise NotImplementedError @@ -30,6 +30,6 @@ def load_preprocessed( self, writer: BatchWriter[bytes], ignore_existing_data: bool, - progress_callback: Optional[ProgressCallback], + progress_callback: ProgressCallback | None, ) -> None: raise NotImplementedError diff --git a/snuba/snapshots/loaders/single_table.py b/snuba/snapshots/loaders/single_table.py index d2fc8a59a53..84a4cc20bf9 100644 --- a/snuba/snapshots/loaders/single_table.py +++ b/snuba/snapshots/loaders/single_table.py @@ -1,5 +1,5 @@ import logging -from typing import Iterable, Optional +from collections.abc import Iterable from snuba.clickhouse.http import JSONRow from snuba.clickhouse.native import ClickhousePool @@ -33,11 +33,11 @@ def __init__( def __validate_table(self, ignore_existing_data: bool) -> None: clickhouse_tables = self.__clickhouse.execute("show tables").results if (self.__dest_table,) not in clickhouse_tables: - raise ValueError("Destination table %s does not exists" % self.__dest_table) + raise ValueError(f"Destination table {self.__dest_table} does not exists") if not ignore_existing_data: table_content = self.__clickhouse.execute( - "select count(*) from %s" % self.__dest_table + f"select count(*) from {self.__dest_table}" ).results if table_content != [(0,)]: raise ValueError("Destination Table is not empty") @@ -46,7 +46,7 @@ def load( self, writer: BufferedWriterWrapper[JSONRow, WriterTableRow], ignore_existing_data: bool, - progress_callback: Optional[ProgressCallback], + progress_callback: ProgressCallback | None, ) -> None: self.__validate_table(ignore_existing_data) descriptor = self.__source.get_descriptor() @@ -66,7 +66,7 @@ def load_preprocessed( self, writer: BatchWriter[bytes], ignore_existing_data: bool, - progress_callback: Optional[ProgressCallback], + progress_callback: ProgressCallback | None, ) -> None: self.__validate_table(ignore_existing_data) descriptor = self.__source.get_descriptor() diff --git a/snuba/snapshots/postgres_snapshot.py b/snuba/snapshots/postgres_snapshot.py index 8dd41faeedc..955a1352067 100644 --- a/snuba/snapshots/postgres_snapshot.py +++ b/snuba/snapshots/postgres_snapshot.py @@ -4,9 +4,10 @@ import json import logging import os.path +from collections.abc import Generator, Iterable, Iterator, Sequence from contextlib import contextmanager from dataclasses import dataclass -from typing import Generator, Iterable, Iterator, NewType, Sequence +from typing import NewType import jsonschema @@ -115,7 +116,7 @@ def __init__(self, path: str, descriptor: PostgresSnapshotDescriptor) -> None: @classmethod def load(cls, product: str, path: str) -> PostgresSnapshot: meta_file_name = os.path.join(path, "metadata.json") - with open(meta_file_name, "r") as meta_file: + with open(meta_file_name) as meta_file: json_desc = json.load(meta_file) jsonschema.validate( json_desc, @@ -124,8 +125,9 @@ def load(cls, product: str, path: str) -> PostgresSnapshot: if json_desc["product"] != product: raise ValueError( - "Invalid product in Postgres snapshot %s. Expected %s" - % (json_desc["product"], product) + "Invalid product in Postgres snapshot {}. Expected {}".format( + json_desc["product"], product + ) ) desc_content = [TableConfig.from_dict(table) for table in json_desc["content"]] @@ -158,13 +160,13 @@ def get_table_file_size(self, table_name: str) -> int: def get_parsed_table_file( self, table: str, - ) -> Generator[Iterator[SnapshotTableRow], None, None]: + ) -> Generator[Iterator[SnapshotTableRow]]: table_desc = self.__descriptor.get_table(table) assert not table_desc.zip, "Cannot parse a gzip table file on the fly" table_path = self.__get_table_path(table) try: - with open(table_path, "r") as table_file: + with open(table_path) as table_file: csv_file = csv.DictReader(table_file) columns = csv_file.fieldnames @@ -179,11 +181,7 @@ def get_parsed_table_file( existing_set = set(columns) if not expected_set <= existing_set: raise ValueError( - "The table %s is missing columns %r " - % ( - table, - expected_set - existing_set, - ) + f"The table {table} is missing columns {expected_set - existing_set!r} " ) if len(existing_set) != len(expected_set): @@ -199,27 +197,24 @@ def get_parsed_table_file( yield csv_file - except FileNotFoundError: + except FileNotFoundError as e: raise ValueError( - "The snapshot does not contain the requested table %s" % table, - ) + f"The snapshot does not contain the requested table {table}", + ) from e @contextmanager - def get_preprocessed_table_file(self, table: str) -> Generator[Iterator[bytes], None, None]: + def get_preprocessed_table_file(self, table: str) -> Generator[Iterator[bytes]]: table_path = self.__get_table_path(table) try: with open(table_path, "rb") as table_file: def chunks_provider() -> Iterator[bytes]: - for chunk in iter( - lambda: table_file.read(settings.BULK_BINARY_LOAD_CHUNK), b"" - ): - yield chunk + yield from iter(lambda: table_file.read(settings.BULK_BINARY_LOAD_CHUNK), b"") yield chunks_provider() - except FileNotFoundError: + except FileNotFoundError as e: raise ValueError( - "The snapshot does not contain the requested table %s" % table, - ) + f"The snapshot does not contain the requested table {table}", + ) from e diff --git a/snuba/snuba_migrations/discover/0001_discover_merge_table.py b/snuba/snuba_migrations/discover/0001_discover_merge_table.py index 7d3a4b5883c..f96985d4639 100644 --- a/snuba/snuba_migrations/discover/0001_discover_merge_table.py +++ b/snuba/snuba_migrations/discover/0001_discover_merge_table.py @@ -1,4 +1,4 @@ -from typing import List, Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import ( UUID, @@ -14,7 +14,7 @@ from snuba.migrations import migration, operations, table_engines from snuba.migrations.columns import MigrationModifiers as Modifiers -columns: List[Column[Modifiers]] = [ +columns: list[Column[Modifiers]] = [ Column("event_id", UUID()), Column("project_id", UInt(64)), Column("type", String(Modifiers(low_cardinality=True))), @@ -51,9 +51,7 @@ def forwards_local(self) -> Sequence[operations.SqlOperation]: storage_set=StorageSetKey.DISCOVER, table_name="discover_local", columns=columns, - engine=table_engines.Merge( - table_name_regex="^errors_local$|^transactions_local$" - ), + engine=table_engines.Merge(table_name_regex="^errors_local$|^transactions_local$"), ), ] @@ -80,7 +78,5 @@ def forwards_dist(self) -> Sequence[operations.SqlOperation]: def backwards_dist(self) -> Sequence[operations.SqlOperation]: return [ - operations.DropTable( - storage_set=StorageSetKey.DISCOVER, table_name="discover_dist" - ) + operations.DropTable(storage_set=StorageSetKey.DISCOVER, table_name="discover_dist") ] diff --git a/snuba/snuba_migrations/discover/0002_discover_add_deleted_tags_hash_map.py b/snuba/snuba_migrations/discover/0002_discover_add_deleted_tags_hash_map.py index 34371ee500c..da21eb78a79 100644 --- a/snuba/snuba_migrations/discover/0002_discover_add_deleted_tags_hash_map.py +++ b/snuba/snuba_migrations/discover/0002_discover_add_deleted_tags_hash_map.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import Array, Column, UInt from snuba.clusters.storage_sets import StorageSetKey @@ -13,9 +13,7 @@ class Migration(migration.ClickhouseNodeMigrationLegacy): blocking = False - def __forward_migrations( - self, table_name: str - ) -> Sequence[operations.SqlOperation]: + def __forward_migrations(self, table_name: str) -> Sequence[operations.SqlOperation]: return [ operations.AddColumn( storage_set=StorageSetKey.DISCOVER, @@ -31,9 +29,7 @@ def __forward_migrations( ), ] - def __backward_migrations( - self, table_name: str - ) -> Sequence[operations.SqlOperation]: + def __backward_migrations(self, table_name: str) -> Sequence[operations.SqlOperation]: return [ operations.DropColumn( storage_set=StorageSetKey.DISCOVER, diff --git a/snuba/snuba_migrations/discover/0003_discover_fix_user_column.py b/snuba/snuba_migrations/discover/0003_discover_fix_user_column.py index 4acb3e2e70d..7d080c8ebde 100644 --- a/snuba/snuba_migrations/discover/0003_discover_fix_user_column.py +++ b/snuba/snuba_migrations/discover/0003_discover_fix_user_column.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import Column, String from snuba.clusters.storage_sets import StorageSetKey @@ -14,9 +14,7 @@ class Migration(migration.ClickhouseNodeMigrationLegacy): blocking = False - def __forward_migrations( - self, table_name: str - ) -> Sequence[operations.SqlOperation]: + def __forward_migrations(self, table_name: str) -> Sequence[operations.SqlOperation]: return [ operations.ModifyColumn( storage_set=StorageSetKey.DISCOVER, @@ -28,9 +26,7 @@ def __forward_migrations( ) ] - def __backwards_migrations( - self, table_name: str - ) -> Sequence[operations.SqlOperation]: + def __backwards_migrations(self, table_name: str) -> Sequence[operations.SqlOperation]: return [ operations.ModifyColumn( storage_set=StorageSetKey.DISCOVER, diff --git a/snuba/snuba_migrations/discover/0004_discover_fix_title_and_message.py b/snuba/snuba_migrations/discover/0004_discover_fix_title_and_message.py index 3f2d5d92b48..fc552ae1950 100644 --- a/snuba/snuba_migrations/discover/0004_discover_fix_title_and_message.py +++ b/snuba/snuba_migrations/discover/0004_discover_fix_title_and_message.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import Column, String from snuba.clusters.storage_sets import StorageSetKey @@ -14,9 +14,7 @@ class Migration(migration.ClickhouseNodeMigrationLegacy): blocking = False - def __forward_migrations( - self, table_name: str - ) -> Sequence[operations.SqlOperation]: + def __forward_migrations(self, table_name: str) -> Sequence[operations.SqlOperation]: return [ operations.ModifyColumn( storage_set=StorageSetKey.DISCOVER, @@ -30,9 +28,7 @@ def __forward_migrations( ), ] - def __backwards_migrations( - self, table_name: str - ) -> Sequence[operations.SqlOperation]: + def __backwards_migrations(self, table_name: str) -> Sequence[operations.SqlOperation]: return [ operations.ModifyColumn( storage_set=StorageSetKey.DISCOVER, diff --git a/snuba/snuba_migrations/discover/0005_discover_fix_transaction_name.py b/snuba/snuba_migrations/discover/0005_discover_fix_transaction_name.py index aab34ae6654..76dd3f77ede 100644 --- a/snuba/snuba_migrations/discover/0005_discover_fix_transaction_name.py +++ b/snuba/snuba_migrations/discover/0005_discover_fix_transaction_name.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import Column, String from snuba.clusters.storage_sets import StorageSetKey @@ -13,22 +13,16 @@ class Migration(migration.ClickhouseNodeMigrationLegacy): blocking = False - def __forward_migrations( - self, table_name: str - ) -> Sequence[operations.SqlOperation]: + def __forward_migrations(self, table_name: str) -> Sequence[operations.SqlOperation]: return [ operations.ModifyColumn( storage_set=StorageSetKey.DISCOVER, table_name=table_name, - column=Column( - "transaction_name", String(Modifiers(low_cardinality=True)) - ), + column=Column("transaction_name", String(Modifiers(low_cardinality=True))), ), ] - def __backwards_migrations( - self, table_name: str - ) -> Sequence[operations.SqlOperation]: + def __backwards_migrations(self, table_name: str) -> Sequence[operations.SqlOperation]: return [ operations.ModifyColumn( storage_set=StorageSetKey.DISCOVER, diff --git a/snuba/snuba_migrations/discover/0006_discover_add_trace_id.py b/snuba/snuba_migrations/discover/0006_discover_add_trace_id.py index 8ba369146b8..94625d7d64c 100644 --- a/snuba/snuba_migrations/discover/0006_discover_add_trace_id.py +++ b/snuba/snuba_migrations/discover/0006_discover_add_trace_id.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import UUID, Column from snuba.clusters.storage_sets import StorageSetKey @@ -13,9 +13,7 @@ class Migration(migration.ClickhouseNodeMigrationLegacy): blocking = False - def __forward_migrations( - self, table_name: str - ) -> Sequence[operations.SqlOperation]: + def __forward_migrations(self, table_name: str) -> Sequence[operations.SqlOperation]: return [ operations.AddColumn( storage_set=StorageSetKey.DISCOVER, @@ -25,9 +23,7 @@ def __forward_migrations( ), ] - def __backwards_migrations( - self, table_name: str - ) -> Sequence[operations.SqlOperation]: + def __backwards_migrations(self, table_name: str) -> Sequence[operations.SqlOperation]: return [ operations.DropColumn( storage_set=StorageSetKey.DISCOVER, diff --git a/snuba/snuba_migrations/discover/0007_discover_add_span_id.py b/snuba/snuba_migrations/discover/0007_discover_add_span_id.py index 884292ec855..02482d603ee 100644 --- a/snuba/snuba_migrations/discover/0007_discover_add_span_id.py +++ b/snuba/snuba_migrations/discover/0007_discover_add_span_id.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import Column, UInt from snuba.clusters.storage_sets import StorageSetKey @@ -13,9 +13,7 @@ class Migration(migration.ClickhouseNodeMigrationLegacy): blocking = False - def __forward_migrations( - self, table_name: str - ) -> Sequence[operations.SqlOperation]: + def __forward_migrations(self, table_name: str) -> Sequence[operations.SqlOperation]: return [ operations.AddColumn( storage_set=StorageSetKey.DISCOVER, @@ -25,9 +23,7 @@ def __forward_migrations( ), ] - def __backwards_migrations( - self, table_name: str - ) -> Sequence[operations.SqlOperation]: + def __backwards_migrations(self, table_name: str) -> Sequence[operations.SqlOperation]: return [ operations.DropColumn( storage_set=StorageSetKey.DISCOVER, diff --git a/snuba/snuba_migrations/discover/0008_discover_fix_add_local_table.py b/snuba/snuba_migrations/discover/0008_discover_fix_add_local_table.py index 92c0f3a5469..6c82549a11c 100644 --- a/snuba/snuba_migrations/discover/0008_discover_fix_add_local_table.py +++ b/snuba/snuba_migrations/discover/0008_discover_fix_add_local_table.py @@ -1,4 +1,4 @@ -from typing import List, Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import ( UUID, @@ -16,7 +16,7 @@ from snuba.migrations.columns import MigrationModifiers as Modifiers from snuba.migrations.operations import OperationTarget -columns: List[Column[Modifiers]] = [ +columns: list[Column[Modifiers]] = [ Column("event_id", UUID()), Column("project_id", UInt(64)), Column("type", String(Modifiers(low_cardinality=True))), @@ -58,9 +58,7 @@ def forwards_ops(self) -> Sequence[operations.SqlOperation]: storage_set=StorageSetKey.DISCOVER, table_name=self.local_table_name, columns=columns, - engine=table_engines.Merge( - table_name_regex="^errors_local$|^transactions_local$" - ), + engine=table_engines.Merge(table_name_regex="^errors_local$|^transactions_local$"), target=OperationTarget.LOCAL, ), ] diff --git a/snuba/snuba_migrations/discover/0009_discover_add_replay_id.py b/snuba/snuba_migrations/discover/0009_discover_add_replay_id.py index f86a3116bcf..ab351e15c37 100644 --- a/snuba/snuba_migrations/discover/0009_discover_add_replay_id.py +++ b/snuba/snuba_migrations/discover/0009_discover_add_replay_id.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import UUID, Column from snuba.clusters.storage_sets import StorageSetKey diff --git a/snuba/snuba_migrations/events/0007_groupedmessages.py b/snuba/snuba_migrations/events/0007_groupedmessages.py index 7a709e3e9a1..db208c6f67f 100644 --- a/snuba/snuba_migrations/events/0007_groupedmessages.py +++ b/snuba/snuba_migrations/events/0007_groupedmessages.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import Column, DateTime, UInt from snuba.clusters.storage_sets import StorageSetKey diff --git a/snuba/snuba_migrations/events/0008_groupassignees.py b/snuba/snuba_migrations/events/0008_groupassignees.py index 231b24442d0..868553842b5 100644 --- a/snuba/snuba_migrations/events/0008_groupassignees.py +++ b/snuba/snuba_migrations/events/0008_groupassignees.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import Column, DateTime, UInt from snuba.clusters.storage_sets import StorageSetKey diff --git a/snuba/snuba_migrations/events/0010_groupedmessages_onpremise_compatibility.py b/snuba/snuba_migrations/events/0010_groupedmessages_onpremise_compatibility.py index 890833019c0..84a64049a61 100644 --- a/snuba/snuba_migrations/events/0010_groupedmessages_onpremise_compatibility.py +++ b/snuba/snuba_migrations/events/0010_groupedmessages_onpremise_compatibility.py @@ -1,5 +1,5 @@ import logging -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import Column, UInt from snuba.clusters.cluster import ClickhouseClientSettings, get_cluster @@ -46,10 +46,9 @@ def fix_order_by(_logger: logging.Logger) -> None: clickhouse.execute(add_column_sql) # There shouldn't be any data in the table yet - assert ( - clickhouse.execute(f"SELECT COUNT() FROM {TABLE_NAME} FINAL;").results[0][0] - == 0 - ), f"{TABLE_NAME} is not empty" + assert clickhouse.execute(f"SELECT COUNT() FROM {TABLE_NAME} FINAL;").results[0][0] == 0, ( + f"{TABLE_NAME} is not empty" + ) new_order_by = f"ORDER BY ({new_primary_key})" old_order_by = f"ORDER BY {old_primary_key}" diff --git a/snuba/snuba_migrations/events/0011_rebuild_errors.py b/snuba/snuba_migrations/events/0011_rebuild_errors.py index 97e13252f3e..7504ac6d7b5 100644 --- a/snuba/snuba_migrations/events/0011_rebuild_errors.py +++ b/snuba/snuba_migrations/events/0011_rebuild_errors.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import ( UUID, @@ -123,8 +123,7 @@ def forwards_local(self) -> Sequence[operations.SqlOperation]: engine=table_engines.ReplacingMergeTree( storage_set=StorageSetKey.EVENTS, version_column="deleted", - order_by="(project_id, toStartOfDay(timestamp), primary_hash, %s)" - % sample_expr, + order_by=f"(project_id, toStartOfDay(timestamp), primary_hash, {sample_expr})", partition_by="(retention_days, toMonday(timestamp))", sample_by=sample_expr, ttl="timestamp + toIntervalDay(retention_days)", @@ -143,11 +142,7 @@ def forwards_local(self) -> Sequence[operations.SqlOperation]: ] def backwards_local(self) -> Sequence[operations.SqlOperation]: - return [ - operations.DropTable( - storage_set=StorageSetKey.EVENTS, table_name="errors_local" - ) - ] + return [operations.DropTable(storage_set=StorageSetKey.EVENTS, table_name="errors_local")] def forwards_dist(self) -> Sequence[operations.SqlOperation]: return [ @@ -182,10 +177,6 @@ def forwards_dist(self) -> Sequence[operations.SqlOperation]: def backwards_dist(self) -> Sequence[operations.SqlOperation]: return [ - operations.DropTable( - storage_set=StorageSetKey.EVENTS, table_name="errors_dist" - ), - operations.DropTable( - storage_set=StorageSetKey.EVENTS_RO, table_name="errors_dist_ro" - ), + operations.DropTable(storage_set=StorageSetKey.EVENTS, table_name="errors_dist"), + operations.DropTable(storage_set=StorageSetKey.EVENTS_RO, table_name="errors_dist_ro"), ] diff --git a/snuba/snuba_migrations/events/0012_errors_make_level_nullable.py b/snuba/snuba_migrations/events/0012_errors_make_level_nullable.py index 8d7f5e7cee2..dc95a60d539 100644 --- a/snuba/snuba_migrations/events/0012_errors_make_level_nullable.py +++ b/snuba/snuba_migrations/events/0012_errors_make_level_nullable.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import Column, String from snuba.clusters.storage_sets import StorageSetKey diff --git a/snuba/snuba_migrations/events/0013_errors_add_hierarchical_hashes.py b/snuba/snuba_migrations/events/0013_errors_add_hierarchical_hashes.py index 20f66437077..d577aaa3e8d 100644 --- a/snuba/snuba_migrations/events/0013_errors_add_hierarchical_hashes.py +++ b/snuba/snuba_migrations/events/0013_errors_add_hierarchical_hashes.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import UUID, Array, Column from snuba.clusters.storage_sets import StorageSetKey @@ -28,9 +28,7 @@ def forwards_local(self) -> Sequence[operations.SqlOperation]: def backwards_local(self) -> Sequence[operations.SqlOperation]: return [ - operations.DropColumn( - StorageSetKey.EVENTS, "errors_local", "hierarchical_hashes" - ), + operations.DropColumn(StorageSetKey.EVENTS, "errors_local", "hierarchical_hashes"), ] def forwards_dist(self) -> Sequence[operations.SqlOperation]: @@ -45,7 +43,5 @@ def forwards_dist(self) -> Sequence[operations.SqlOperation]: def backwards_dist(self) -> Sequence[operations.SqlOperation]: return [ - operations.DropColumn( - StorageSetKey.EVENTS, "errors_dist", "hierarchical_hashes" - ), + operations.DropColumn(StorageSetKey.EVENTS, "errors_dist", "hierarchical_hashes"), ] diff --git a/snuba/snuba_migrations/events/0017_errors_add_indexes.py b/snuba/snuba_migrations/events/0017_errors_add_indexes.py index fcb43e1c408..f82c23390a6 100644 --- a/snuba/snuba_migrations/events/0017_errors_add_indexes.py +++ b/snuba/snuba_migrations/events/0017_errors_add_indexes.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clusters.storage_sets import StorageSetKey from snuba.migrations import migration, operations diff --git a/snuba/snuba_migrations/events/0018_errors_ro_add_tags_hash_map.py b/snuba/snuba_migrations/events/0018_errors_ro_add_tags_hash_map.py index 19e11ffd042..fda41065cd8 100644 --- a/snuba/snuba_migrations/events/0018_errors_ro_add_tags_hash_map.py +++ b/snuba/snuba_migrations/events/0018_errors_ro_add_tags_hash_map.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import UUID, Array, Column, String, UInt from snuba.clusters.storage_sets import StorageSetKey @@ -32,9 +32,7 @@ def forwards_ops(self) -> Sequence[operations.SqlOperation]: operations.ModifyColumn( storage_set=StorageSetKey.EVENTS_RO, table_name="errors_dist_ro", - column=Column( - "level", String(Modifiers(low_cardinality=True, nullable=True)) - ), + column=Column("level", String(Modifiers(low_cardinality=True, nullable=True))), target=operations.OperationTarget.DISTRIBUTED, ), ] diff --git a/snuba/snuba_migrations/events/0019_add_replay_id_column.py b/snuba/snuba_migrations/events/0019_add_replay_id_column.py index 70d84d35779..8e369f9c901 100644 --- a/snuba/snuba_migrations/events/0019_add_replay_id_column.py +++ b/snuba/snuba_migrations/events/0019_add_replay_id_column.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import UUID, Column from snuba.clusters.storage_sets import StorageSetKey diff --git a/snuba/snuba_migrations/events/0020_add_main_thread_column.py b/snuba/snuba_migrations/events/0020_add_main_thread_column.py index 5aaa905900b..76152215c92 100644 --- a/snuba/snuba_migrations/events/0020_add_main_thread_column.py +++ b/snuba/snuba_migrations/events/0020_add_main_thread_column.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import Column, UInt from snuba.clusters.storage_sets import StorageSetKey @@ -15,9 +15,7 @@ def forwards_ops(self) -> Sequence[operations.SqlOperation]: operations.AddColumn( storage_set=StorageSetKey.EVENTS, table_name=table_name, - column=Column( - "exception_main_thread", UInt(8, Modifiers(nullable=True)) - ), + column=Column("exception_main_thread", UInt(8, Modifiers(nullable=True))), after="modules.version", target=target, ) diff --git a/snuba/snuba_migrations/events/0021_add_replay_id_errors_ro.py b/snuba/snuba_migrations/events/0021_add_replay_id_errors_ro.py index 6be0e223c39..2d8678d4704 100644 --- a/snuba/snuba_migrations/events/0021_add_replay_id_errors_ro.py +++ b/snuba/snuba_migrations/events/0021_add_replay_id_errors_ro.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import UUID, Column from snuba.clusters.storage_sets import StorageSetKey diff --git a/snuba/snuba_migrations/events/0022_add_main_thread_column_errors_ro.py b/snuba/snuba_migrations/events/0022_add_main_thread_column_errors_ro.py index 78c40623f49..153453d3cb8 100644 --- a/snuba/snuba_migrations/events/0022_add_main_thread_column_errors_ro.py +++ b/snuba/snuba_migrations/events/0022_add_main_thread_column_errors_ro.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import Column, UInt from snuba.clusters.storage_sets import StorageSetKey @@ -17,9 +17,7 @@ def forwards_ops(self) -> Sequence[operations.SqlOperation]: operations.AddColumn( storage_set=StorageSetKey.EVENTS_RO, table_name=table_name, - column=Column( - "exception_main_thread", UInt(8, Modifiers(nullable=True)) - ), + column=Column("exception_main_thread", UInt(8, Modifiers(nullable=True))), after="modules.version", target=OperationTarget.DISTRIBUTED, ) diff --git a/snuba/snuba_migrations/events/0023_add_trace_sampled_num_processing_errors_columns.py b/snuba/snuba_migrations/events/0023_add_trace_sampled_num_processing_errors_columns.py index 1f0f1467ec6..77be3a28ed4 100644 --- a/snuba/snuba_migrations/events/0023_add_trace_sampled_num_processing_errors_columns.py +++ b/snuba/snuba_migrations/events/0023_add_trace_sampled_num_processing_errors_columns.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import Column from snuba.clusters.storage_sets import StorageSetKey @@ -24,9 +24,7 @@ def forwards_ops(self) -> Sequence[operations.SqlOperation]: operations.AddColumn( storage_set=StorageSetKey.EVENTS, table_name=table_name, - column=Column( - "num_processing_errors", UInt(64, Modifiers(nullable=True)) - ), + column=Column("num_processing_errors", UInt(64, Modifiers(nullable=True))), after="trace_sampled", target=target, ), diff --git a/snuba/snuba_migrations/events/0024_add_trace_sampled_num_processing_errors_columns_ro.py b/snuba/snuba_migrations/events/0024_add_trace_sampled_num_processing_errors_columns_ro.py index ab8e9bc9eba..9fae4b363fc 100644 --- a/snuba/snuba_migrations/events/0024_add_trace_sampled_num_processing_errors_columns_ro.py +++ b/snuba/snuba_migrations/events/0024_add_trace_sampled_num_processing_errors_columns_ro.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import Column, UInt from snuba.clusters.storage_sets import StorageSetKey @@ -24,9 +24,7 @@ def forwards_ops(self) -> Sequence[operations.SqlOperation]: operations.AddColumn( storage_set=StorageSetKey.EVENTS_RO, table_name=table_name, - column=Column( - "num_processing_errors", UInt(64, Modifiers(nullable=True)) - ), + column=Column("num_processing_errors", UInt(64, Modifiers(nullable=True))), after="trace_sampled", target=OperationTarget.DISTRIBUTED, ), diff --git a/snuba/snuba_migrations/events/0025_add_flags_column.py b/snuba/snuba_migrations/events/0025_add_flags_column.py index 38f2d693699..34e44e3d264 100644 --- a/snuba/snuba_migrations/events/0025_add_flags_column.py +++ b/snuba/snuba_migrations/events/0025_add_flags_column.py @@ -1,4 +1,4 @@ -from typing import Iterator, Sequence +from collections.abc import Iterator, Sequence from snuba.clickhouse.columns import Array, Column, Nested, String, UInt from snuba.clusters.storage_sets import StorageSetKey @@ -31,9 +31,7 @@ def forward_ops() -> Iterator[operations.SqlOperation]: operations.AddColumn( storage_set=storage_set, table_name=table_name, - column=Column( - "flags", Nested([("key", String()), ("value", String())]) - ), + column=Column("flags", Nested([("key", String()), ("value", String())])), after="_tags_hash_map", target=target, ), diff --git a/snuba/snuba_migrations/events/0026_add_symbolicated_in_app_column.py b/snuba/snuba_migrations/events/0026_add_symbolicated_in_app_column.py index f46bca4e780..67cca13ea7e 100644 --- a/snuba/snuba_migrations/events/0026_add_symbolicated_in_app_column.py +++ b/snuba/snuba_migrations/events/0026_add_symbolicated_in_app_column.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clusters.storage_sets import StorageSetKey from snuba.migrations.columns import MigrationModifiers diff --git a/snuba/snuba_migrations/events/0027_add_symbolicated_in_app_column_ro.py b/snuba/snuba_migrations/events/0027_add_symbolicated_in_app_column_ro.py index f52c4380a72..51bd8284610 100644 --- a/snuba/snuba_migrations/events/0027_add_symbolicated_in_app_column_ro.py +++ b/snuba/snuba_migrations/events/0027_add_symbolicated_in_app_column_ro.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clusters.storage_sets import StorageSetKey from snuba.migrations.columns import MigrationModifiers diff --git a/snuba/snuba_migrations/events/0028_add_timestamp_ms_column_errors.py b/snuba/snuba_migrations/events/0028_add_timestamp_ms_column_errors.py index 0d33add58e8..5ad2c3dfe10 100644 --- a/snuba/snuba_migrations/events/0028_add_timestamp_ms_column_errors.py +++ b/snuba/snuba_migrations/events/0028_add_timestamp_ms_column_errors.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clusters.storage_sets import StorageSetKey from snuba.migrations.columns import MigrationModifiers diff --git a/snuba/snuba_migrations/events/0029_add_sample_weight_column_to_errors.py b/snuba/snuba_migrations/events/0029_add_sample_weight_column_to_errors.py index 40cd6581e5a..1b909d8a5eb 100644 --- a/snuba/snuba_migrations/events/0029_add_sample_weight_column_to_errors.py +++ b/snuba/snuba_migrations/events/0029_add_sample_weight_column_to_errors.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clusters.storage_sets import StorageSetKey from snuba.migrations.columns import MigrationModifiers diff --git a/snuba/snuba_migrations/events/0030_add_group_first_seen_column_to_errors.py b/snuba/snuba_migrations/events/0030_add_group_first_seen_column_to_errors.py index a9bf59429a3..2d1510baf35 100644 --- a/snuba/snuba_migrations/events/0030_add_group_first_seen_column_to_errors.py +++ b/snuba/snuba_migrations/events/0030_add_group_first_seen_column_to_errors.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clusters.storage_sets import StorageSetKey from snuba.migrations.columns import MigrationModifiers @@ -20,27 +20,21 @@ def forwards_ops(self) -> Sequence[SqlOperation]: AddColumn( storage_set=StorageSetKey.EVENTS, table_name="errors_local", - column=Column( - "group_first_seen", DateTime(MigrationModifiers(nullable=True)) - ), + column=Column("group_first_seen", DateTime(MigrationModifiers(nullable=True))), after="sample_weight", target=OperationTarget.LOCAL, ), AddColumn( storage_set=StorageSetKey.EVENTS, table_name="errors_dist", - column=Column( - "group_first_seen", DateTime(MigrationModifiers(nullable=True)) - ), + column=Column("group_first_seen", DateTime(MigrationModifiers(nullable=True))), after="sample_weight", target=OperationTarget.DISTRIBUTED, ), AddColumn( storage_set=StorageSetKey.EVENTS_RO, table_name="errors_dist_ro", - column=Column( - "group_first_seen", DateTime(MigrationModifiers(nullable=True)) - ), + column=Column("group_first_seen", DateTime(MigrationModifiers(nullable=True))), after="sample_weight", target=OperationTarget.DISTRIBUTED, ), diff --git a/snuba/snuba_migrations/events_analytics_platform/0001_spans.py b/snuba/snuba_migrations/events_analytics_platform/0001_spans.py index 0e79fdb05a3..6acacdea4c6 100644 --- a/snuba/snuba_migrations/events_analytics_platform/0001_spans.py +++ b/snuba/snuba_migrations/events_analytics_platform/0001_spans.py @@ -1,4 +1,4 @@ -from typing import List, Sequence +from collections.abc import Sequence from snuba.clusters.storage_sets import StorageSetKey from snuba.migrations import migration, operations, table_engines @@ -21,7 +21,7 @@ dist_table_name = "eap_spans_dist" num_attr_buckets = 20 -columns: List[Column[Modifiers]] = [ +columns: list[Column[Modifiers]] = [ Column("organization_id", UInt(64)), Column("project_id", UInt(64)), Column("service", String(Modifiers(codecs=["ZSTD(1)"]))), @@ -117,7 +117,7 @@ class Migration(migration.ClickhouseNodeMigration): blocking = False def forwards_ops(self) -> Sequence[SqlOperation]: - res: List[SqlOperation] = [ + res: list[SqlOperation] = [ operations.CreateTable( storage_set=storage_set_name, table_name=local_table_name, diff --git a/snuba/snuba_migrations/events_analytics_platform/0002_spans_attributes_mv.py b/snuba/snuba_migrations/events_analytics_platform/0002_spans_attributes_mv.py index aa5f12ce901..5f47cbd7664 100644 --- a/snuba/snuba_migrations/events_analytics_platform/0002_spans_attributes_mv.py +++ b/snuba/snuba_migrations/events_analytics_platform/0002_spans_attributes_mv.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import AggregateFunction, Column, DateTime, String, UInt from snuba.clusters.storage_sets import StorageSetKey diff --git a/snuba/snuba_migrations/events_analytics_platform/0003_eap_spans_project_id_index.py b/snuba/snuba_migrations/events_analytics_platform/0003_eap_spans_project_id_index.py index 5eb892c2697..fc9a13e7092 100644 --- a/snuba/snuba_migrations/events_analytics_platform/0003_eap_spans_project_id_index.py +++ b/snuba/snuba_migrations/events_analytics_platform/0003_eap_spans_project_id_index.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clusters.storage_sets import StorageSetKey from snuba.migrations import migration, operations diff --git a/snuba/snuba_migrations/events_analytics_platform/0004_modify_sampling_weight.py b/snuba/snuba_migrations/events_analytics_platform/0004_modify_sampling_weight.py index 1a263b328bf..c2d6c6bc2e9 100644 --- a/snuba/snuba_migrations/events_analytics_platform/0004_modify_sampling_weight.py +++ b/snuba/snuba_migrations/events_analytics_platform/0004_modify_sampling_weight.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clusters.storage_sets import StorageSetKey from snuba.migrations import migration, operations diff --git a/snuba/snuba_migrations/events_analytics_platform/0005_remove_attribute_mv.py b/snuba/snuba_migrations/events_analytics_platform/0005_remove_attribute_mv.py index 9f032258767..24c8157b4b6 100644 --- a/snuba/snuba_migrations/events_analytics_platform/0005_remove_attribute_mv.py +++ b/snuba/snuba_migrations/events_analytics_platform/0005_remove_attribute_mv.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import AggregateFunction, Column, DateTime, String, UInt from snuba.clusters.storage_sets import StorageSetKey diff --git a/snuba/snuba_migrations/events_analytics_platform/0006_drop_attribute_key_project_id_indexes.py b/snuba/snuba_migrations/events_analytics_platform/0006_drop_attribute_key_project_id_indexes.py index 321cf02ece7..5cfc6509d1d 100644 --- a/snuba/snuba_migrations/events_analytics_platform/0006_drop_attribute_key_project_id_indexes.py +++ b/snuba/snuba_migrations/events_analytics_platform/0006_drop_attribute_key_project_id_indexes.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.migrations import migration from snuba.migrations.operations import SqlOperation diff --git a/snuba/snuba_migrations/events_analytics_platform/0007_drop_project_id_index.py b/snuba/snuba_migrations/events_analytics_platform/0007_drop_project_id_index.py index 57293d5535b..4a0a2dd4356 100644 --- a/snuba/snuba_migrations/events_analytics_platform/0007_drop_project_id_index.py +++ b/snuba/snuba_migrations/events_analytics_platform/0007_drop_project_id_index.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clusters.storage_sets import StorageSetKey from snuba.migrations import migration, operations diff --git a/snuba/snuba_migrations/events_analytics_platform/0008_drop_index_attribute_key_bucket_0.py b/snuba/snuba_migrations/events_analytics_platform/0008_drop_index_attribute_key_bucket_0.py index bbd2e70dad2..e84fb291b39 100644 --- a/snuba/snuba_migrations/events_analytics_platform/0008_drop_index_attribute_key_bucket_0.py +++ b/snuba/snuba_migrations/events_analytics_platform/0008_drop_index_attribute_key_bucket_0.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clusters.storage_sets import StorageSetKey from snuba.migrations import migration, operations diff --git a/snuba/snuba_migrations/events_analytics_platform/0009_drop_index_attribute_key_buckets_1_19.py b/snuba/snuba_migrations/events_analytics_platform/0009_drop_index_attribute_key_buckets_1_19.py index e05de544261..5bc50e274cf 100644 --- a/snuba/snuba_migrations/events_analytics_platform/0009_drop_index_attribute_key_buckets_1_19.py +++ b/snuba/snuba_migrations/events_analytics_platform/0009_drop_index_attribute_key_buckets_1_19.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clusters.storage_sets import StorageSetKey from snuba.migrations import migration, operations diff --git a/snuba/snuba_migrations/events_analytics_platform/0010_drop_indexes_on_attribute_values.py b/snuba/snuba_migrations/events_analytics_platform/0010_drop_indexes_on_attribute_values.py index f7fcf4c76d6..909a718f4fa 100644 --- a/snuba/snuba_migrations/events_analytics_platform/0010_drop_indexes_on_attribute_values.py +++ b/snuba/snuba_migrations/events_analytics_platform/0010_drop_indexes_on_attribute_values.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clusters.storage_sets import StorageSetKey from snuba.migrations import migration, operations diff --git a/snuba/snuba_migrations/events_analytics_platform/0011_span_attribute_table.py b/snuba/snuba_migrations/events_analytics_platform/0011_span_attribute_table.py index b407591f026..3a15d97bca2 100644 --- a/snuba/snuba_migrations/events_analytics_platform/0011_span_attribute_table.py +++ b/snuba/snuba_migrations/events_analytics_platform/0011_span_attribute_table.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Sequence +from collections.abc import Sequence from snuba.clusters.storage_sets import StorageSetKey from snuba.migrations import migration diff --git a/snuba/snuba_migrations/events_analytics_platform/0012_span_attribute_table_numeric.py b/snuba/snuba_migrations/events_analytics_platform/0012_span_attribute_table_numeric.py index 39c32fb7137..811b7d6577b 100644 --- a/snuba/snuba_migrations/events_analytics_platform/0012_span_attribute_table_numeric.py +++ b/snuba/snuba_migrations/events_analytics_platform/0012_span_attribute_table_numeric.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Sequence +from collections.abc import Sequence from snuba.clusters.storage_sets import StorageSetKey from snuba.migrations import migration diff --git a/snuba/snuba_migrations/events_analytics_platform/0013_span_attribute_table_shard_keys.py b/snuba/snuba_migrations/events_analytics_platform/0013_span_attribute_table_shard_keys.py index f8bbbe7cdd9..0e15285a066 100644 --- a/snuba/snuba_migrations/events_analytics_platform/0013_span_attribute_table_shard_keys.py +++ b/snuba/snuba_migrations/events_analytics_platform/0013_span_attribute_table_shard_keys.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Sequence +from collections.abc import Sequence from snuba.clusters.storage_sets import StorageSetKey from snuba.migrations import migration diff --git a/snuba/snuba_migrations/events_analytics_platform/0014_span_attribute_table_smaller.py b/snuba/snuba_migrations/events_analytics_platform/0014_span_attribute_table_smaller.py index da7ea0fa55b..4be5ff4f305 100644 --- a/snuba/snuba_migrations/events_analytics_platform/0014_span_attribute_table_smaller.py +++ b/snuba/snuba_migrations/events_analytics_platform/0014_span_attribute_table_smaller.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Sequence +from collections.abc import Sequence from snuba.clusters.storage_sets import StorageSetKey from snuba.migrations import migration diff --git a/snuba/snuba_migrations/events_analytics_platform/0015_span_attribute_table_namespaced.py b/snuba/snuba_migrations/events_analytics_platform/0015_span_attribute_table_namespaced.py index 0b850d92776..34e85be1195 100644 --- a/snuba/snuba_migrations/events_analytics_platform/0015_span_attribute_table_namespaced.py +++ b/snuba/snuba_migrations/events_analytics_platform/0015_span_attribute_table_namespaced.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Sequence +from collections.abc import Sequence from snuba.clusters.storage_sets import StorageSetKey from snuba.migrations import migration diff --git a/snuba/snuba_migrations/events_analytics_platform/0016_spans_v2.py b/snuba/snuba_migrations/events_analytics_platform/0016_spans_v2.py index 4f6169b0c91..51460bb91e8 100644 --- a/snuba/snuba_migrations/events_analytics_platform/0016_spans_v2.py +++ b/snuba/snuba_migrations/events_analytics_platform/0016_spans_v2.py @@ -1,4 +1,4 @@ -from typing import List, Sequence +from collections.abc import Sequence from snuba.clusters.storage_sets import StorageSetKey from snuba.migrations import migration, operations, table_engines @@ -21,7 +21,7 @@ dist_table_name = "eap_spans_2_dist" num_attr_buckets = 20 -columns: List[Column[Modifiers]] = [ +columns: list[Column[Modifiers]] = [ Column("organization_id", UInt(64)), Column("project_id", UInt(64)), Column("service", String(Modifiers(codecs=["ZSTD(1)"]))), @@ -92,7 +92,7 @@ class Migration(migration.ClickhouseNodeMigration): blocking = False def forwards_ops(self) -> Sequence[SqlOperation]: - res: List[SqlOperation] = [ + res: list[SqlOperation] = [ operations.CreateTable( storage_set=storage_set_name, table_name=local_table_name, diff --git a/snuba/snuba_migrations/events_analytics_platform/0017_span_attribute_table_v3.py b/snuba/snuba_migrations/events_analytics_platform/0017_span_attribute_table_v3.py index 672555ff731..a73baf07446 100644 --- a/snuba/snuba_migrations/events_analytics_platform/0017_span_attribute_table_v3.py +++ b/snuba/snuba_migrations/events_analytics_platform/0017_span_attribute_table_v3.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Sequence +from collections.abc import Sequence from snuba.clusters.storage_sets import StorageSetKey from snuba.migrations import migration diff --git a/snuba/snuba_migrations/events_analytics_platform/0018_drop_unused_span_tables.py b/snuba/snuba_migrations/events_analytics_platform/0018_drop_unused_span_tables.py index 07fa21910d8..767cfa0b4d5 100644 --- a/snuba/snuba_migrations/events_analytics_platform/0018_drop_unused_span_tables.py +++ b/snuba/snuba_migrations/events_analytics_platform/0018_drop_unused_span_tables.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Sequence +from collections.abc import Sequence from snuba.clusters.storage_sets import StorageSetKey from snuba.migrations import migration diff --git a/snuba/snuba_migrations/events_analytics_platform/0019_uptime_monitors_init.py b/snuba/snuba_migrations/events_analytics_platform/0019_uptime_monitors_init.py index 2b501a87563..69ede87ddfc 100644 --- a/snuba/snuba_migrations/events_analytics_platform/0019_uptime_monitors_init.py +++ b/snuba/snuba_migrations/events_analytics_platform/0019_uptime_monitors_init.py @@ -1,4 +1,4 @@ -from typing import List, Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import UUID, Column, String, UInt from snuba.clusters.storage_sets import StorageSetKey @@ -12,7 +12,7 @@ local_table_name = f"{table_prefix}_local" dist_table_name = f"{table_prefix}_dist" -columns: List[Column[Modifiers]] = [ +columns: list[Column[Modifiers]] = [ Column("organization_id", UInt(64)), Column("project_id", UInt(64)), Column("environment", String(Modifiers(nullable=True, low_cardinality=True))), diff --git a/snuba/snuba_migrations/events_analytics_platform/0020_ourlogs_init.py b/snuba/snuba_migrations/events_analytics_platform/0020_ourlogs_init.py index 36b15aaac0f..fc032703927 100644 --- a/snuba/snuba_migrations/events_analytics_platform/0020_ourlogs_init.py +++ b/snuba/snuba_migrations/events_analytics_platform/0020_ourlogs_init.py @@ -1,4 +1,4 @@ -from typing import List, Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import UUID, Column, String, UInt from snuba.clusters.storage_sets import StorageSetKey @@ -12,7 +12,7 @@ dist_table_name = "ourlogs_dist" num_attr_buckets = 20 -columns: List[Column[Modifiers]] = [ +columns: list[Column[Modifiers]] = [ Column("organization_id", UInt(64)), Column("project_id", UInt(64)), Column("trace_id", UUID()), # optional diff --git a/snuba/snuba_migrations/events_analytics_platform/0021_ourlogs_attrs.py b/snuba/snuba_migrations/events_analytics_platform/0021_ourlogs_attrs.py index ee718c7081e..e9b4d78104d 100644 --- a/snuba/snuba_migrations/events_analytics_platform/0021_ourlogs_attrs.py +++ b/snuba/snuba_migrations/events_analytics_platform/0021_ourlogs_attrs.py @@ -1,4 +1,4 @@ -from typing import List, Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import UUID, Column, String, UInt from snuba.clusters.storage_sets import StorageSetKey @@ -11,7 +11,7 @@ local_table_name = "ourlogs_2_local" dist_table_name = "ourlogs_2_dist" -columns: List[Column[Modifiers]] = [ +columns: list[Column[Modifiers]] = [ Column("organization_id", UInt(64)), Column("project_id", UInt(64)), Column("trace_id", UUID()), # optional diff --git a/snuba/snuba_migrations/events_analytics_platform/0022_uptime_monitors_init_v2.py b/snuba/snuba_migrations/events_analytics_platform/0022_uptime_monitors_init_v2.py index e231e4c2438..7b600e2d1fc 100644 --- a/snuba/snuba_migrations/events_analytics_platform/0022_uptime_monitors_init_v2.py +++ b/snuba/snuba_migrations/events_analytics_platform/0022_uptime_monitors_init_v2.py @@ -1,4 +1,4 @@ -from typing import List, Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import UUID, Column, String, UInt from snuba.clusters.storage_sets import StorageSetKey @@ -12,7 +12,7 @@ local_table_name = f"{table_prefix}_local" dist_table_name = f"{table_prefix}_dist" -columns: List[Column[Modifiers]] = [ +columns: list[Column[Modifiers]] = [ Column("organization_id", UInt(64)), Column("project_id", UInt(64)), Column("environment", String(Modifiers(nullable=True, low_cardinality=True))), diff --git a/snuba/snuba_migrations/events_analytics_platform/0023_smart_autocomplete_mv.py b/snuba/snuba_migrations/events_analytics_platform/0023_smart_autocomplete_mv.py index 931619ddc96..837c0a4a6c1 100644 --- a/snuba/snuba_migrations/events_analytics_platform/0023_smart_autocomplete_mv.py +++ b/snuba/snuba_migrations/events_analytics_platform/0023_smart_autocomplete_mv.py @@ -1,4 +1,5 @@ -from typing import Any, List, Sequence +from collections.abc import Sequence +from typing import Any from snuba.clusters.storage_sets import StorageSetKey from snuba.migrations import migration, operations, table_engines @@ -22,7 +23,7 @@ _attr_columns = [Column(f"attrs_{type_name}", type_spec) for type_name, type_spec in _TYPES.items()] -columns: List[Column[Modifiers]] = [ +columns: list[Column[Modifiers]] = [ Column("project_id", UInt(64)), Column("item_type", String()), Column("date", Date(Modifiers(codecs=["DoubleDelta", "ZSTD(1)"]))), diff --git a/snuba/snuba_migrations/events_analytics_platform/0024_items.py b/snuba/snuba_migrations/events_analytics_platform/0024_items.py index 5e697f8acd2..23d763a8fc9 100644 --- a/snuba/snuba_migrations/events_analytics_platform/0024_items.py +++ b/snuba/snuba_migrations/events_analytics_platform/0024_items.py @@ -1,4 +1,4 @@ -from typing import List, Sequence +from collections.abc import Sequence from snuba.clusters.storage_sets import StorageSetKey from snuba.migrations import migration, operations, table_engines @@ -21,7 +21,7 @@ dist_table_name = "eap_items_1_dist" num_attr_buckets = 40 -columns: List[Column[Modifiers]] = [ +columns: list[Column[Modifiers]] = [ Column("organization_id", UInt(64)), Column("project_id", UInt(64)), Column("item_type", UInt(8)), @@ -96,7 +96,7 @@ class Migration(migration.ClickhouseNodeMigration): blocking = False def forwards_ops(self) -> Sequence[SqlOperation]: - res: List[SqlOperation] = [ + res: list[SqlOperation] = [ operations.CreateTable( storage_set=storage_set_name, table_name=local_table_name, diff --git a/snuba/snuba_migrations/events_analytics_platform/0025_smart_autocomplete_index.py b/snuba/snuba_migrations/events_analytics_platform/0025_smart_autocomplete_index.py index 533d0d98620..0a0f2af8fb8 100644 --- a/snuba/snuba_migrations/events_analytics_platform/0025_smart_autocomplete_index.py +++ b/snuba/snuba_migrations/events_analytics_platform/0025_smart_autocomplete_index.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clusters.storage_sets import StorageSetKey from snuba.datasets.storages.tags_hash_map import get_array_vals_hash @@ -8,7 +8,6 @@ class Migration(migration.ClickhouseNodeMigration): - blocking = False storage_set_key = StorageSetKey.EVENTS_ANALYTICS_PLATFORM granularity = "8192" @@ -30,9 +29,7 @@ def forwards_ops(self) -> Sequence[operations.SqlOperation]: name=self.str_hash_map_col, type=Array( UInt(64), - Modifiers( - materialized=get_array_vals_hash("mapKeys(attrs_string)") - ), + Modifiers(materialized=get_array_vals_hash("mapKeys(attrs_string)")), ), ), after="attrs_string", @@ -45,9 +42,7 @@ def forwards_ops(self) -> Sequence[operations.SqlOperation]: self.str_hash_map_col, type=Array( UInt(64), - Modifiers( - materialized=get_array_vals_hash("mapKeys(attrs_string)") - ), + Modifiers(materialized=get_array_vals_hash("mapKeys(attrs_string)")), ), ), after="attrs_string", diff --git a/snuba/snuba_migrations/events_analytics_platform/0026_items_add_attributes_hash_map.py b/snuba/snuba_migrations/events_analytics_platform/0026_items_add_attributes_hash_map.py index 03b5255b1c2..4365cf67edb 100644 --- a/snuba/snuba_migrations/events_analytics_platform/0026_items_add_attributes_hash_map.py +++ b/snuba/snuba_migrations/events_analytics_platform/0026_items_add_attributes_hash_map.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import Array, Column, UInt from snuba.clusters.storage_sets import StorageSetKey diff --git a/snuba/snuba_migrations/events_analytics_platform/0027_uptime_checks_add_column_in_incident.py b/snuba/snuba_migrations/events_analytics_platform/0027_uptime_checks_add_column_in_incident.py index a8a853826e5..5bc4ced3b9a 100644 --- a/snuba/snuba_migrations/events_analytics_platform/0027_uptime_checks_add_column_in_incident.py +++ b/snuba/snuba_migrations/events_analytics_platform/0027_uptime_checks_add_column_in_incident.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import Column, UInt from snuba.clusters.storage_sets import StorageSetKey diff --git a/snuba/snuba_migrations/events_analytics_platform/0028_ourlogs_v3.py b/snuba/snuba_migrations/events_analytics_platform/0028_ourlogs_v3.py index 8fba65bca0f..30da3a272a1 100644 --- a/snuba/snuba_migrations/events_analytics_platform/0028_ourlogs_v3.py +++ b/snuba/snuba_migrations/events_analytics_platform/0028_ourlogs_v3.py @@ -1,4 +1,4 @@ -from typing import List, Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import UUID, Column, String, UInt from snuba.clusters.storage_sets import StorageSetKey @@ -12,7 +12,7 @@ dist_table_name = "ourlogs_3_dist" num_attr_buckets = 20 -columns: List[Column[Modifiers]] = [ +columns: list[Column[Modifiers]] = [ Column("organization_id", UInt(64)), Column("project_id", UInt(64)), Column("id", UUID()), diff --git a/snuba/snuba_migrations/events_analytics_platform/0029_remove_smart_autocomplete_experimental.py b/snuba/snuba_migrations/events_analytics_platform/0029_remove_smart_autocomplete_experimental.py index fb49569af0a..7b2cfe0c2dc 100644 --- a/snuba/snuba_migrations/events_analytics_platform/0029_remove_smart_autocomplete_experimental.py +++ b/snuba/snuba_migrations/events_analytics_platform/0029_remove_smart_autocomplete_experimental.py @@ -1,4 +1,5 @@ -from typing import Any, List, Sequence +from collections.abc import Sequence +from typing import Any from snuba.clusters.storage_sets import StorageSetKey from snuba.migrations import migration, operations, table_engines @@ -22,7 +23,7 @@ _attr_columns = [Column(f"attrs_{type_name}", type_spec) for type_name, type_spec in _TYPES.items()] -columns: List[Column[Modifiers]] = [ +columns: list[Column[Modifiers]] = [ Column("project_id", UInt(64)), Column("item_type", String()), Column("date", Date(Modifiers(codecs=["DoubleDelta", "ZSTD(1)"]))), diff --git a/snuba/snuba_migrations/events_analytics_platform/0030_smart_autocomplete_items.py b/snuba/snuba_migrations/events_analytics_platform/0030_smart_autocomplete_items.py index b429ecacbc6..49f484af2cd 100644 --- a/snuba/snuba_migrations/events_analytics_platform/0030_smart_autocomplete_items.py +++ b/snuba/snuba_migrations/events_analytics_platform/0030_smart_autocomplete_items.py @@ -1,4 +1,4 @@ -from typing import List, Sequence +from collections.abc import Sequence from snuba.clusters.storage_sets import StorageSetKey from snuba.datasets.storages.tags_hash_map import get_array_vals_hash @@ -9,7 +9,7 @@ num_attr_buckets = 40 -columns: List[Column[Modifiers]] = [ +columns: list[Column[Modifiers]] = [ Column("organization_id", UInt(64)), Column("project_id", UInt(64)), Column("item_type", UInt(8)), @@ -23,9 +23,7 @@ Array( UInt(64), Modifiers( - materialized=get_array_vals_hash( - "arrayConcat(attributes_string, attributes_float)" - ) + materialized=get_array_vals_hash("arrayConcat(attributes_string, attributes_float)") ), ), ), @@ -45,12 +43,8 @@ ] -_attr_num_names = ", ".join( - [f"mapKeys(attributes_float_{i})" for i in range(num_attr_buckets)] -) -_attr_str_names = ", ".join( - [f"mapKeys(attributes_string_{i})" for i in range(num_attr_buckets)] -) +_attr_num_names = ", ".join([f"mapKeys(attributes_float_{i})" for i in range(num_attr_buckets)]) +_attr_str_names = ", ".join([f"mapKeys(attributes_string_{i})" for i in range(num_attr_buckets)]) MV_QUERY = f""" @@ -67,7 +61,6 @@ class Migration(migration.ClickhouseNodeMigration): - blocking = False storage_set_key = StorageSetKey.EVENTS_ANALYTICS_PLATFORM granularity = "8192" diff --git a/snuba/snuba_migrations/events_analytics_platform/0032_sampled_storage_views.py b/snuba/snuba_migrations/events_analytics_platform/0032_sampled_storage_views.py index 05fc0ab4a90..d99a8f1322d 100644 --- a/snuba/snuba_migrations/events_analytics_platform/0032_sampled_storage_views.py +++ b/snuba/snuba_migrations/events_analytics_platform/0032_sampled_storage_views.py @@ -1,4 +1,4 @@ -from typing import List, Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import Array, Column, UInt from snuba.clusters.storage_sets import StorageSetKey @@ -14,7 +14,7 @@ def hash_map_column_name(attribute_type: str, i: int) -> str: return f"_hash_map_{attribute_type}_{i}" -columns: List[Column[Modifiers]] = [ +columns: list[Column[Modifiers]] = [ Column("organization_id", UInt(64)), Column("project_id", UInt(64)), Column("item_type", UInt(8)), @@ -104,7 +104,6 @@ def hash_map_column_name(attribute_type: str, i: int) -> str: class Migration(migration.ClickhouseNodeMigration): - blocking = False storage_set_key = StorageSetKey.EVENTS_ANALYTICS_PLATFORM granularity = "8192" diff --git a/snuba/snuba_migrations/events_analytics_platform/0033_items_attribute_table_v1.py b/snuba/snuba_migrations/events_analytics_platform/0033_items_attribute_table_v1.py index e87604a25f4..183e4fec771 100644 --- a/snuba/snuba_migrations/events_analytics_platform/0033_items_attribute_table_v1.py +++ b/snuba/snuba_migrations/events_analytics_platform/0033_items_attribute_table_v1.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Sequence +from collections.abc import Sequence from snuba.clusters.storage_sets import StorageSetKey from snuba.migrations import migration diff --git a/snuba/snuba_migrations/events_analytics_platform/0034_materialize_sampled_storage_views.py b/snuba/snuba_migrations/events_analytics_platform/0034_materialize_sampled_storage_views.py index bb9946e55b1..7b5dda75467 100644 --- a/snuba/snuba_migrations/events_analytics_platform/0034_materialize_sampled_storage_views.py +++ b/snuba/snuba_migrations/events_analytics_platform/0034_materialize_sampled_storage_views.py @@ -1,4 +1,4 @@ -from typing import List, Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import Array, Column, UInt from snuba.clusters.storage_sets import StorageSetKey @@ -14,7 +14,7 @@ def hash_map_column_name(attribute_type: str, i: int) -> str: return f"_hash_map_{attribute_type}_{i}" -columns: List[Column[Modifiers]] = [ +columns: list[Column[Modifiers]] = [ Column("organization_id", UInt(64)), Column("project_id", UInt(64)), Column("item_type", UInt(8)), @@ -133,7 +133,6 @@ def get_mv_expr(sampling_weight: int) -> str: class Migration(migration.ClickhouseNodeMigration): - blocking = False storage_set_key = StorageSetKey.EVENTS_ANALYTICS_PLATFORM granularity = "8192" diff --git a/snuba/snuba_migrations/events_analytics_platform/0035_drop_item_attrs.py b/snuba/snuba_migrations/events_analytics_platform/0035_drop_item_attrs.py index 016addd09c5..d7621117e9d 100644 --- a/snuba/snuba_migrations/events_analytics_platform/0035_drop_item_attrs.py +++ b/snuba/snuba_migrations/events_analytics_platform/0035_drop_item_attrs.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Sequence +from collections.abc import Sequence from snuba.clusters.storage_sets import StorageSetKey from snuba.migrations import migration diff --git a/snuba/snuba_migrations/events_analytics_platform/0036_items_attribute_table_v1.py b/snuba/snuba_migrations/events_analytics_platform/0036_items_attribute_table_v1.py index b48e14ee758..30297bd70ce 100644 --- a/snuba/snuba_migrations/events_analytics_platform/0036_items_attribute_table_v1.py +++ b/snuba/snuba_migrations/events_analytics_platform/0036_items_attribute_table_v1.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Sequence +from collections.abc import Sequence from snuba.clusters.storage_sets import StorageSetKey from snuba.migrations import migration diff --git a/snuba/snuba_migrations/events_analytics_platform/0037_remove_items_attribute_mv_v1.py b/snuba/snuba_migrations/events_analytics_platform/0037_remove_items_attribute_mv_v1.py index c69f2e7ea60..e846d0ef3af 100644 --- a/snuba/snuba_migrations/events_analytics_platform/0037_remove_items_attribute_mv_v1.py +++ b/snuba/snuba_migrations/events_analytics_platform/0037_remove_items_attribute_mv_v1.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Sequence +from collections.abc import Sequence from snuba.clusters.storage_sets import StorageSetKey from snuba.migrations import migration diff --git a/snuba/snuba_migrations/events_analytics_platform/0038_eap_items_add_sampling_factor.py b/snuba/snuba_migrations/events_analytics_platform/0038_eap_items_add_sampling_factor.py index 1b2ba0ac3a8..f2ac5a3de48 100644 --- a/snuba/snuba_migrations/events_analytics_platform/0038_eap_items_add_sampling_factor.py +++ b/snuba/snuba_migrations/events_analytics_platform/0038_eap_items_add_sampling_factor.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import Column, Float from snuba.clusters.storage_sets import StorageSetKey @@ -40,9 +40,7 @@ def forwards_ops(self) -> Sequence[operations.SqlOperation]: ] for sampling_weight in self.sampling_weights: - downsampled_local_table_name = ( - f"eap_items_1_downsample_{sampling_weight}_local" - ) + downsampled_local_table_name = f"eap_items_1_downsample_{sampling_weight}_local" ops.append( operations.AddColumn( storage_set=storage_set_name, @@ -75,9 +73,7 @@ def backwards_ops(self) -> Sequence[operations.SqlOperation]: ] for sampling_weight in self.sampling_weights: - downsampled_local_table_name = ( - f"eap_items_1_downsample_{sampling_weight}_local" - ) + downsampled_local_table_name = f"eap_items_1_downsample_{sampling_weight}_local" ops.append( operations.DropColumn( storage_set=storage_set_name, diff --git a/snuba/snuba_migrations/events_analytics_platform/0039_update_mv_with_sampling_factor.py b/snuba/snuba_migrations/events_analytics_platform/0039_update_mv_with_sampling_factor.py index b8424287dbb..c622e61c391 100644 --- a/snuba/snuba_migrations/events_analytics_platform/0039_update_mv_with_sampling_factor.py +++ b/snuba/snuba_migrations/events_analytics_platform/0039_update_mv_with_sampling_factor.py @@ -1,5 +1,5 @@ +from collections.abc import Sequence from copy import deepcopy -from typing import List, Sequence from snuba.clickhouse.columns import Array, Column, UInt from snuba.clusters.storage_sets import StorageSetKey @@ -15,7 +15,7 @@ def hash_map_column_name(attribute_type: str, i: int) -> str: return f"_hash_map_{attribute_type}_{i}" -columns: List[Column[Modifiers]] = [ +columns: list[Column[Modifiers]] = [ Column("organization_id", UInt(64)), Column("project_id", UInt(64)), Column("item_type", UInt(8)), @@ -151,7 +151,6 @@ def get_mv_expr_sampling_factor(sampling_weight: int) -> str: class Migration(migration.ClickhouseNodeMigration): - blocking = False storage_set_key = StorageSetKey.EVENTS_ANALYTICS_PLATFORM granularity = "8192" @@ -159,7 +158,7 @@ class Migration(migration.ClickhouseNodeMigration): sampling_weights = [8, 8**2, 8**3] def forwards_ops(self) -> Sequence[SqlOperation]: - ops: List[SqlOperation] = [] + ops: list[SqlOperation] = [] for sampling_weight in self.sampling_weights: local_table_name = f"eap_items_1_downsample_{sampling_weight}_local" mv_name = f"eap_items_1_downsample_{sampling_weight}_mv_2" @@ -191,7 +190,7 @@ def forwards_ops(self) -> Sequence[SqlOperation]: return ops def backwards_ops(self) -> Sequence[SqlOperation]: - ops: List[SqlOperation] = [] + ops: list[SqlOperation] = [] for sampling_weight in self.sampling_weights: local_table_name = f"eap_items_1_downsample_{sampling_weight}_local" mv_name = f"eap_items_1_downsample_{sampling_weight}_mv" diff --git a/snuba/snuba_migrations/events_analytics_platform/0040_eap_items_downsampled_dist_add_sampling_factor.py b/snuba/snuba_migrations/events_analytics_platform/0040_eap_items_downsampled_dist_add_sampling_factor.py index bfc44d57982..5bce22bb5b9 100644 --- a/snuba/snuba_migrations/events_analytics_platform/0040_eap_items_downsampled_dist_add_sampling_factor.py +++ b/snuba/snuba_migrations/events_analytics_platform/0040_eap_items_downsampled_dist_add_sampling_factor.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import Column, Float from snuba.clusters.storage_sets import StorageSetKey @@ -18,9 +18,7 @@ class Migration(migration.ClickhouseNodeMigration): def forwards_ops(self) -> Sequence[operations.SqlOperation]: ops = [] for sampling_weight in self.sampling_weights: - downsampled_dist_table_name = ( - f"eap_items_1_downsample_{sampling_weight}_dist" - ) + downsampled_dist_table_name = f"eap_items_1_downsample_{sampling_weight}_dist" ops.append( operations.AddColumn( storage_set=storage_set_name, @@ -40,9 +38,7 @@ def backwards_ops(self) -> Sequence[operations.SqlOperation]: ops = [] for sampling_weight in self.sampling_weights: - downsampled_dist_table_name = ( - f"eap_items_1_downsample_{sampling_weight}_dist" - ) + downsampled_dist_table_name = f"eap_items_1_downsample_{sampling_weight}_dist" ops.append( operations.DropColumn( storage_set=storage_set_name, diff --git a/snuba/snuba_migrations/events_analytics_platform/0041_hashed_attributes_index.py b/snuba/snuba_migrations/events_analytics_platform/0041_hashed_attributes_index.py index 0dddc1ecca6..05e5f8eb122 100644 --- a/snuba/snuba_migrations/events_analytics_platform/0041_hashed_attributes_index.py +++ b/snuba/snuba_migrations/events_analytics_platform/0041_hashed_attributes_index.py @@ -1,4 +1,4 @@ -from typing import List, Sequence +from collections.abc import Sequence from snuba.clusters.storage_sets import StorageSetKey from snuba.migrations import migration, operations @@ -10,12 +10,8 @@ def get_hashed_attributes_column_expression() -> str: column_expressions = [] for i in range(buckets): - hashed_keys_string = ( - f"arrayMap(k -> cityHash64(k), mapKeys(attributes_string_{i}))" - ) - hashed_keys_float = ( - f"arrayMap(k -> cityHash64(k), mapKeys(attributes_float_{i}))" - ) + hashed_keys_string = f"arrayMap(k -> cityHash64(k), mapKeys(attributes_string_{i}))" + hashed_keys_float = f"arrayMap(k -> cityHash64(k), mapKeys(attributes_float_{i}))" column_expressions.append(hashed_keys_string) column_expressions.append(hashed_keys_float) @@ -23,7 +19,6 @@ def get_hashed_attributes_column_expression() -> str: class Migration(migration.ClickhouseNodeMigration): - blocking = False storage_set_key = StorageSetKey.EVENTS_ANALYTICS_PLATFORM granularity = "8192" @@ -32,7 +27,7 @@ class Migration(migration.ClickhouseNodeMigration): dist_table_name = "eap_items_1_dist" def forwards_ops(self) -> Sequence[operations.SqlOperation]: - ops: List[operations.SqlOperation] = [ + ops: list[operations.SqlOperation] = [ operations.RunSql( storage_set=self.storage_set_key, statement=f"ALTER TABLE {self.local_table_name} ADD COLUMN IF NOT EXISTS hashed_keys Array(UInt64) MATERIALIZED {get_hashed_attributes_column_expression()}", @@ -57,7 +52,7 @@ def forwards_ops(self) -> Sequence[operations.SqlOperation]: return ops def backwards_ops(self) -> Sequence[operations.SqlOperation]: - ops: List[operations.SqlOperation] = [ + ops: list[operations.SqlOperation] = [ operations.DropIndex( storage_set=self.storage_set_key, table_name=self.local_table_name, diff --git a/snuba/snuba_migrations/events_analytics_platform/0042_remove_hashed_columns.py b/snuba/snuba_migrations/events_analytics_platform/0042_remove_hashed_columns.py index 3edd1501d98..54fe0b5229b 100644 --- a/snuba/snuba_migrations/events_analytics_platform/0042_remove_hashed_columns.py +++ b/snuba/snuba_migrations/events_analytics_platform/0042_remove_hashed_columns.py @@ -1,4 +1,5 @@ -from typing import Any, List, Sequence +from collections.abc import Sequence +from typing import Any from snuba.clusters.storage_sets import StorageSetKey from snuba.migrations import migration, operations @@ -24,7 +25,7 @@ def hash_map_column_name(attribute_type: str, i: int) -> str: return f"_hash_map_{attribute_type}_{i}" -base_columns: List[Column[Modifiers]] = [ +base_columns: list[Column[Modifiers]] = [ Column("organization_id", UInt(64)), Column("project_id", UInt(64)), Column("item_type", UInt(8)), @@ -89,9 +90,7 @@ def hash_map_column_name(attribute_type: str, i: int) -> str: def get_mv_expr(sampling_weight: int, with_hash_map_columns: bool = True) -> str: column_names = [ - c.name - for c in base_columns - if c.name not in {"sampling_weight", "sampling_factor"} + c.name for c in base_columns if c.name not in {"sampling_weight", "sampling_factor"} ] if with_hash_map_columns: column_names.extend([c.name for c in hash_map_columns]) @@ -109,7 +108,7 @@ class Migration(migration.ClickhouseNodeMigration): dist_table_name = "eap_items_1_dist" def forwards_ops(self) -> Sequence[operations.SqlOperation]: - ops: List[operations.SqlOperation] = [] + ops: list[operations.SqlOperation] = [] ops.append( operations.DropIndex( storage_set=self.storage_set_key, @@ -139,7 +138,7 @@ def forwards_ops(self) -> Sequence[operations.SqlOperation]: return ops def backwards_ops(self) -> Sequence[operations.SqlOperation]: - ops: List[operations.SqlOperation] = [] + ops: list[operations.SqlOperation] = [] for downsampled_factor in self.downsampled_factors: ops.append( operations.CreateMaterializedView( diff --git a/snuba/snuba_migrations/events_analytics_platform/0047_use_downsampled_retention_for_downsampled_tables.py b/snuba/snuba_migrations/events_analytics_platform/0047_use_downsampled_retention_for_downsampled_tables.py index 463e19f7ade..da838ba8ac9 100644 --- a/snuba/snuba_migrations/events_analytics_platform/0047_use_downsampled_retention_for_downsampled_tables.py +++ b/snuba/snuba_migrations/events_analytics_platform/0047_use_downsampled_retention_for_downsampled_tables.py @@ -1,4 +1,4 @@ -from typing import List, Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import Column, UInt from snuba.clusters.storage_sets import StorageSetKey @@ -13,7 +13,7 @@ old_version = 2 new_version = 3 -columns: List[Column[Modifiers]] = [ +columns: list[Column[Modifiers]] = [ Column("organization_id", UInt(64)), Column("project_id", UInt(64)), Column("item_type", UInt(8)), @@ -78,11 +78,7 @@ def generate_old_materialized_view_expression(sampling_weight: int) -> str: column_names_str = ", ".join( - [ - c.name - for c in columns - if c.name != "sampling_weight" and c.name != "sampling_factor" - ] + [c.name for c in columns if c.name != "sampling_weight" and c.name != "sampling_factor"] ) return f"SELECT {column_names_str}, sampling_weight * {sampling_weight} AS sampling_weight, sampling_factor / {sampling_weight} AS sampling_factor FROM eap_items_1_local WHERE (cityHash64(item_id + {sampling_weight}) % {sampling_weight}) = 0" @@ -104,7 +100,7 @@ class Migration(migration.ClickhouseNodeMigration): blocking = False def forwards_ops(self) -> Sequence[SqlOperation]: - ops: List[SqlOperation] = [] + ops: list[SqlOperation] = [] for sampling_weight in sampling_weights: local_table_name = f"eap_items_1_downsample_{sampling_weight}_local" @@ -117,9 +113,7 @@ def forwards_ops(self) -> Sequence[SqlOperation]: columns=columns, destination_table_name=local_table_name, target=OperationTarget.LOCAL, - query=generate_new_materialized_view_expression( - sampling_weight - ), + query=generate_new_materialized_view_expression(sampling_weight), ), operations.DropTable( storage_set=storage_set_key, @@ -132,7 +126,7 @@ def forwards_ops(self) -> Sequence[SqlOperation]: return ops def backwards_ops(self) -> Sequence[SqlOperation]: - ops: List[SqlOperation] = [] + ops: list[SqlOperation] = [] for sampling_weight in sampling_weights: local_table_name = f"eap_items_1_downsample_{sampling_weight}_local" @@ -145,9 +139,7 @@ def backwards_ops(self) -> Sequence[SqlOperation]: columns=columns, destination_table_name=local_table_name, target=OperationTarget.LOCAL, - query=generate_old_materialized_view_expression( - sampling_weight - ), + query=generate_old_materialized_view_expression(sampling_weight), ), operations.DropTable( storage_set=storage_set_key, diff --git a/snuba/snuba_migrations/events_analytics_platform/0049_use_client_and_server_sample_rates_in_materialized_views.py b/snuba/snuba_migrations/events_analytics_platform/0049_use_client_and_server_sample_rates_in_materialized_views.py index bd79dc25e9b..7e05ee2dd2e 100644 --- a/snuba/snuba_migrations/events_analytics_platform/0049_use_client_and_server_sample_rates_in_materialized_views.py +++ b/snuba/snuba_migrations/events_analytics_platform/0049_use_client_and_server_sample_rates_in_materialized_views.py @@ -1,4 +1,4 @@ -from typing import List, Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import Column, UInt from snuba.clusters.storage_sets import StorageSetKey @@ -13,7 +13,7 @@ old_version = 3 new_version = old_version + 1 -columns: List[Column[Modifiers]] = [ +columns: list[Column[Modifiers]] = [ Column("organization_id", UInt(64)), Column("project_id", UInt(64)), Column("item_type", UInt(8)), @@ -122,7 +122,7 @@ class Migration(migration.ClickhouseNodeMigration): blocking = False def forwards_ops(self) -> Sequence[SqlOperation]: - ops: List[SqlOperation] = [] + ops: list[SqlOperation] = [] for sampling_weight in sampling_weights: local_table_name = f"eap_items_1_downsample_{sampling_weight}_local" @@ -148,7 +148,7 @@ def forwards_ops(self) -> Sequence[SqlOperation]: return ops def backwards_ops(self) -> Sequence[SqlOperation]: - ops: List[SqlOperation] = [] + ops: list[SqlOperation] = [] for sampling_weight in sampling_weights: local_table_name = f"eap_items_1_downsample_{sampling_weight}_local" diff --git a/snuba/snuba_migrations/events_analytics_platform/0051_add_bool_keys_to_autocomplete.py b/snuba/snuba_migrations/events_analytics_platform/0051_add_bool_keys_to_autocomplete.py index 84852f2085f..f97e0fa5878 100644 --- a/snuba/snuba_migrations/events_analytics_platform/0051_add_bool_keys_to_autocomplete.py +++ b/snuba/snuba_migrations/events_analytics_platform/0051_add_bool_keys_to_autocomplete.py @@ -1,4 +1,4 @@ -from typing import List, Sequence +from collections.abc import Sequence from snuba.clusters.storage_sets import StorageSetKey from snuba.datasets.storages.tags_hash_map import get_array_vals_hash @@ -9,7 +9,7 @@ num_attr_buckets = 40 -columns: List[Column[Modifiers]] = [ +columns: list[Column[Modifiers]] = [ Column("organization_id", UInt(64)), Column("project_id", UInt(64)), Column("item_type", UInt(8)), @@ -72,7 +72,7 @@ class Migration(migration.ClickhouseNodeMigration): new_mv_name = "eap_item_co_occurring_attrs_2_mv" def forwards_ops(self) -> Sequence[SqlOperation]: - ops: List[SqlOperation] = [] + ops: list[SqlOperation] = [] # Add bool_attribute_keys column to local table ops.append( @@ -132,7 +132,7 @@ def backwards_ops(self) -> Sequence[SqlOperation]: arrayConcat({_attr_num_names}) AS attributes_float FROM eap_items_1_local """ - old_columns: List[Column[Modifiers]] = [ + old_columns: list[Column[Modifiers]] = [ Column("organization_id", UInt(64)), Column("project_id", UInt(64)), Column("item_type", UInt(8)), @@ -165,7 +165,7 @@ def backwards_ops(self) -> Sequence[SqlOperation]: ), ] - ops: List[SqlOperation] = [] + ops: list[SqlOperation] = [] # Recreate old MV ops.append( diff --git a/snuba/snuba_migrations/events_analytics_platform/0052_create_deletes_workload.py b/snuba/snuba_migrations/events_analytics_platform/0052_create_deletes_workload.py index df9c18da7ee..3d3f93df149 100644 --- a/snuba/snuba_migrations/events_analytics_platform/0052_create_deletes_workload.py +++ b/snuba/snuba_migrations/events_analytics_platform/0052_create_deletes_workload.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clusters.storage_sets import StorageSetKey from snuba.migrations import migration, operations diff --git a/snuba/snuba_migrations/events_analytics_platform/0053_alter_deletes_workload_max_threads.py b/snuba/snuba_migrations/events_analytics_platform/0053_alter_deletes_workload_max_threads.py index d6e8a1761bc..779847ca516 100644 --- a/snuba/snuba_migrations/events_analytics_platform/0053_alter_deletes_workload_max_threads.py +++ b/snuba/snuba_migrations/events_analytics_platform/0053_alter_deletes_workload_max_threads.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clusters.storage_sets import StorageSetKey from snuba.migrations import migration, operations diff --git a/snuba/snuba_migrations/events_analytics_platform/0054_fix_bools_in_autocomplete.py b/snuba/snuba_migrations/events_analytics_platform/0054_fix_bools_in_autocomplete.py index 7f550f1698a..be2c7650fa1 100644 --- a/snuba/snuba_migrations/events_analytics_platform/0054_fix_bools_in_autocomplete.py +++ b/snuba/snuba_migrations/events_analytics_platform/0054_fix_bools_in_autocomplete.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clusters.storage_sets import StorageSetKey from snuba.migrations import migration, operations diff --git a/snuba/snuba_migrations/events_analytics_platform/0055_fix_attribute_keys_hash_missing_bool.py b/snuba/snuba_migrations/events_analytics_platform/0055_fix_attribute_keys_hash_missing_bool.py index 145f14b4b64..286d130e0e1 100644 --- a/snuba/snuba_migrations/events_analytics_platform/0055_fix_attribute_keys_hash_missing_bool.py +++ b/snuba/snuba_migrations/events_analytics_platform/0055_fix_attribute_keys_hash_missing_bool.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clusters.storage_sets import StorageSetKey from snuba.datasets.storages.tags_hash_map import get_array_vals_hash diff --git a/snuba/snuba_migrations/events_analytics_platform/0056_eap_items_dist_ro.py b/snuba/snuba_migrations/events_analytics_platform/0056_eap_items_dist_ro.py index 9c49291da12..6ac647c9aa6 100644 --- a/snuba/snuba_migrations/events_analytics_platform/0056_eap_items_dist_ro.py +++ b/snuba/snuba_migrations/events_analytics_platform/0056_eap_items_dist_ro.py @@ -1,4 +1,4 @@ -from typing import Optional, Sequence +from collections.abc import Sequence from snuba.clusters.cluster import get_cluster from snuba.clusters.storage_sets import StorageSetKey @@ -8,7 +8,7 @@ storage_set = StorageSetKey.EVENTS_ANALYTICS_PLATFORM_RO # (new_dist_ro_table, source_dist_table, local_table, sharding_key) -DIST_RO_TABLES: list[tuple[str, str, str, Optional[str]]] = [ +DIST_RO_TABLES: list[tuple[str, str, str, str | None]] = [ ( "eap_items_1_dist_ro", "eap_items_1_dist", @@ -54,7 +54,7 @@ def _create_dist_ro_sql( new_table: str, source_table: str, local_table: str, - sharding_key: Optional[str], + sharding_key: str | None, ) -> str: cluster = get_cluster(storage_set) cluster_name = cluster.get_clickhouse_cluster_name() diff --git a/snuba/snuba_migrations/functions/0001_functions.py b/snuba/snuba_migrations/functions/0001_functions.py index 870db4c38e6..33871bf0899 100644 --- a/snuba/snuba_migrations/functions/0001_functions.py +++ b/snuba/snuba_migrations/functions/0001_functions.py @@ -1,4 +1,4 @@ -from typing import List, MutableMapping, Optional, Sequence, Union +from collections.abc import MutableMapping, Sequence from snuba.clickhouse.columns import ( UUID, @@ -15,7 +15,7 @@ from snuba.migrations import migration, migration_utilities, operations, table_engines from snuba.migrations.columns import MigrationModifiers as Modifiers -common_columns: List[Column[Modifiers]] = [ +common_columns: list[Column[Modifiers]] = [ Column("project_id", UInt(64)), Column("transaction_name", String()), Column("timestamp", DateTime()), @@ -34,13 +34,13 @@ Column("retention_days", UInt(16)), ] -raw_columns: List[Column[Modifiers]] = common_columns + [ +raw_columns: list[Column[Modifiers]] = common_columns + [ Column("durations", Array(Float(64))), Column("profile_id", UUID()), Column("materialization_version", UInt(8)), ] -agg_columns: List[Column[Modifiers]] = common_columns + [ +agg_columns: list[Column[Modifiers]] = common_columns + [ Column("count", AggregateFunction("count", [Float(64)])), Column( "percentiles", @@ -71,18 +71,16 @@ class Migration(migration.CodeMigration): local_view_table = "functions_local" def _create_functions_mv_table( - self, clickhouse: Optional[ClickhousePool] + self, clickhouse: ClickhousePool | None ) -> operations.SqlOperation: - table_settings: MutableMapping[str, Union[int, str]] = { + table_settings: MutableMapping[str, int | str] = { "index_granularity": self.index_granularity, } clickhouse_version = migration_utilities.get_clickhouse_version_for_storage_set( self.storage_set, clickhouse ) - if migration_utilities.supports_setting( - clickhouse_version, "allow_nullable_key" - ): + if migration_utilities.supports_setting(clickhouse_version, "allow_nullable_key"): table_settings["allow_nullable_key"] = 1 return operations.CreateTable( diff --git a/snuba/snuba_migrations/functions/0002_add_new_columns_to_raw_functions.py b/snuba/snuba_migrations/functions/0002_add_new_columns_to_raw_functions.py index 6f95508bc39..af93840f1eb 100644 --- a/snuba/snuba_migrations/functions/0002_add_new_columns_to_raw_functions.py +++ b/snuba/snuba_migrations/functions/0002_add_new_columns_to_raw_functions.py @@ -1,5 +1,5 @@ +from collections.abc import Sequence from dataclasses import dataclass -from typing import Sequence from snuba.clickhouse.columns import Column, String, UInt from snuba.clusters.storage_sets import StorageSetKey @@ -28,21 +28,15 @@ class NewColumn: after="dist", ), NewColumn( - column=Column( - "transaction_status", UInt(8, Modifiers(default=str(UNKNOWN_SPAN_STATUS))) - ), + column=Column("transaction_status", UInt(8, Modifiers(default=str(UNKNOWN_SPAN_STATUS)))), after="transaction_op", ), NewColumn( - column=Column( - "http_method", String(Modifiers(nullable=True, low_cardinality=True)) - ), + column=Column("http_method", String(Modifiers(nullable=True, low_cardinality=True))), after="transaction_status", ), NewColumn( - column=Column( - "browser_name", String(Modifiers(nullable=True, low_cardinality=True)) - ), + column=Column("browser_name", String(Modifiers(nullable=True, low_cardinality=True))), after="http_method", ), NewColumn( diff --git a/snuba/snuba_migrations/functions/0003_add_new_columns_to_raw_functions.py b/snuba/snuba_migrations/functions/0003_add_new_columns_to_raw_functions.py index 3ebc2fba18e..016f2a7b11d 100644 --- a/snuba/snuba_migrations/functions/0003_add_new_columns_to_raw_functions.py +++ b/snuba/snuba_migrations/functions/0003_add_new_columns_to_raw_functions.py @@ -1,5 +1,5 @@ +from collections.abc import Sequence from dataclasses import dataclass -from typing import Sequence from snuba.clickhouse.columns import Column, DateTime64, String from snuba.clusters.storage_sets import StorageSetKey diff --git a/snuba/snuba_migrations/functions/0004_functions_v2.py b/snuba/snuba_migrations/functions/0004_functions_v2.py index e1d456d6050..0538a813cb0 100644 --- a/snuba/snuba_migrations/functions/0004_functions_v2.py +++ b/snuba/snuba_migrations/functions/0004_functions_v2.py @@ -1,5 +1,5 @@ +from collections.abc import Sequence from dataclasses import dataclass -from typing import List, Sequence from snuba.clickhouse.columns import ( UUID, @@ -24,7 +24,7 @@ class NewColumn: after: str | None -columns: List[Column[Modifiers]] = [ +columns: list[Column[Modifiers]] = [ Column("project_id", UInt(64)), Column("transaction_name", String()), Column("timestamp", DateTime()), diff --git a/snuba/snuba_migrations/generic_metrics/0001_sets_aggregate_table.py b/snuba/snuba_migrations/generic_metrics/0001_sets_aggregate_table.py index 87f8a909770..a4cb62ad02b 100644 --- a/snuba/snuba_migrations/generic_metrics/0001_sets_aggregate_table.py +++ b/snuba/snuba_migrations/generic_metrics/0001_sets_aggregate_table.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import ( AggregateFunction, diff --git a/snuba/snuba_migrations/generic_metrics/0002_sets_raw_table.py b/snuba/snuba_migrations/generic_metrics/0002_sets_raw_table.py index 79336f1d0dc..ecfcd49aaec 100644 --- a/snuba/snuba_migrations/generic_metrics/0002_sets_raw_table.py +++ b/snuba/snuba_migrations/generic_metrics/0002_sets_raw_table.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import ( Array, diff --git a/snuba/snuba_migrations/generic_metrics/0003_sets_mv.py b/snuba/snuba_migrations/generic_metrics/0003_sets_mv.py index 895576249b3..853bc8879d1 100644 --- a/snuba/snuba_migrations/generic_metrics/0003_sets_mv.py +++ b/snuba/snuba_migrations/generic_metrics/0003_sets_mv.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import ( AggregateFunction, diff --git a/snuba/snuba_migrations/generic_metrics/0004_sets_raw_add_granularities.py b/snuba/snuba_migrations/generic_metrics/0004_sets_raw_add_granularities.py index 3455cb11cb6..acf0c6500f0 100644 --- a/snuba/snuba_migrations/generic_metrics/0004_sets_raw_add_granularities.py +++ b/snuba/snuba_migrations/generic_metrics/0004_sets_raw_add_granularities.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clusters.storage_sets import StorageSetKey from snuba.migrations import migration, operations diff --git a/snuba/snuba_migrations/generic_metrics/0005_sets_replace_mv.py b/snuba/snuba_migrations/generic_metrics/0005_sets_replace_mv.py index 6c32d859d6b..6a33427ed8a 100644 --- a/snuba/snuba_migrations/generic_metrics/0005_sets_replace_mv.py +++ b/snuba/snuba_migrations/generic_metrics/0005_sets_replace_mv.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import ( AggregateFunction, diff --git a/snuba/snuba_migrations/generic_metrics/0006_sets_raw_add_granularities_dist_table.py b/snuba/snuba_migrations/generic_metrics/0006_sets_raw_add_granularities_dist_table.py index 3f1b0650c4e..cadc051d5ac 100644 --- a/snuba/snuba_migrations/generic_metrics/0006_sets_raw_add_granularities_dist_table.py +++ b/snuba/snuba_migrations/generic_metrics/0006_sets_raw_add_granularities_dist_table.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clusters.storage_sets import StorageSetKey from snuba.migrations import migration, operations diff --git a/snuba/snuba_migrations/generic_metrics/0007_distributions_aggregate_table.py b/snuba/snuba_migrations/generic_metrics/0007_distributions_aggregate_table.py index 095bb2d3700..5e24aa279dc 100644 --- a/snuba/snuba_migrations/generic_metrics/0007_distributions_aggregate_table.py +++ b/snuba/snuba_migrations/generic_metrics/0007_distributions_aggregate_table.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import ( AggregateFunction, diff --git a/snuba/snuba_migrations/generic_metrics/0008_distributions_raw_table.py b/snuba/snuba_migrations/generic_metrics/0008_distributions_raw_table.py index 4b28e16ed35..9013e571c0b 100644 --- a/snuba/snuba_migrations/generic_metrics/0008_distributions_raw_table.py +++ b/snuba/snuba_migrations/generic_metrics/0008_distributions_raw_table.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import ( Array, diff --git a/snuba/snuba_migrations/generic_metrics/0009_distributions_mv.py b/snuba/snuba_migrations/generic_metrics/0009_distributions_mv.py index 40d43ccddee..6604e3ac37f 100644 --- a/snuba/snuba_migrations/generic_metrics/0009_distributions_mv.py +++ b/snuba/snuba_migrations/generic_metrics/0009_distributions_mv.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import ( AggregateFunction, diff --git a/snuba/snuba_migrations/generic_metrics/0010_counters_aggregate_table.py b/snuba/snuba_migrations/generic_metrics/0010_counters_aggregate_table.py index 68b2759a8f7..5fd26b265e8 100644 --- a/snuba/snuba_migrations/generic_metrics/0010_counters_aggregate_table.py +++ b/snuba/snuba_migrations/generic_metrics/0010_counters_aggregate_table.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import ( AggregateFunction, diff --git a/snuba/snuba_migrations/generic_metrics/0011_counters_raw_table.py b/snuba/snuba_migrations/generic_metrics/0011_counters_raw_table.py index a26a5d40651..c5737fffc3c 100644 --- a/snuba/snuba_migrations/generic_metrics/0011_counters_raw_table.py +++ b/snuba/snuba_migrations/generic_metrics/0011_counters_raw_table.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import ( Array, diff --git a/snuba/snuba_migrations/generic_metrics/0012_counters_mv.py b/snuba/snuba_migrations/generic_metrics/0012_counters_mv.py index ee24c913a6d..ed396599949 100644 --- a/snuba/snuba_migrations/generic_metrics/0012_counters_mv.py +++ b/snuba/snuba_migrations/generic_metrics/0012_counters_mv.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import ( AggregateFunction, diff --git a/snuba/snuba_migrations/generic_metrics/0013_distributions_dist_tags_hash.py b/snuba/snuba_migrations/generic_metrics/0013_distributions_dist_tags_hash.py index 4ee30cb1325..21fbe21e11a 100644 --- a/snuba/snuba_migrations/generic_metrics/0013_distributions_dist_tags_hash.py +++ b/snuba/snuba_migrations/generic_metrics/0013_distributions_dist_tags_hash.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import Array, Column, UInt from snuba.clusters.storage_sets import StorageSetKey diff --git a/snuba/snuba_migrations/generic_metrics/0014_distribution_add_options.py b/snuba/snuba_migrations/generic_metrics/0014_distribution_add_options.py index a91f11e7a46..5b2afd446fa 100644 --- a/snuba/snuba_migrations/generic_metrics/0014_distribution_add_options.py +++ b/snuba/snuba_migrations/generic_metrics/0014_distribution_add_options.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import Column, UInt from snuba.clusters.storage_sets import StorageSetKey @@ -17,34 +17,34 @@ class Migration(migration.ClickhouseNodeMigration): columns = [ ( - Column("enable_histogram", UInt(8, MigrationModifiers(default=str("0")))), + Column("enable_histogram", UInt(8, MigrationModifiers(default="0"))), "granularities", ), ( Column( "decasecond_retention_days", - UInt(8, MigrationModifiers(default=str("7"))), + UInt(8, MigrationModifiers(default="7")), ), "enable_histogram", ), ( Column( "min_retention_days", - UInt(8, MigrationModifiers(default=str("30"))), + UInt(8, MigrationModifiers(default="30")), ), "decasecond_retention_days", ), ( Column( "hr_retention_days", - UInt(8, MigrationModifiers(default=str("retention_days"))), + UInt(8, MigrationModifiers(default="retention_days")), ), "min_retention_days", ), ( Column( "day_retention_days", - UInt(8, MigrationModifiers(default=str("retention_days"))), + UInt(8, MigrationModifiers(default="retention_days")), ), "hr_retention_days", ), diff --git a/snuba/snuba_migrations/generic_metrics/0015_sets_add_options.py b/snuba/snuba_migrations/generic_metrics/0015_sets_add_options.py index 253b69ab685..c5928569d3a 100644 --- a/snuba/snuba_migrations/generic_metrics/0015_sets_add_options.py +++ b/snuba/snuba_migrations/generic_metrics/0015_sets_add_options.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import Column, UInt from snuba.clusters.storage_sets import StorageSetKey @@ -19,28 +19,28 @@ class Migration(migration.ClickhouseNodeMigration): ( Column( "decasecond_retention_days", - UInt(8, MigrationModifiers(default=str("retention_days"))), + UInt(8, MigrationModifiers(default="retention_days")), ), "granularities", ), ( Column( "min_retention_days", - UInt(8, MigrationModifiers(default=str("retention_days"))), + UInt(8, MigrationModifiers(default="retention_days")), ), "decasecond_retention_days", ), ( Column( "hr_retention_days", - UInt(8, MigrationModifiers(default=str("retention_days"))), + UInt(8, MigrationModifiers(default="retention_days")), ), "min_retention_days", ), ( Column( "day_retention_days", - UInt(8, MigrationModifiers(default=str("retention_days"))), + UInt(8, MigrationModifiers(default="retention_days")), ), "hr_retention_days", ), diff --git a/snuba/snuba_migrations/generic_metrics/0016_counters_add_options.py b/snuba/snuba_migrations/generic_metrics/0016_counters_add_options.py index e84aa03987e..0cdec006a82 100644 --- a/snuba/snuba_migrations/generic_metrics/0016_counters_add_options.py +++ b/snuba/snuba_migrations/generic_metrics/0016_counters_add_options.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import Column, UInt from snuba.clusters.storage_sets import StorageSetKey @@ -19,28 +19,28 @@ class Migration(migration.ClickhouseNodeMigration): ( Column( "decasecond_retention_days", - UInt(8, MigrationModifiers(default=str("retention_days"))), + UInt(8, MigrationModifiers(default="retention_days")), ), "granularities", ), ( Column( "min_retention_days", - UInt(8, MigrationModifiers(default=str("retention_days"))), + UInt(8, MigrationModifiers(default="retention_days")), ), "decasecond_retention_days", ), ( Column( "hr_retention_days", - UInt(8, MigrationModifiers(default=str("retention_days"))), + UInt(8, MigrationModifiers(default="retention_days")), ), "min_retention_days", ), ( Column( "day_retention_days", - UInt(8, MigrationModifiers(default=str("retention_days"))), + UInt(8, MigrationModifiers(default="retention_days")), ), "hr_retention_days", ), diff --git a/snuba/snuba_migrations/generic_metrics/0017_distributions_mv2.py b/snuba/snuba_migrations/generic_metrics/0017_distributions_mv2.py index 6d43666a455..dc2d646627f 100644 --- a/snuba/snuba_migrations/generic_metrics/0017_distributions_mv2.py +++ b/snuba/snuba_migrations/generic_metrics/0017_distributions_mv2.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import ( AggregateFunction, diff --git a/snuba/snuba_migrations/generic_metrics/0018_sets_update_opt_default.py b/snuba/snuba_migrations/generic_metrics/0018_sets_update_opt_default.py index d1651b8e520..fa732a9bd7e 100644 --- a/snuba/snuba_migrations/generic_metrics/0018_sets_update_opt_default.py +++ b/snuba/snuba_migrations/generic_metrics/0018_sets_update_opt_default.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import Column, UInt from snuba.clusters.storage_sets import StorageSetKey @@ -18,22 +18,22 @@ class Migration(migration.ClickhouseNodeMigration): before = [ Column( "decasecond_retention_days", - UInt(8, MigrationModifiers(default=str("retention_days"))), + UInt(8, MigrationModifiers(default="retention_days")), ), Column( "min_retention_days", - UInt(8, MigrationModifiers(default=str("retention_days"))), + UInt(8, MigrationModifiers(default="retention_days")), ), ] after = [ Column( "decasecond_retention_days", - UInt(8, MigrationModifiers(default=str("7"))), + UInt(8, MigrationModifiers(default="7")), ), Column( "min_retention_days", - UInt(8, MigrationModifiers(default=str("30"))), + UInt(8, MigrationModifiers(default="30")), ), ] diff --git a/snuba/snuba_migrations/generic_metrics/0019_counters_update_opt_default.py b/snuba/snuba_migrations/generic_metrics/0019_counters_update_opt_default.py index 07fdd132f92..b855af072d6 100644 --- a/snuba/snuba_migrations/generic_metrics/0019_counters_update_opt_default.py +++ b/snuba/snuba_migrations/generic_metrics/0019_counters_update_opt_default.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import Column, UInt from snuba.clusters.storage_sets import StorageSetKey @@ -18,22 +18,22 @@ class Migration(migration.ClickhouseNodeMigration): before = [ Column( "decasecond_retention_days", - UInt(8, MigrationModifiers(default=str("retention_days"))), + UInt(8, MigrationModifiers(default="retention_days")), ), Column( "min_retention_days", - UInt(8, MigrationModifiers(default=str("retention_days"))), + UInt(8, MigrationModifiers(default="retention_days")), ), ] after = [ Column( "decasecond_retention_days", - UInt(8, MigrationModifiers(default=str("7"))), + UInt(8, MigrationModifiers(default="7")), ), Column( "min_retention_days", - UInt(8, MigrationModifiers(default=str("30"))), + UInt(8, MigrationModifiers(default="30")), ), ] diff --git a/snuba/snuba_migrations/generic_metrics/0020_sets_mv2.py b/snuba/snuba_migrations/generic_metrics/0020_sets_mv2.py index c07a3ccd99f..c9c7d8abe95 100644 --- a/snuba/snuba_migrations/generic_metrics/0020_sets_mv2.py +++ b/snuba/snuba_migrations/generic_metrics/0020_sets_mv2.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import ( AggregateFunction, diff --git a/snuba/snuba_migrations/generic_metrics/0021_counters_mv2.py b/snuba/snuba_migrations/generic_metrics/0021_counters_mv2.py index 2481f3fd9a6..fd87381382f 100644 --- a/snuba/snuba_migrations/generic_metrics/0021_counters_mv2.py +++ b/snuba/snuba_migrations/generic_metrics/0021_counters_mv2.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import ( AggregateFunction, diff --git a/snuba/snuba_migrations/generic_metrics/0022_gauges_aggregate_table.py b/snuba/snuba_migrations/generic_metrics/0022_gauges_aggregate_table.py index b7450d78c2e..90ec970407b 100644 --- a/snuba/snuba_migrations/generic_metrics/0022_gauges_aggregate_table.py +++ b/snuba/snuba_migrations/generic_metrics/0022_gauges_aggregate_table.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import ( AggregateFunction, @@ -31,9 +31,7 @@ class Migration(migration.ClickhouseNodeMigration): Column("project_id", UInt(64)), Column("metric_id", UInt(64)), Column("granularity", UInt(8)), - Column( - "rounded_timestamp", DateTime(modifiers=Modifiers(codecs=["DoubleDelta"])) - ), + Column("rounded_timestamp", DateTime(modifiers=Modifiers(codecs=["DoubleDelta"]))), Column("last_timestamp", AggregateFunction("max", [DateTime()])), Column("retention_days", UInt(16)), Column( diff --git a/snuba/snuba_migrations/generic_metrics/0023_gauges_raw_table.py b/snuba/snuba_migrations/generic_metrics/0023_gauges_raw_table.py index 722e0ce0443..d8247ada3b6 100644 --- a/snuba/snuba_migrations/generic_metrics/0023_gauges_raw_table.py +++ b/snuba/snuba_migrations/generic_metrics/0023_gauges_raw_table.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import ( Array, @@ -60,19 +60,19 @@ class Migration(migration.ClickhouseNodeMigration): Column("granularities", Array(UInt(8))), Column( "decasecond_retention_days", - UInt(8, Modifiers(default=str("7"))), + UInt(8, Modifiers(default="7")), ), Column( "min_retention_days", - UInt(8, Modifiers(default=str("30"))), + UInt(8, Modifiers(default="30")), ), Column( "hr_retention_days", - UInt(8, Modifiers(default=str("retention_days"))), + UInt(8, Modifiers(default="retention_days")), ), Column( "day_retention_days", - UInt(8, Modifiers(default=str("retention_days"))), + UInt(8, Modifiers(default="retention_days")), ), ] diff --git a/snuba/snuba_migrations/generic_metrics/0024_gauges_mv.py b/snuba/snuba_migrations/generic_metrics/0024_gauges_mv.py index 2e291857240..bef45bb5c3b 100644 --- a/snuba/snuba_migrations/generic_metrics/0024_gauges_mv.py +++ b/snuba/snuba_migrations/generic_metrics/0024_gauges_mv.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import Column, DateTime, Nested, String, UInt from snuba.clusters.storage_sets import StorageSetKey diff --git a/snuba/snuba_migrations/generic_metrics/0025_counters_add_raw_tags_hash_column.py b/snuba/snuba_migrations/generic_metrics/0025_counters_add_raw_tags_hash_column.py index 398ca1c20bf..946846438d0 100644 --- a/snuba/snuba_migrations/generic_metrics/0025_counters_add_raw_tags_hash_column.py +++ b/snuba/snuba_migrations/generic_metrics/0025_counters_add_raw_tags_hash_column.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import Array, Column, UInt from snuba.clusters.storage_sets import StorageSetKey diff --git a/snuba/snuba_migrations/generic_metrics/0026_gauges_add_raw_tags_hash_column.py b/snuba/snuba_migrations/generic_metrics/0026_gauges_add_raw_tags_hash_column.py index 23d92c4103e..faf1c721377 100644 --- a/snuba/snuba_migrations/generic_metrics/0026_gauges_add_raw_tags_hash_column.py +++ b/snuba/snuba_migrations/generic_metrics/0026_gauges_add_raw_tags_hash_column.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import Array, Column, UInt from snuba.clusters.storage_sets import StorageSetKey diff --git a/snuba/snuba_migrations/generic_metrics/0027_sets_add_raw_tags_column.py b/snuba/snuba_migrations/generic_metrics/0027_sets_add_raw_tags_column.py index 76f304cc6e5..f91353529e2 100644 --- a/snuba/snuba_migrations/generic_metrics/0027_sets_add_raw_tags_column.py +++ b/snuba/snuba_migrations/generic_metrics/0027_sets_add_raw_tags_column.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import Array, Column, UInt from snuba.clusters.storage_sets import StorageSetKey diff --git a/snuba/snuba_migrations/generic_metrics/0028_distributions_add_indexed_tags_column.py b/snuba/snuba_migrations/generic_metrics/0028_distributions_add_indexed_tags_column.py index de77d9870b3..9e1acd31338 100644 --- a/snuba/snuba_migrations/generic_metrics/0028_distributions_add_indexed_tags_column.py +++ b/snuba/snuba_migrations/generic_metrics/0028_distributions_add_indexed_tags_column.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import Array, Column, UInt from snuba.clusters.storage_sets import StorageSetKey diff --git a/snuba/snuba_migrations/generic_metrics/0029_add_use_case_id_index.py b/snuba/snuba_migrations/generic_metrics/0029_add_use_case_id_index.py index 13efd2f8f7f..0c59d022559 100644 --- a/snuba/snuba_migrations/generic_metrics/0029_add_use_case_id_index.py +++ b/snuba/snuba_migrations/generic_metrics/0029_add_use_case_id_index.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clusters.storage_sets import StorageSetKey from snuba.migrations import migration, operations diff --git a/snuba/snuba_migrations/generic_metrics/0030_add_record_meta_column.py b/snuba/snuba_migrations/generic_metrics/0030_add_record_meta_column.py index e60c3492fc9..9f4a5ee5a4d 100644 --- a/snuba/snuba_migrations/generic_metrics/0030_add_record_meta_column.py +++ b/snuba/snuba_migrations/generic_metrics/0030_add_record_meta_column.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import Column, UInt from snuba.clusters.storage_sets import StorageSetKey @@ -17,14 +17,14 @@ def forwards_ops(self) -> Sequence[operations.SqlOperation]: operations.AddColumn( storage_set=self.storage_set_key, table_name=self.local_table_name, - column=Column("record_meta", UInt(8, Modifiers(default=str("0")))), + column=Column("record_meta", UInt(8, Modifiers(default="0"))), target=operations.OperationTarget.LOCAL, after="materialization_version", ), operations.AddColumn( storage_set=self.storage_set_key, table_name=self.dist_table_name, - column=Column("record_meta", UInt(8, Modifiers(default=str("0")))), + column=Column("record_meta", UInt(8, Modifiers(default="0"))), target=operations.OperationTarget.DISTRIBUTED, after="materialization_version", ), diff --git a/snuba/snuba_migrations/generic_metrics/0031_counters_meta_table.py b/snuba/snuba_migrations/generic_metrics/0031_counters_meta_table.py index 397066fc78f..601bf522a26 100644 --- a/snuba/snuba_migrations/generic_metrics/0031_counters_meta_table.py +++ b/snuba/snuba_migrations/generic_metrics/0031_counters_meta_table.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import AggregateFunction, Column, DateTime, String, UInt from snuba.clusters.storage_sets import StorageSetKey diff --git a/snuba/snuba_migrations/generic_metrics/0032_counters_meta_table_mv.py b/snuba/snuba_migrations/generic_metrics/0032_counters_meta_table_mv.py index 03da680694d..4c8612b49ef 100644 --- a/snuba/snuba_migrations/generic_metrics/0032_counters_meta_table_mv.py +++ b/snuba/snuba_migrations/generic_metrics/0032_counters_meta_table_mv.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import AggregateFunction, Column, DateTime, String, UInt from snuba.clusters.storage_sets import StorageSetKey diff --git a/snuba/snuba_migrations/generic_metrics/0033_counters_meta_tag_values_table.py b/snuba/snuba_migrations/generic_metrics/0033_counters_meta_tag_values_table.py index 08b89f3b863..fbe4ad112ac 100644 --- a/snuba/snuba_migrations/generic_metrics/0033_counters_meta_tag_values_table.py +++ b/snuba/snuba_migrations/generic_metrics/0033_counters_meta_tag_values_table.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import AggregateFunction, Column, DateTime, String, UInt from snuba.clusters.storage_sets import StorageSetKey diff --git a/snuba/snuba_migrations/generic_metrics/0034_counters_meta_tag_values_table_mv.py b/snuba/snuba_migrations/generic_metrics/0034_counters_meta_tag_values_table_mv.py index 2be4fbfa23e..a8e238f0eae 100644 --- a/snuba/snuba_migrations/generic_metrics/0034_counters_meta_tag_values_table_mv.py +++ b/snuba/snuba_migrations/generic_metrics/0034_counters_meta_tag_values_table_mv.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import AggregateFunction, Column, DateTime, String, UInt from snuba.clusters.storage_sets import StorageSetKey diff --git a/snuba/snuba_migrations/generic_metrics/0035_recreate_counters_meta_tag_value_table_mv.py b/snuba/snuba_migrations/generic_metrics/0035_recreate_counters_meta_tag_value_table_mv.py index 827486546ef..8e6955a4a63 100644 --- a/snuba/snuba_migrations/generic_metrics/0035_recreate_counters_meta_tag_value_table_mv.py +++ b/snuba/snuba_migrations/generic_metrics/0035_recreate_counters_meta_tag_value_table_mv.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import AggregateFunction, Column, DateTime, String, UInt from snuba.clusters.storage_sets import StorageSetKey diff --git a/snuba/snuba_migrations/generic_metrics/0036_counters_meta_tables_final.py b/snuba/snuba_migrations/generic_metrics/0036_counters_meta_tables_final.py index 503dfb4c66c..d5a21976096 100644 --- a/snuba/snuba_migrations/generic_metrics/0036_counters_meta_tables_final.py +++ b/snuba/snuba_migrations/generic_metrics/0036_counters_meta_tables_final.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import AggregateFunction, Column, DateTime, String, UInt from snuba.clusters.storage_sets import StorageSetKey diff --git a/snuba/snuba_migrations/generic_metrics/0037_add_record_meta_column_sets.py b/snuba/snuba_migrations/generic_metrics/0037_add_record_meta_column_sets.py index ab4650c5cf7..217721832c3 100644 --- a/snuba/snuba_migrations/generic_metrics/0037_add_record_meta_column_sets.py +++ b/snuba/snuba_migrations/generic_metrics/0037_add_record_meta_column_sets.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import Column, UInt from snuba.clusters.storage_sets import StorageSetKey @@ -17,14 +17,14 @@ def forwards_ops(self) -> Sequence[operations.SqlOperation]: operations.AddColumn( storage_set=self.storage_set_key, table_name=self.local_table_name, - column=Column("record_meta", UInt(8, Modifiers(default=str("0")))), + column=Column("record_meta", UInt(8, Modifiers(default="0"))), target=operations.OperationTarget.LOCAL, after="materialization_version", ), operations.AddColumn( storage_set=self.storage_set_key, table_name=self.dist_table_name, - column=Column("record_meta", UInt(8, Modifiers(default=str("0")))), + column=Column("record_meta", UInt(8, Modifiers(default="0"))), target=operations.OperationTarget.DISTRIBUTED, after="materialization_version", ), diff --git a/snuba/snuba_migrations/generic_metrics/0038_add_record_meta_column_distributions.py b/snuba/snuba_migrations/generic_metrics/0038_add_record_meta_column_distributions.py index 402f318acc1..5a184756744 100644 --- a/snuba/snuba_migrations/generic_metrics/0038_add_record_meta_column_distributions.py +++ b/snuba/snuba_migrations/generic_metrics/0038_add_record_meta_column_distributions.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import Column, UInt from snuba.clusters.storage_sets import StorageSetKey @@ -17,14 +17,14 @@ def forwards_ops(self) -> Sequence[operations.SqlOperation]: operations.AddColumn( storage_set=self.storage_set_key, table_name=self.local_table_name, - column=Column("record_meta", UInt(8, Modifiers(default=str("0")))), + column=Column("record_meta", UInt(8, Modifiers(default="0"))), target=operations.OperationTarget.LOCAL, after="materialization_version", ), operations.AddColumn( storage_set=self.storage_set_key, table_name=self.dist_table_name, - column=Column("record_meta", UInt(8, Modifiers(default=str("0")))), + column=Column("record_meta", UInt(8, Modifiers(default="0"))), target=operations.OperationTarget.DISTRIBUTED, after="materialization_version", ), diff --git a/snuba/snuba_migrations/generic_metrics/0039_add_record_meta_column_gauges.py b/snuba/snuba_migrations/generic_metrics/0039_add_record_meta_column_gauges.py index c8f4350a96d..e654eb4216b 100644 --- a/snuba/snuba_migrations/generic_metrics/0039_add_record_meta_column_gauges.py +++ b/snuba/snuba_migrations/generic_metrics/0039_add_record_meta_column_gauges.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import Column, UInt from snuba.clusters.storage_sets import StorageSetKey @@ -17,14 +17,14 @@ def forwards_ops(self) -> Sequence[operations.SqlOperation]: operations.AddColumn( storage_set=self.storage_set_key, table_name=self.local_table_name, - column=Column("record_meta", UInt(8, Modifiers(default=str("0")))), + column=Column("record_meta", UInt(8, Modifiers(default="0"))), target=operations.OperationTarget.LOCAL, after="materialization_version", ), operations.AddColumn( storage_set=self.storage_set_key, table_name=self.dist_table_name, - column=Column("record_meta", UInt(8, Modifiers(default=str("0")))), + column=Column("record_meta", UInt(8, Modifiers(default="0"))), target=operations.OperationTarget.DISTRIBUTED, after="materialization_version", ), diff --git a/snuba/snuba_migrations/generic_metrics/0040_remove_counters_meta_tables.py b/snuba/snuba_migrations/generic_metrics/0040_remove_counters_meta_tables.py index eff1a415ac3..5362e1d63dd 100644 --- a/snuba/snuba_migrations/generic_metrics/0040_remove_counters_meta_tables.py +++ b/snuba/snuba_migrations/generic_metrics/0040_remove_counters_meta_tables.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import AggregateFunction, Column, DateTime, String, UInt from snuba.clusters.storage_sets import StorageSetKey diff --git a/snuba/snuba_migrations/generic_metrics/0041_adjust_partitioning_meta_tables.py b/snuba/snuba_migrations/generic_metrics/0041_adjust_partitioning_meta_tables.py index d0e2276ee61..7752b600a03 100644 --- a/snuba/snuba_migrations/generic_metrics/0041_adjust_partitioning_meta_tables.py +++ b/snuba/snuba_migrations/generic_metrics/0041_adjust_partitioning_meta_tables.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import AggregateFunction, Column, DateTime, String, UInt from snuba.clusters.storage_sets import StorageSetKey diff --git a/snuba/snuba_migrations/generic_metrics/0042_rename_counters_meta_tables.py b/snuba/snuba_migrations/generic_metrics/0042_rename_counters_meta_tables.py index 6c004f6a11d..8e51920bd70 100644 --- a/snuba/snuba_migrations/generic_metrics/0042_rename_counters_meta_tables.py +++ b/snuba/snuba_migrations/generic_metrics/0042_rename_counters_meta_tables.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import AggregateFunction, Column, DateTime, String, UInt from snuba.clusters.storage_sets import StorageSetKey @@ -163,9 +163,7 @@ def create_ops(self, prefix: str) -> list[operations.SqlOperation]: storage_set=self.storage_set_key, table_name=getattr(self, f"{prefix}_tag_value_dist_table_name"), engine=table_engines.Distributed( - local_table_name=getattr( - self, f"{prefix}_tag_value_local_table_name" - ), + local_table_name=getattr(self, f"{prefix}_tag_value_local_table_name"), sharding_key=None, ), columns=self.tag_value_table_columns, @@ -175,9 +173,7 @@ def create_ops(self, prefix: str) -> list[operations.SqlOperation]: storage_set=self.storage_set_key, view_name=getattr(self, f"{prefix}_tag_value_view_name"), columns=self.tag_value_table_columns, - destination_table_name=getattr( - self, f"{prefix}_tag_value_local_table_name" - ), + destination_table_name=getattr(self, f"{prefix}_tag_value_local_table_name"), target=OperationTarget.LOCAL, query=""" SELECT diff --git a/snuba/snuba_migrations/generic_metrics/0043_sets_meta_tables.py b/snuba/snuba_migrations/generic_metrics/0043_sets_meta_tables.py index 77ab5b259fb..658968a5ce7 100644 --- a/snuba/snuba_migrations/generic_metrics/0043_sets_meta_tables.py +++ b/snuba/snuba_migrations/generic_metrics/0043_sets_meta_tables.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import AggregateFunction, Column, DateTime, String, UInt from snuba.clusters.storage_sets import StorageSetKey diff --git a/snuba/snuba_migrations/generic_metrics/0044_gauges_meta_tables.py b/snuba/snuba_migrations/generic_metrics/0044_gauges_meta_tables.py index efa6883759c..afbeec1609b 100644 --- a/snuba/snuba_migrations/generic_metrics/0044_gauges_meta_tables.py +++ b/snuba/snuba_migrations/generic_metrics/0044_gauges_meta_tables.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import AggregateFunction, Column, DateTime, String, UInt from snuba.clusters.storage_sets import StorageSetKey diff --git a/snuba/snuba_migrations/generic_metrics/0045_distributions_meta_tables.py b/snuba/snuba_migrations/generic_metrics/0045_distributions_meta_tables.py index 5523c7719e6..a02ae24bdc7 100644 --- a/snuba/snuba_migrations/generic_metrics/0045_distributions_meta_tables.py +++ b/snuba/snuba_migrations/generic_metrics/0045_distributions_meta_tables.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import AggregateFunction, Column, DateTime, String, UInt from snuba.clusters.storage_sets import StorageSetKey diff --git a/snuba/snuba_migrations/generic_metrics/0046_distributions_add_disable_percentiles.py b/snuba/snuba_migrations/generic_metrics/0046_distributions_add_disable_percentiles.py index 2e911b75dd2..4a852359ea4 100644 --- a/snuba/snuba_migrations/generic_metrics/0046_distributions_add_disable_percentiles.py +++ b/snuba/snuba_migrations/generic_metrics/0046_distributions_add_disable_percentiles.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import Column, UInt from snuba.clusters.storage_sets import StorageSetKey @@ -19,7 +19,7 @@ def forwards_ops(self) -> Sequence[operations.SqlOperation]: table_name=self.local_table_name, column=Column( "disable_percentiles", - UInt(8, Modifiers(default=str("0"), codecs=["T64"])), + UInt(8, Modifiers(default="0", codecs=["T64"])), ), target=operations.OperationTarget.LOCAL, after="granularities", @@ -29,7 +29,7 @@ def forwards_ops(self) -> Sequence[operations.SqlOperation]: table_name=self.dist_table_name, column=Column( "disable_percentiles", - UInt(8, Modifiers(default=str("0"), codecs=["T64"])), + UInt(8, Modifiers(default="0", codecs=["T64"])), ), target=operations.OperationTarget.DISTRIBUTED, after="granularities", diff --git a/snuba/snuba_migrations/generic_metrics/0047_distributions_mv3.py b/snuba/snuba_migrations/generic_metrics/0047_distributions_mv3.py index b6c79d5d9f4..cd8fa63fd33 100644 --- a/snuba/snuba_migrations/generic_metrics/0047_distributions_mv3.py +++ b/snuba/snuba_migrations/generic_metrics/0047_distributions_mv3.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import ( AggregateFunction, diff --git a/snuba/snuba_migrations/generic_metrics/0048_counters_meta_tables_support_empty_tags.py b/snuba/snuba_migrations/generic_metrics/0048_counters_meta_tables_support_empty_tags.py index 344f92d6960..230855acbc0 100644 --- a/snuba/snuba_migrations/generic_metrics/0048_counters_meta_tables_support_empty_tags.py +++ b/snuba/snuba_migrations/generic_metrics/0048_counters_meta_tables_support_empty_tags.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import AggregateFunction, Column, DateTime, String, UInt from snuba.clusters.storage_sets import StorageSetKey diff --git a/snuba/snuba_migrations/generic_metrics/0049_sets_meta_tables_support_empty_tags.py b/snuba/snuba_migrations/generic_metrics/0049_sets_meta_tables_support_empty_tags.py index ace42a5817b..1288e7a45a3 100644 --- a/snuba/snuba_migrations/generic_metrics/0049_sets_meta_tables_support_empty_tags.py +++ b/snuba/snuba_migrations/generic_metrics/0049_sets_meta_tables_support_empty_tags.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import AggregateFunction, Column, DateTime, String, UInt from snuba.clusters.storage_sets import StorageSetKey diff --git a/snuba/snuba_migrations/generic_metrics/0050_distributions_meta_tables_support_empty_tags.py b/snuba/snuba_migrations/generic_metrics/0050_distributions_meta_tables_support_empty_tags.py index 91902cef649..f5c0f52c622 100644 --- a/snuba/snuba_migrations/generic_metrics/0050_distributions_meta_tables_support_empty_tags.py +++ b/snuba/snuba_migrations/generic_metrics/0050_distributions_meta_tables_support_empty_tags.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import AggregateFunction, Column, DateTime, String, UInt from snuba.clusters.storage_sets import StorageSetKey diff --git a/snuba/snuba_migrations/generic_metrics/0051_gauges_meta_tables_support_empty_tags.py b/snuba/snuba_migrations/generic_metrics/0051_gauges_meta_tables_support_empty_tags.py index f21f94cfd4b..63f4a736003 100644 --- a/snuba/snuba_migrations/generic_metrics/0051_gauges_meta_tables_support_empty_tags.py +++ b/snuba/snuba_migrations/generic_metrics/0051_gauges_meta_tables_support_empty_tags.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import AggregateFunction, Column, DateTime, String, UInt from snuba.clusters.storage_sets import StorageSetKey diff --git a/snuba/snuba_migrations/generic_metrics/0052_counters_raw_add_sampling_weight.py b/snuba/snuba_migrations/generic_metrics/0052_counters_raw_add_sampling_weight.py index 73a4a0677fa..e28dd26ad7f 100644 --- a/snuba/snuba_migrations/generic_metrics/0052_counters_raw_add_sampling_weight.py +++ b/snuba/snuba_migrations/generic_metrics/0052_counters_raw_add_sampling_weight.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import Column, UInt from snuba.clusters.storage_sets import StorageSetKey @@ -18,7 +18,7 @@ class Migration(migration.ClickhouseNodeMigration): columns: Sequence[Column[MigrationModifiers]] = [ Column( "sampling_weight", - UInt(64, MigrationModifiers(codecs=["ZSTD(1)"], default=str("1"))), + UInt(64, MigrationModifiers(codecs=["ZSTD(1)"], default="1")), ), ] diff --git a/snuba/snuba_migrations/generic_metrics/0053_counters_aggregated_add_sampling_weight.py b/snuba/snuba_migrations/generic_metrics/0053_counters_aggregated_add_sampling_weight.py index dc79d6888e5..34761cc4089 100644 --- a/snuba/snuba_migrations/generic_metrics/0053_counters_aggregated_add_sampling_weight.py +++ b/snuba/snuba_migrations/generic_metrics/0053_counters_aggregated_add_sampling_weight.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import Column from snuba.clusters.storage_sets import StorageSetKey diff --git a/snuba/snuba_migrations/generic_metrics/0054_counters_mv3.py b/snuba/snuba_migrations/generic_metrics/0054_counters_mv3.py index 061c6dfc351..cd09a01bb95 100644 --- a/snuba/snuba_migrations/generic_metrics/0054_counters_mv3.py +++ b/snuba/snuba_migrations/generic_metrics/0054_counters_mv3.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import ( AggregateFunction, diff --git a/snuba/snuba_migrations/generic_metrics/0055_gauges_raw_add_sampling_weight.py b/snuba/snuba_migrations/generic_metrics/0055_gauges_raw_add_sampling_weight.py index 05c2ebac78e..abd679710cc 100644 --- a/snuba/snuba_migrations/generic_metrics/0055_gauges_raw_add_sampling_weight.py +++ b/snuba/snuba_migrations/generic_metrics/0055_gauges_raw_add_sampling_weight.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import Column, UInt from snuba.clusters.storage_sets import StorageSetKey @@ -16,7 +16,7 @@ class Migration(migration.ClickhouseNodeMigration): columns: Sequence[Column[MigrationModifiers]] = [ Column( "sampling_weight", - UInt(64, MigrationModifiers(codecs=["ZSTD(1)"], default=str("1"))), + UInt(64, MigrationModifiers(codecs=["ZSTD(1)"], default="1")), ), ] diff --git a/snuba/snuba_migrations/generic_metrics/0056_gauges_aggregated_add_weighted_columns.py b/snuba/snuba_migrations/generic_metrics/0056_gauges_aggregated_add_weighted_columns.py index 94e881f5b1e..3c09c924ea6 100644 --- a/snuba/snuba_migrations/generic_metrics/0056_gauges_aggregated_add_weighted_columns.py +++ b/snuba/snuba_migrations/generic_metrics/0056_gauges_aggregated_add_weighted_columns.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import Column from snuba.clusters.storage_sets import StorageSetKey diff --git a/snuba/snuba_migrations/generic_metrics/0057_gauges_mv3.py b/snuba/snuba_migrations/generic_metrics/0057_gauges_mv3.py index 96af1f80deb..7825031be10 100644 --- a/snuba/snuba_migrations/generic_metrics/0057_gauges_mv3.py +++ b/snuba/snuba_migrations/generic_metrics/0057_gauges_mv3.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import Column, DateTime, Nested, String, UInt from snuba.clusters.storage_sets import StorageSetKey diff --git a/snuba/snuba_migrations/generic_metrics/0058_distributions_raw_add_sampling_weight.py b/snuba/snuba_migrations/generic_metrics/0058_distributions_raw_add_sampling_weight.py index 8be8b25c335..f8d52ff1182 100644 --- a/snuba/snuba_migrations/generic_metrics/0058_distributions_raw_add_sampling_weight.py +++ b/snuba/snuba_migrations/generic_metrics/0058_distributions_raw_add_sampling_weight.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import Column, UInt from snuba.clusters.storage_sets import StorageSetKey @@ -16,7 +16,7 @@ class Migration(migration.ClickhouseNodeMigration): columns: Sequence[Column[MigrationModifiers]] = [ Column( "sampling_weight", - UInt(64, MigrationModifiers(codecs=["ZSTD(1)"], default=str("1"))), + UInt(64, MigrationModifiers(codecs=["ZSTD(1)"], default="1")), ), ] diff --git a/snuba/snuba_migrations/generic_metrics/0059_distributions_aggregated_add_weighted_columns.py b/snuba/snuba_migrations/generic_metrics/0059_distributions_aggregated_add_weighted_columns.py index 771b2e65376..19086211ebc 100644 --- a/snuba/snuba_migrations/generic_metrics/0059_distributions_aggregated_add_weighted_columns.py +++ b/snuba/snuba_migrations/generic_metrics/0059_distributions_aggregated_add_weighted_columns.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import Column from snuba.clusters.storage_sets import StorageSetKey diff --git a/snuba/snuba_migrations/generic_metrics/0060_distributions_mv4.py b/snuba/snuba_migrations/generic_metrics/0060_distributions_mv4.py index 0b81b30cdb4..3a03695064e 100644 --- a/snuba/snuba_migrations/generic_metrics/0060_distributions_mv4.py +++ b/snuba/snuba_migrations/generic_metrics/0060_distributions_mv4.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import ( AggregateFunction, diff --git a/snuba/snuba_migrations/generic_metrics/0061_remove_distribution_meta_tag_values_mv.py b/snuba/snuba_migrations/generic_metrics/0061_remove_distribution_meta_tag_values_mv.py index 119a01b8b2a..84c39873817 100644 --- a/snuba/snuba_migrations/generic_metrics/0061_remove_distribution_meta_tag_values_mv.py +++ b/snuba/snuba_migrations/generic_metrics/0061_remove_distribution_meta_tag_values_mv.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import AggregateFunction, Column, DateTime, String, UInt from snuba.clusters.storage_sets import StorageSetKey diff --git a/snuba/snuba_migrations/group_attributes/0001_group_attributes.py b/snuba/snuba_migrations/group_attributes/0001_group_attributes.py index 54c6f6e32e1..1f84c1edf07 100644 --- a/snuba/snuba_migrations/group_attributes/0001_group_attributes.py +++ b/snuba/snuba_migrations/group_attributes/0001_group_attributes.py @@ -1,4 +1,4 @@ -from typing import List, Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import Column, DateTime, UInt from snuba.clusters.storage_sets import StorageSetKey @@ -6,7 +6,7 @@ from snuba.migrations.columns import MigrationModifiers as Modifiers from snuba.migrations.operations import OperationTarget, SqlOperation -columns: List[Column[Modifiers]] = [ +columns: list[Column[Modifiers]] = [ Column("project_id", UInt(64)), Column("group_id", UInt(64)), # Group diff --git a/snuba/snuba_migrations/group_attributes/0002_add_priority_to_group_attributes.py b/snuba/snuba_migrations/group_attributes/0002_add_priority_to_group_attributes.py index 19a2fb34257..8c777858039 100644 --- a/snuba/snuba_migrations/group_attributes/0002_add_priority_to_group_attributes.py +++ b/snuba/snuba_migrations/group_attributes/0002_add_priority_to_group_attributes.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import Column, UInt from snuba.clusters.storage_sets import StorageSetKey diff --git a/snuba/snuba_migrations/group_attributes/0003_add_first_release_id_to_group_attributes.py b/snuba/snuba_migrations/group_attributes/0003_add_first_release_id_to_group_attributes.py index 9fec68e7ea2..aa26dfc6715 100644 --- a/snuba/snuba_migrations/group_attributes/0003_add_first_release_id_to_group_attributes.py +++ b/snuba/snuba_migrations/group_attributes/0003_add_first_release_id_to_group_attributes.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import UUID, Column from snuba.clusters.storage_sets import StorageSetKey diff --git a/snuba/snuba_migrations/group_attributes/0004_add_new_first_release_column_to_group_attributes.py b/snuba/snuba_migrations/group_attributes/0004_add_new_first_release_column_to_group_attributes.py index 6c937d5b3c6..5793aa433eb 100644 --- a/snuba/snuba_migrations/group_attributes/0004_add_new_first_release_column_to_group_attributes.py +++ b/snuba/snuba_migrations/group_attributes/0004_add_new_first_release_column_to_group_attributes.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import Column, UInt from snuba.clusters.storage_sets import StorageSetKey diff --git a/snuba/snuba_migrations/metrics/0001_metrics_buckets.py b/snuba/snuba_migrations/metrics/0001_metrics_buckets.py index 8edfeb281c4..3c6545bb546 100644 --- a/snuba/snuba_migrations/metrics/0001_metrics_buckets.py +++ b/snuba/snuba_migrations/metrics/0001_metrics_buckets.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import Array, Column, UInt from snuba.migrations import migration, operations diff --git a/snuba/snuba_migrations/metrics/0002_metrics_sets.py b/snuba/snuba_migrations/metrics/0002_metrics_sets.py index 5fcb52a3786..9a496f4a105 100644 --- a/snuba/snuba_migrations/metrics/0002_metrics_sets.py +++ b/snuba/snuba_migrations/metrics/0002_metrics_sets.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import AggregateFunction, Column, UInt from snuba.migrations import migration, operations diff --git a/snuba/snuba_migrations/metrics/0003_counters_to_buckets.py b/snuba/snuba_migrations/metrics/0003_counters_to_buckets.py index 5a7b93e7d32..7b79ac6e497 100644 --- a/snuba/snuba_migrations/metrics/0003_counters_to_buckets.py +++ b/snuba/snuba_migrations/metrics/0003_counters_to_buckets.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import Column, Float from snuba.migrations import migration, operations diff --git a/snuba/snuba_migrations/metrics/0004_metrics_counters.py b/snuba/snuba_migrations/metrics/0004_metrics_counters.py index 9bc86da9d1a..57c71798b0b 100644 --- a/snuba/snuba_migrations/metrics/0004_metrics_counters.py +++ b/snuba/snuba_migrations/metrics/0004_metrics_counters.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import AggregateFunction, Column, Float from snuba.migrations import migration, operations diff --git a/snuba/snuba_migrations/metrics/0005_metrics_distributions_buckets.py b/snuba/snuba_migrations/metrics/0005_metrics_distributions_buckets.py index 08476493104..2430f122e31 100644 --- a/snuba/snuba_migrations/metrics/0005_metrics_distributions_buckets.py +++ b/snuba/snuba_migrations/metrics/0005_metrics_distributions_buckets.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import Array, Column, Float from snuba.migrations import migration, operations diff --git a/snuba/snuba_migrations/metrics/0006_metrics_distributions.py b/snuba/snuba_migrations/metrics/0006_metrics_distributions.py index 4163b1a64a2..dc9f0fc5ea4 100644 --- a/snuba/snuba_migrations/metrics/0006_metrics_distributions.py +++ b/snuba/snuba_migrations/metrics/0006_metrics_distributions.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.migrations import migration, operations from snuba.snuba_migrations.metrics.templates import ( @@ -14,9 +14,7 @@ class Migration(migration.ClickhouseNodeMigrationLegacy): blocking = False def forwards_local(self) -> Sequence[operations.SqlOperation]: - return get_forward_migrations_local( - **get_migration_args_for_distributions(granularity=60) - ) + return get_forward_migrations_local(**get_migration_args_for_distributions(granularity=60)) def backwards_local(self) -> Sequence[operations.SqlOperation]: return [ diff --git a/snuba/snuba_migrations/metrics/0007_metrics_sets_granularity_10.py b/snuba/snuba_migrations/metrics/0007_metrics_sets_granularity_10.py index ed36639e018..ad21811b70c 100644 --- a/snuba/snuba_migrations/metrics/0007_metrics_sets_granularity_10.py +++ b/snuba/snuba_migrations/metrics/0007_metrics_sets_granularity_10.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clusters.storage_sets import StorageSetKey from snuba.migrations import migration, operations @@ -20,11 +20,7 @@ class Migration(migration.ClickhouseNodeMigrationLegacy): blocking = False def forwards_local(self) -> Sequence[operations.SqlOperation]: - return ( - get_forward_view_migration_local( - **get_migration_args_for_sets(granularity=10) - ), - ) + return (get_forward_view_migration_local(**get_migration_args_for_sets(granularity=10)),) def backwards_local(self) -> Sequence[operations.SqlOperation]: return [ diff --git a/snuba/snuba_migrations/metrics/0008_metrics_counters_granularity_10.py b/snuba/snuba_migrations/metrics/0008_metrics_counters_granularity_10.py index ff2a853416d..0d488c444b1 100644 --- a/snuba/snuba_migrations/metrics/0008_metrics_counters_granularity_10.py +++ b/snuba/snuba_migrations/metrics/0008_metrics_counters_granularity_10.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clusters.storage_sets import StorageSetKey from snuba.migrations import migration, operations @@ -21,9 +21,7 @@ class Migration(migration.ClickhouseNodeMigrationLegacy): def forwards_local(self) -> Sequence[operations.SqlOperation]: return ( - get_forward_view_migration_local( - **get_migration_args_for_counters(granularity=10) - ), + get_forward_view_migration_local(**get_migration_args_for_counters(granularity=10)), ) def backwards_local(self) -> Sequence[operations.SqlOperation]: diff --git a/snuba/snuba_migrations/metrics/0009_metrics_distributions_granularity_10.py b/snuba/snuba_migrations/metrics/0009_metrics_distributions_granularity_10.py index 639b2c68800..848fda0a67f 100644 --- a/snuba/snuba_migrations/metrics/0009_metrics_distributions_granularity_10.py +++ b/snuba/snuba_migrations/metrics/0009_metrics_distributions_granularity_10.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clusters.storage_sets import StorageSetKey from snuba.migrations import migration, operations diff --git a/snuba/snuba_migrations/metrics/0010_metrics_sets_granularity_1h.py b/snuba/snuba_migrations/metrics/0010_metrics_sets_granularity_1h.py index 2db351e1c8b..5d2468bb1b4 100644 --- a/snuba/snuba_migrations/metrics/0010_metrics_sets_granularity_1h.py +++ b/snuba/snuba_migrations/metrics/0010_metrics_sets_granularity_1h.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clusters.storage_sets import StorageSetKey from snuba.migrations import migration, operations @@ -20,9 +20,7 @@ class Migration(migration.ClickhouseNodeMigrationLegacy): def forwards_local(self) -> Sequence[operations.SqlOperation]: return ( - get_forward_view_migration_local( - **get_migration_args_for_sets(granularity=60 * 60) - ), + get_forward_view_migration_local(**get_migration_args_for_sets(granularity=60 * 60)), ) def backwards_local(self) -> Sequence[operations.SqlOperation]: diff --git a/snuba/snuba_migrations/metrics/0011_metrics_counters_granularity_1h.py b/snuba/snuba_migrations/metrics/0011_metrics_counters_granularity_1h.py index a1f5ba1956c..174d6488d26 100644 --- a/snuba/snuba_migrations/metrics/0011_metrics_counters_granularity_1h.py +++ b/snuba/snuba_migrations/metrics/0011_metrics_counters_granularity_1h.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clusters.storage_sets import StorageSetKey from snuba.migrations import migration, operations diff --git a/snuba/snuba_migrations/metrics/0012_metrics_distributions_granularity_1h.py b/snuba/snuba_migrations/metrics/0012_metrics_distributions_granularity_1h.py index 0a475768f94..888dcfdcfbc 100644 --- a/snuba/snuba_migrations/metrics/0012_metrics_distributions_granularity_1h.py +++ b/snuba/snuba_migrations/metrics/0012_metrics_distributions_granularity_1h.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clusters.storage_sets import StorageSetKey from snuba.migrations import migration, operations diff --git a/snuba/snuba_migrations/metrics/0013_metrics_sets_granularity_1d.py b/snuba/snuba_migrations/metrics/0013_metrics_sets_granularity_1d.py index cdf09d1c483..8b1b75bc972 100644 --- a/snuba/snuba_migrations/metrics/0013_metrics_sets_granularity_1d.py +++ b/snuba/snuba_migrations/metrics/0013_metrics_sets_granularity_1d.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clusters.storage_sets import StorageSetKey from snuba.migrations import migration, operations diff --git a/snuba/snuba_migrations/metrics/0014_metrics_counters_granularity_1d.py b/snuba/snuba_migrations/metrics/0014_metrics_counters_granularity_1d.py index 2f8d9ba5e94..01b15c06f56 100644 --- a/snuba/snuba_migrations/metrics/0014_metrics_counters_granularity_1d.py +++ b/snuba/snuba_migrations/metrics/0014_metrics_counters_granularity_1d.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clusters.storage_sets import StorageSetKey from snuba.migrations import migration, operations diff --git a/snuba/snuba_migrations/metrics/0015_metrics_distributions_granularity_1d.py b/snuba/snuba_migrations/metrics/0015_metrics_distributions_granularity_1d.py index 0e3321e8261..6a4ef3282fb 100644 --- a/snuba/snuba_migrations/metrics/0015_metrics_distributions_granularity_1d.py +++ b/snuba/snuba_migrations/metrics/0015_metrics_distributions_granularity_1d.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clusters.storage_sets import StorageSetKey from snuba.migrations import migration, operations diff --git a/snuba/snuba_migrations/metrics/0016_metrics_sets_consolidated_granularity.py b/snuba/snuba_migrations/metrics/0016_metrics_sets_consolidated_granularity.py index 562ad518abd..b6dcacf3f40 100644 --- a/snuba/snuba_migrations/metrics/0016_metrics_sets_consolidated_granularity.py +++ b/snuba/snuba_migrations/metrics/0016_metrics_sets_consolidated_granularity.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import AggregateFunction, Column, UInt from snuba.clusters.storage_sets import StorageSetKey diff --git a/snuba/snuba_migrations/metrics/0017_metrics_counters_consolidated_granularity.py b/snuba/snuba_migrations/metrics/0017_metrics_counters_consolidated_granularity.py index 234a53d6bbc..30967b7a40b 100644 --- a/snuba/snuba_migrations/metrics/0017_metrics_counters_consolidated_granularity.py +++ b/snuba/snuba_migrations/metrics/0017_metrics_counters_consolidated_granularity.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import AggregateFunction, Column, Float from snuba.clusters.storage_sets import StorageSetKey diff --git a/snuba/snuba_migrations/metrics/0018_metrics_distributions_consolidated_granularity.py b/snuba/snuba_migrations/metrics/0018_metrics_distributions_consolidated_granularity.py index ec358f39fc5..9ea48f99d5d 100644 --- a/snuba/snuba_migrations/metrics/0018_metrics_distributions_consolidated_granularity.py +++ b/snuba/snuba_migrations/metrics/0018_metrics_distributions_consolidated_granularity.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clusters.storage_sets import StorageSetKey from snuba.migrations import migration, operations diff --git a/snuba/snuba_migrations/metrics/0019_aggregate_tables_add_ttl.py b/snuba/snuba_migrations/metrics/0019_aggregate_tables_add_ttl.py index e78435c0a6f..515f6fa21a2 100644 --- a/snuba/snuba_migrations/metrics/0019_aggregate_tables_add_ttl.py +++ b/snuba/snuba_migrations/metrics/0019_aggregate_tables_add_ttl.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clusters.storage_sets import StorageSetKey from snuba.migrations import migration, operations @@ -21,8 +21,7 @@ def forwards_local(self) -> Sequence[operations.SqlOperation]: operations.RunSql( storage_set=StorageSetKey.METRICS, statement=( - f"ALTER TABLE {table_name} MODIFY TTL " - "timestamp + toIntervalDay(retention_days)" + f"ALTER TABLE {table_name} MODIFY TTL timestamp + toIntervalDay(retention_days)" ), ) for table_name in self.table_names diff --git a/snuba/snuba_migrations/metrics/0020_polymorphic_buckets_table.py b/snuba/snuba_migrations/metrics/0020_polymorphic_buckets_table.py index babdc2bbd39..865e37419b7 100644 --- a/snuba/snuba_migrations/metrics/0020_polymorphic_buckets_table.py +++ b/snuba/snuba_migrations/metrics/0020_polymorphic_buckets_table.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clusters.storage_sets import StorageSetKey from snuba.migrations import migration, operations, table_engines @@ -73,7 +73,5 @@ def forwards_dist(self) -> Sequence[operations.SqlOperation]: def backwards_dist(self) -> Sequence[operations.SqlOperation]: return [ - operations.DropTable( - storage_set=StorageSetKey.METRICS, table_name=self.dist_table_name - ) + operations.DropTable(storage_set=StorageSetKey.METRICS, table_name=self.dist_table_name) ] diff --git a/snuba/snuba_migrations/metrics/0021_polymorphic_bucket_materialized_views.py b/snuba/snuba_migrations/metrics/0021_polymorphic_bucket_materialized_views.py index 9f8f0e83b5f..3a512d08837 100644 --- a/snuba/snuba_migrations/metrics/0021_polymorphic_bucket_materialized_views.py +++ b/snuba/snuba_migrations/metrics/0021_polymorphic_bucket_materialized_views.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clusters.storage_sets import StorageSetKey from snuba.migrations import migration, operations diff --git a/snuba/snuba_migrations/metrics/0022_repartition_polymorphic_table.py b/snuba/snuba_migrations/metrics/0022_repartition_polymorphic_table.py index 93b0e251118..63ecfdd85ee 100644 --- a/snuba/snuba_migrations/metrics/0022_repartition_polymorphic_table.py +++ b/snuba/snuba_migrations/metrics/0022_repartition_polymorphic_table.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clusters.storage_sets import StorageSetKey from snuba.migrations import migration, operations, table_engines @@ -74,7 +74,5 @@ def forwards_dist(self) -> Sequence[operations.SqlOperation]: def backwards_dist(self) -> Sequence[operations.SqlOperation]: return [ - operations.DropTable( - storage_set=StorageSetKey.METRICS, table_name=self.dist_table_name - ) + operations.DropTable(storage_set=StorageSetKey.METRICS, table_name=self.dist_table_name) ] diff --git a/snuba/snuba_migrations/metrics/0023_polymorphic_repartitioned_bucket_matview.py b/snuba/snuba_migrations/metrics/0023_polymorphic_repartitioned_bucket_matview.py index 3485677d164..69bd42b7c1a 100644 --- a/snuba/snuba_migrations/metrics/0023_polymorphic_repartitioned_bucket_matview.py +++ b/snuba/snuba_migrations/metrics/0023_polymorphic_repartitioned_bucket_matview.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clusters.storage_sets import StorageSetKey from snuba.migrations import migration, operations diff --git a/snuba/snuba_migrations/metrics/0024_metrics_distributions_add_histogram.py b/snuba/snuba_migrations/metrics/0024_metrics_distributions_add_histogram.py index 76b72632754..4788bc2d2b5 100644 --- a/snuba/snuba_migrations/metrics/0024_metrics_distributions_add_histogram.py +++ b/snuba/snuba_migrations/metrics/0024_metrics_distributions_add_histogram.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import AggregateFunction, Column, Float, UInt from snuba.clusters.storage_sets import StorageSetKey @@ -18,9 +18,7 @@ class Migration(migration.ClickhouseNodeMigrationLegacy): blocking = False raw_table_name = "metrics_raw_v2_local" - def __forward_migrations( - self, table_name: str - ) -> Sequence[operations.SqlOperation]: + def __forward_migrations(self, table_name: str) -> Sequence[operations.SqlOperation]: return [ operations.AddColumn( storage_set=StorageSetKey.METRICS, @@ -33,9 +31,7 @@ def __forward_migrations( ) ] - def __backward_migrations( - self, table_name: str - ) -> Sequence[operations.SqlOperation]: + def __backward_migrations(self, table_name: str) -> Sequence[operations.SqlOperation]: return [ # operations.DropColumn( # storage_set=StorageSetKey.METRICS, diff --git a/snuba/snuba_migrations/metrics/0025_metrics_counters_aggregate_v2.py b/snuba/snuba_migrations/metrics/0025_metrics_counters_aggregate_v2.py index 4210f555e5c..d5449e16c04 100644 --- a/snuba/snuba_migrations/metrics/0025_metrics_counters_aggregate_v2.py +++ b/snuba/snuba_migrations/metrics/0025_metrics_counters_aggregate_v2.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import Array, Column, UInt from snuba.clusters.storage_sets import StorageSetKey @@ -65,11 +65,7 @@ def forwards_local(self) -> Sequence[operations.SqlOperation]: ] def backwards_local(self) -> Sequence[operations.SqlOperation]: - return [ - operations.DropTable( - storage_set=StorageSetKey.METRICS, table_name=self.table_name - ) - ] + return [operations.DropTable(storage_set=StorageSetKey.METRICS, table_name=self.table_name)] def forwards_dist(self) -> Sequence[operations.SqlOperation]: return [ @@ -94,7 +90,5 @@ def forwards_dist(self) -> Sequence[operations.SqlOperation]: def backwards_dist(self) -> Sequence[operations.SqlOperation]: return [ - operations.DropTable( - storage_set=StorageSetKey.METRICS, table_name=self.dist_table_name - ) + operations.DropTable(storage_set=StorageSetKey.METRICS, table_name=self.dist_table_name) ] diff --git a/snuba/snuba_migrations/metrics/0026_metrics_counters_v2_writing_matview.py b/snuba/snuba_migrations/metrics/0026_metrics_counters_v2_writing_matview.py index 37b5331722e..ecd5d0cc53b 100644 --- a/snuba/snuba_migrations/metrics/0026_metrics_counters_v2_writing_matview.py +++ b/snuba/snuba_migrations/metrics/0026_metrics_counters_v2_writing_matview.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clusters.storage_sets import StorageSetKey from snuba.migrations import migration, operations diff --git a/snuba/snuba_migrations/metrics/0027_fix_migration_0026.py b/snuba/snuba_migrations/metrics/0027_fix_migration_0026.py index 5254b6001cb..857a47abd82 100644 --- a/snuba/snuba_migrations/metrics/0027_fix_migration_0026.py +++ b/snuba/snuba_migrations/metrics/0027_fix_migration_0026.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clusters.storage_sets import StorageSetKey from snuba.migrations import migration, operations diff --git a/snuba/snuba_migrations/metrics/0028_metrics_sets_aggregate_v2.py b/snuba/snuba_migrations/metrics/0028_metrics_sets_aggregate_v2.py index a1366264e9c..0aba085cc8a 100644 --- a/snuba/snuba_migrations/metrics/0028_metrics_sets_aggregate_v2.py +++ b/snuba/snuba_migrations/metrics/0028_metrics_sets_aggregate_v2.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import Array, Column, UInt from snuba.clusters.storage_sets import StorageSetKey @@ -65,11 +65,7 @@ def forwards_local(self) -> Sequence[operations.SqlOperation]: ] def backwards_local(self) -> Sequence[operations.SqlOperation]: - return [ - operations.DropTable( - storage_set=StorageSetKey.METRICS, table_name=self.table_name - ) - ] + return [operations.DropTable(storage_set=StorageSetKey.METRICS, table_name=self.table_name)] def forwards_dist(self) -> Sequence[operations.SqlOperation]: return [ @@ -94,7 +90,5 @@ def forwards_dist(self) -> Sequence[operations.SqlOperation]: def backwards_dist(self) -> Sequence[operations.SqlOperation]: return [ - operations.DropTable( - storage_set=StorageSetKey.METRICS, table_name=self.dist_table_name - ) + operations.DropTable(storage_set=StorageSetKey.METRICS, table_name=self.dist_table_name) ] diff --git a/snuba/snuba_migrations/metrics/0029_metrics_distributions_aggregate_v2.py b/snuba/snuba_migrations/metrics/0029_metrics_distributions_aggregate_v2.py index 0b3448ff83c..24d737b5438 100644 --- a/snuba/snuba_migrations/metrics/0029_metrics_distributions_aggregate_v2.py +++ b/snuba/snuba_migrations/metrics/0029_metrics_distributions_aggregate_v2.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import Array, Column, UInt from snuba.clusters.storage_sets import StorageSetKey @@ -74,11 +74,7 @@ def forwards_local(self) -> Sequence[operations.SqlOperation]: ] def backwards_local(self) -> Sequence[operations.SqlOperation]: - return [ - operations.DropTable( - storage_set=StorageSetKey.METRICS, table_name=self.table_name - ) - ] + return [operations.DropTable(storage_set=StorageSetKey.METRICS, table_name=self.table_name)] def forwards_dist(self) -> Sequence[operations.SqlOperation]: return [ @@ -103,7 +99,5 @@ def forwards_dist(self) -> Sequence[operations.SqlOperation]: def backwards_dist(self) -> Sequence[operations.SqlOperation]: return [ - operations.DropTable( - storage_set=StorageSetKey.METRICS, table_name=self.dist_table_name - ) + operations.DropTable(storage_set=StorageSetKey.METRICS, table_name=self.dist_table_name) ] diff --git a/snuba/snuba_migrations/metrics/0030_metrics_distributions_v2_writing_mv.py b/snuba/snuba_migrations/metrics/0030_metrics_distributions_v2_writing_mv.py index 6b9c65f3204..aed87cd6782 100644 --- a/snuba/snuba_migrations/metrics/0030_metrics_distributions_v2_writing_mv.py +++ b/snuba/snuba_migrations/metrics/0030_metrics_distributions_v2_writing_mv.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clusters.storage_sets import StorageSetKey from snuba.migrations import migration, operations @@ -25,9 +25,7 @@ def forwards_local(self) -> Sequence[operations.SqlOperation]: source_table_name=self.raw_table_name, table_name=self.table_name, aggregation_col_schema=COL_SCHEMA_DISTRIBUTIONS_V2, - mv_name=get_polymorphic_mv_variant_name( - "distributions", self.mv_version - ), + mv_name=get_polymorphic_mv_variant_name("distributions", self.mv_version), aggregation_states=( "quantilesState(0.5, 0.75, 0.9, 0.95, 0.99)((arrayJoin(distribution_values) AS values_rows)) as percentiles, " "minState(values_rows) as min, " @@ -47,9 +45,7 @@ def backwards_local(self) -> Sequence[operations.SqlOperation]: return [ operations.DropTable( storage_set=StorageSetKey.METRICS, - table_name=get_polymorphic_mv_variant_name( - "distributions", self.mv_version - ), + table_name=get_polymorphic_mv_variant_name("distributions", self.mv_version), ) ] diff --git a/snuba/snuba_migrations/metrics/0031_metrics_sets_v2_writing_mv.py b/snuba/snuba_migrations/metrics/0031_metrics_sets_v2_writing_mv.py index 8accbca98dc..dc2955de347 100644 --- a/snuba/snuba_migrations/metrics/0031_metrics_sets_v2_writing_mv.py +++ b/snuba/snuba_migrations/metrics/0031_metrics_sets_v2_writing_mv.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import AggregateFunction, Column, UInt from snuba.clusters.storage_sets import StorageSetKey diff --git a/snuba/snuba_migrations/metrics/0032_redo_0030_and_0031_without_timestamps.py b/snuba/snuba_migrations/metrics/0032_redo_0030_and_0031_without_timestamps.py index b4fbcd383bf..b70e3b436df 100644 --- a/snuba/snuba_migrations/metrics/0032_redo_0030_and_0031_without_timestamps.py +++ b/snuba/snuba_migrations/metrics/0032_redo_0030_and_0031_without_timestamps.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clusters.storage_sets import StorageSetKey from snuba.migrations import migration, operations @@ -25,17 +25,13 @@ def forwards_local(self) -> Sequence[operations.SqlOperation]: return [ operations.DropTable( storage_set=StorageSetKey.METRICS, - table_name=get_polymorphic_mv_variant_name( - "distributions", self.mv_version - ), + table_name=get_polymorphic_mv_variant_name("distributions", self.mv_version), ), get_forward_view_migration_polymorphic_table_v3( source_table_name=self.raw_table_name, table_name=self.dist_table_name, aggregation_col_schema=COL_SCHEMA_DISTRIBUTIONS_V2, - mv_name=get_polymorphic_mv_variant_name( - "distributions", self.mv_version - ), + mv_name=get_polymorphic_mv_variant_name("distributions", self.mv_version), aggregation_states=( "quantilesState(0.5, 0.75, 0.9, 0.95, 0.99)((arrayJoin(distribution_values) AS values_rows)) as percentiles, " "minState(values_rows) as min, " diff --git a/snuba/snuba_migrations/metrics/0033_metrics_cleanup_old_views.py b/snuba/snuba_migrations/metrics/0033_metrics_cleanup_old_views.py index 60a96daa069..3567228f286 100644 --- a/snuba/snuba_migrations/metrics/0033_metrics_cleanup_old_views.py +++ b/snuba/snuba_migrations/metrics/0033_metrics_cleanup_old_views.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clusters.storage_sets import StorageSetKey from snuba.migrations import migration, operations @@ -13,9 +13,7 @@ class Migration(migration.ClickhouseNodeMigrationLegacy): def forwards_local(self) -> Sequence[operations.SqlOperation]: return [ - operations.DropTable( - storage_set=StorageSetKey.METRICS, table_name=table_name - ) + operations.DropTable(storage_set=StorageSetKey.METRICS, table_name=table_name) for table_name in [ "metrics_counters_consolidated_mv_local", "metrics_counters_mv_10s_local", diff --git a/snuba/snuba_migrations/metrics/0034_metrics_cleanup_old_tables.py b/snuba/snuba_migrations/metrics/0034_metrics_cleanup_old_tables.py index 42f7d69644d..3ba9d4160ea 100644 --- a/snuba/snuba_migrations/metrics/0034_metrics_cleanup_old_tables.py +++ b/snuba/snuba_migrations/metrics/0034_metrics_cleanup_old_tables.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clusters.storage_sets import StorageSetKey from snuba.migrations import migration, operations @@ -13,9 +13,7 @@ class Migration(migration.ClickhouseNodeMigrationLegacy): def forwards_local(self) -> Sequence[operations.SqlOperation]: return [ - operations.DropTable( - storage_set=StorageSetKey.METRICS, table_name=table_name - ) + operations.DropTable(storage_set=StorageSetKey.METRICS, table_name=table_name) for table_name in [ "metrics_buckets_local", "metrics_counters_buckets_local", @@ -32,9 +30,7 @@ def backwards_local(self) -> Sequence[operations.SqlOperation]: def forwards_dist(self) -> Sequence[operations.SqlOperation]: return [ - operations.DropTable( - storage_set=StorageSetKey.METRICS, table_name=table_name - ) + operations.DropTable(storage_set=StorageSetKey.METRICS, table_name=table_name) for table_name in [ "metrics_buckets_dist", "metrics_counters_buckets_dist", diff --git a/snuba/snuba_migrations/metrics/0035_metrics_raw_timeseries_id.py b/snuba/snuba_migrations/metrics/0035_metrics_raw_timeseries_id.py index 450b93fb208..2e79e7d493c 100644 --- a/snuba/snuba_migrations/metrics/0035_metrics_raw_timeseries_id.py +++ b/snuba/snuba_migrations/metrics/0035_metrics_raw_timeseries_id.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import Column, UInt from snuba.clusters.storage_sets import StorageSetKey @@ -7,7 +7,6 @@ class Migration(migration.ClickhouseNodeMigration): - """ Adds the timeseries_id column to the metrics raw table so we can add a sharding key for each timeseries and scale the release health cluster. diff --git a/snuba/snuba_migrations/metrics/templates.py b/snuba/snuba_migrations/metrics/templates.py index a56f243981f..417abbf9e46 100644 --- a/snuba/snuba_migrations/metrics/templates.py +++ b/snuba/snuba_migrations/metrics/templates.py @@ -1,4 +1,5 @@ -from typing import Sequence, TypedDict +from collections.abc import Sequence +from typing import TypedDict from snuba.clickhouse.columns import ( AggregateFunction, @@ -434,9 +435,7 @@ def get_forward_migrations_dist( storage_set=StorageSetKey.METRICS, table_name=dist_table_name, columns=[*COMMON_AGGR_COLUMNS, *aggregation_col_schema], - engine=table_engines.Distributed( - local_table_name=local_table_name, sharding_key=None - ), + engine=table_engines.Distributed(local_table_name=local_table_name, sharding_key=None), ), operations.AddColumn( storage_set=StorageSetKey.METRICS, diff --git a/snuba/snuba_migrations/outcomes/0001_outcomes.py b/snuba/snuba_migrations/outcomes/0001_outcomes.py index 3f312f61bdd..d17067cf5c6 100644 --- a/snuba/snuba_migrations/outcomes/0001_outcomes.py +++ b/snuba/snuba_migrations/outcomes/0001_outcomes.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import UUID, Column, DateTime, String, UInt from snuba.clusters.storage_sets import StorageSetKey diff --git a/snuba/snuba_migrations/outcomes/0002_outcomes_remove_size_and_bytes.py b/snuba/snuba_migrations/outcomes/0002_outcomes_remove_size_and_bytes.py index 32601b903cd..eb5fd3194e7 100644 --- a/snuba/snuba_migrations/outcomes/0002_outcomes_remove_size_and_bytes.py +++ b/snuba/snuba_migrations/outcomes/0002_outcomes_remove_size_and_bytes.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clusters.storage_sets import StorageSetKey from snuba.migrations import migration, operations diff --git a/snuba/snuba_migrations/outcomes/0003_outcomes_add_category_and_quantity.py b/snuba/snuba_migrations/outcomes/0003_outcomes_add_category_and_quantity.py index 6ddb4da053b..0f239173910 100644 --- a/snuba/snuba_migrations/outcomes/0003_outcomes_add_category_and_quantity.py +++ b/snuba/snuba_migrations/outcomes/0003_outcomes_add_category_and_quantity.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import Column, UInt from snuba.clusters.storage_sets import StorageSetKey @@ -44,15 +44,9 @@ def forwards_local(self) -> Sequence[operations.SqlOperation]: def backwards_local(self) -> Sequence[operations.SqlOperation]: return [ - operations.DropColumn( - StorageSetKey.OUTCOMES, "outcomes_raw_local", "quantity" - ), - operations.DropColumn( - StorageSetKey.OUTCOMES, "outcomes_raw_local", "category" - ), - operations.DropColumn( - StorageSetKey.OUTCOMES, "outcomes_hourly_local", "quantity" - ), + operations.DropColumn(StorageSetKey.OUTCOMES, "outcomes_raw_local", "quantity"), + operations.DropColumn(StorageSetKey.OUTCOMES, "outcomes_raw_local", "category"), + operations.DropColumn(StorageSetKey.OUTCOMES, "outcomes_hourly_local", "quantity"), operations.RunSql( storage_set=StorageSetKey.OUTCOMES, statement=""" @@ -60,9 +54,7 @@ def backwards_local(self) -> Sequence[operations.SqlOperation]: MODIFY ORDER BY (org_id, project_id, key_id, outcome, reason, timestamp); """, ), - operations.DropColumn( - StorageSetKey.OUTCOMES, "outcomes_hourly_local", "category" - ), + operations.DropColumn(StorageSetKey.OUTCOMES, "outcomes_hourly_local", "category"), ] def forwards_dist(self) -> Sequence[operations.SqlOperation]: @@ -95,16 +87,8 @@ def forwards_dist(self) -> Sequence[operations.SqlOperation]: def backwards_dist(self) -> Sequence[operations.SqlOperation]: return [ - operations.DropColumn( - StorageSetKey.OUTCOMES, "outcomes_raw_dist", "quantity" - ), - operations.DropColumn( - StorageSetKey.OUTCOMES, "outcomes_raw_dist", "category" - ), - operations.DropColumn( - StorageSetKey.OUTCOMES, "outcomes_hourly_dist", "quantity" - ), - operations.DropColumn( - StorageSetKey.OUTCOMES, "outcomes_hourly_dist", "category" - ), + operations.DropColumn(StorageSetKey.OUTCOMES, "outcomes_raw_dist", "quantity"), + operations.DropColumn(StorageSetKey.OUTCOMES, "outcomes_raw_dist", "category"), + operations.DropColumn(StorageSetKey.OUTCOMES, "outcomes_hourly_dist", "quantity"), + operations.DropColumn(StorageSetKey.OUTCOMES, "outcomes_hourly_dist", "category"), ] diff --git a/snuba/snuba_migrations/outcomes/0004_outcomes_matview_additions.py b/snuba/snuba_migrations/outcomes/0004_outcomes_matview_additions.py index 2d8d3ddadc3..111999ef410 100644 --- a/snuba/snuba_migrations/outcomes/0004_outcomes_matview_additions.py +++ b/snuba/snuba_migrations/outcomes/0004_outcomes_matview_additions.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import Column, DateTime, String, UInt from snuba.clusters.storage_sets import StorageSetKey diff --git a/snuba/snuba_migrations/outcomes/0005_outcomes_ttl.py b/snuba/snuba_migrations/outcomes/0005_outcomes_ttl.py index 0147a5907da..28219fc7275 100644 --- a/snuba/snuba_migrations/outcomes/0005_outcomes_ttl.py +++ b/snuba/snuba_migrations/outcomes/0005_outcomes_ttl.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clusters.storage_sets import StorageSetKey from snuba.migrations import migration, operations diff --git a/snuba/snuba_migrations/outcomes/0006_outcomes_add_size_col.py b/snuba/snuba_migrations/outcomes/0006_outcomes_add_size_col.py index dfb72e26751..b08494b1ac9 100644 --- a/snuba/snuba_migrations/outcomes/0006_outcomes_add_size_col.py +++ b/snuba/snuba_migrations/outcomes/0006_outcomes_add_size_col.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import Column, UInt from snuba.clusters.storage_sets import StorageSetKey diff --git a/snuba/snuba_migrations/outcomes/0007_outcomes_add_event_id_ttl_codec.py b/snuba/snuba_migrations/outcomes/0007_outcomes_add_event_id_ttl_codec.py index 0e27ee2b090..7c474a0a969 100644 --- a/snuba/snuba_migrations/outcomes/0007_outcomes_add_event_id_ttl_codec.py +++ b/snuba/snuba_migrations/outcomes/0007_outcomes_add_event_id_ttl_codec.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import UUID, Column from snuba.clusters.storage_sets import StorageSetKey diff --git a/snuba/snuba_migrations/outcomes/0008_outcomes_add_indexes.py b/snuba/snuba_migrations/outcomes/0008_outcomes_add_indexes.py index 6cb024c70dd..cddd57d0f9e 100644 --- a/snuba/snuba_migrations/outcomes/0008_outcomes_add_indexes.py +++ b/snuba/snuba_migrations/outcomes/0008_outcomes_add_indexes.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clusters.storage_sets import StorageSetKey from snuba.migrations import migration, operations diff --git a/snuba/snuba_migrations/outcomes/0009_outcomes_daily_table.py b/snuba/snuba_migrations/outcomes/0009_outcomes_daily_table.py index 745a99064b2..b20f0a088f3 100644 --- a/snuba/snuba_migrations/outcomes/0009_outcomes_daily_table.py +++ b/snuba/snuba_migrations/outcomes/0009_outcomes_daily_table.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import Column, DateTime, String, UInt from snuba.clusters.storage_sets import StorageSetKey diff --git a/snuba/snuba_migrations/outcomes/0010_outcomes_daily_fixed_partitioning.py b/snuba/snuba_migrations/outcomes/0010_outcomes_daily_fixed_partitioning.py index 18fd2b53549..c3b481a0b50 100644 --- a/snuba/snuba_migrations/outcomes/0010_outcomes_daily_fixed_partitioning.py +++ b/snuba/snuba_migrations/outcomes/0010_outcomes_daily_fixed_partitioning.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import Column, DateTime, String, UInt from snuba.clusters.storage_sets import StorageSetKey diff --git a/snuba/snuba_migrations/outcomes/0011_add_quantity64.py b/snuba/snuba_migrations/outcomes/0011_add_quantity64.py index 98e0359ca0b..c4871724af2 100644 --- a/snuba/snuba_migrations/outcomes/0011_add_quantity64.py +++ b/snuba/snuba_migrations/outcomes/0011_add_quantity64.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clusters.storage_sets import StorageSetKey from snuba.migrations.migration import ClickhouseNodeMigration diff --git a/snuba/snuba_migrations/profile_chunks/0001_create_profile_chunks_table.py b/snuba/snuba_migrations/profile_chunks/0001_create_profile_chunks_table.py index 0f77b90f350..a6b06dd4a5f 100644 --- a/snuba/snuba_migrations/profile_chunks/0001_create_profile_chunks_table.py +++ b/snuba/snuba_migrations/profile_chunks/0001_create_profile_chunks_table.py @@ -1,4 +1,4 @@ -from typing import List, Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import UUID, Column, DateTime64, UInt from snuba.clusters.storage_sets import StorageSetKey @@ -11,7 +11,7 @@ local_table_name = f"{table_prefix}_local" dist_table_name = f"{table_prefix}_dist" -columns: List[Column[Modifiers]] = [ +columns: list[Column[Modifiers]] = [ Column("project_id", UInt(64)), Column("profiler_id", UUID()), Column("chunk_id", UUID()), diff --git a/snuba/snuba_migrations/profile_chunks/0002_add_environment_column.py b/snuba/snuba_migrations/profile_chunks/0002_add_environment_column.py index 48dd8d71711..5f4aa7ec46f 100644 --- a/snuba/snuba_migrations/profile_chunks/0002_add_environment_column.py +++ b/snuba/snuba_migrations/profile_chunks/0002_add_environment_column.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import Column, String from snuba.clusters.storage_sets import StorageSetKey diff --git a/snuba/snuba_migrations/profiles/0001_profiles.py b/snuba/snuba_migrations/profiles/0001_profiles.py index 1c9e384a999..cccb799f891 100644 --- a/snuba/snuba_migrations/profiles/0001_profiles.py +++ b/snuba/snuba_migrations/profiles/0001_profiles.py @@ -1,11 +1,11 @@ -from typing import List, Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import UUID, Column, DateTime, String, UInt from snuba.clusters.storage_sets import StorageSetKey from snuba.migrations import migration, operations, table_engines from snuba.migrations.columns import MigrationModifiers as Modifiers -columns: List[Column[Modifiers]] = [ +columns: list[Column[Modifiers]] = [ # primary key Column("organization_id", UInt(64)), Column("project_id", UInt(64)), @@ -20,9 +20,7 @@ Column("device_locale", String(Modifiers(low_cardinality=True))), Column("device_manufacturer", String(Modifiers(low_cardinality=True))), Column("device_model", String(Modifiers(low_cardinality=True))), - Column( - "device_os_build_number", String(Modifiers(low_cardinality=True, nullable=True)) - ), + Column("device_os_build_number", String(Modifiers(low_cardinality=True, nullable=True))), Column("device_os_name", String(Modifiers(low_cardinality=True))), Column("device_os_version", String(Modifiers(low_cardinality=True))), Column("duration_ns", UInt(64)), diff --git a/snuba/snuba_migrations/profiles/0002_disable_vertical_merge_algorithm.py b/snuba/snuba_migrations/profiles/0002_disable_vertical_merge_algorithm.py index 648caffe900..1688619c7ec 100644 --- a/snuba/snuba_migrations/profiles/0002_disable_vertical_merge_algorithm.py +++ b/snuba/snuba_migrations/profiles/0002_disable_vertical_merge_algorithm.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clusters.storage_sets import StorageSetKey from snuba.migrations import migration, operations diff --git a/snuba/snuba_migrations/profiles/0003_add_device_architecture.py b/snuba/snuba_migrations/profiles/0003_add_device_architecture.py index dc4857b773d..6589efba290 100644 --- a/snuba/snuba_migrations/profiles/0003_add_device_architecture.py +++ b/snuba/snuba_migrations/profiles/0003_add_device_architecture.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import Column, String from snuba.clusters.storage_sets import StorageSetKey diff --git a/snuba/snuba_migrations/profiles/0004_drop_profile_column.py b/snuba/snuba_migrations/profiles/0004_drop_profile_column.py index e501a3722a8..7c7c575f68d 100644 --- a/snuba/snuba_migrations/profiles/0004_drop_profile_column.py +++ b/snuba/snuba_migrations/profiles/0004_drop_profile_column.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import Column, String from snuba.clusters.storage_sets import StorageSetKey diff --git a/snuba/snuba_migrations/querylog/0001_querylog.py b/snuba/snuba_migrations/querylog/0001_querylog.py index 5d427665d4e..9406e4a3c74 100644 --- a/snuba/snuba_migrations/querylog/0001_querylog.py +++ b/snuba/snuba_migrations/querylog/0001_querylog.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import ( UUID, diff --git a/snuba/snuba_migrations/querylog/0002_status_type_change.py b/snuba/snuba_migrations/querylog/0002_status_type_change.py index 9b7fee4841e..fc7c8e2da50 100644 --- a/snuba/snuba_migrations/querylog/0002_status_type_change.py +++ b/snuba/snuba_migrations/querylog/0002_status_type_change.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import Array, Column, Enum, String from snuba.clusters.storage_sets import StorageSetKey @@ -14,9 +14,7 @@ class Migration(migration.ClickhouseNodeMigrationLegacy): blocking = True - def __forward_migrations( - self, table_name: str - ) -> Sequence[operations.SqlOperation]: + def __forward_migrations(self, table_name: str) -> Sequence[operations.SqlOperation]: return [ operations.ModifyColumn( StorageSetKey.QUERYLOG, @@ -33,12 +31,8 @@ def __forward_migrations( ), ] - def __backwards_migrations( - self, table_name: str - ) -> Sequence[operations.SqlOperation]: - status_type = Enum[Modifiers]( - [("success", 0), ("error", 1), ("rate-limited", 2)] - ) + def __backwards_migrations(self, table_name: str) -> Sequence[operations.SqlOperation]: + status_type = Enum[Modifiers]([("success", 0), ("error", 1), ("rate-limited", 2)]) return [ operations.ModifyColumn( StorageSetKey.QUERYLOG, diff --git a/snuba/snuba_migrations/querylog/0003_add_profile_fields.py b/snuba/snuba_migrations/querylog/0003_add_profile_fields.py index e7825728f67..5a011e74792 100644 --- a/snuba/snuba_migrations/querylog/0003_add_profile_fields.py +++ b/snuba/snuba_migrations/querylog/0003_add_profile_fields.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import Array, Column, String, UInt from snuba.clusters.storage_sets import StorageSetKey @@ -13,9 +13,7 @@ class Migration(migration.ClickhouseNodeMigrationLegacy): blocking = True - def __forward_migrations( - self, table_name: str - ) -> Sequence[operations.SqlOperation]: + def __forward_migrations(self, table_name: str) -> Sequence[operations.SqlOperation]: return [ operations.AddColumn( storage_set=StorageSetKey.QUERYLOG, @@ -23,10 +21,8 @@ def __forward_migrations( column=Column( "clickhouse_queries.all_columns", Array( - Array((String(Modifiers(low_cardinality=True)))), - Modifiers( - default="arrayResize([['']], length(clickhouse_queries.sql))" - ), + Array(String(Modifiers(low_cardinality=True))), + Modifiers(default="arrayResize([['']], length(clickhouse_queries.sql))"), ), ), after="clickhouse_queries.consistent", @@ -38,9 +34,7 @@ def __forward_migrations( "clickhouse_queries.or_conditions", Array( UInt(8), - Modifiers( - default="arrayResize([0], length(clickhouse_queries.sql))" - ), + Modifiers(default="arrayResize([0], length(clickhouse_queries.sql))"), ), ), after="clickhouse_queries.all_columns", @@ -52,9 +46,7 @@ def __forward_migrations( "clickhouse_queries.where_columns", Array( Array(String(Modifiers(low_cardinality=True))), - Modifiers( - default="arrayResize([['']], length(clickhouse_queries.sql))" - ), + Modifiers(default="arrayResize([['']], length(clickhouse_queries.sql))"), ), ), after="clickhouse_queries.or_conditions", @@ -66,9 +58,7 @@ def __forward_migrations( "clickhouse_queries.where_mapping_columns", Array( Array(String(Modifiers(low_cardinality=True))), - Modifiers( - default="arrayResize([['']], length(clickhouse_queries.sql))" - ), + Modifiers(default="arrayResize([['']], length(clickhouse_queries.sql))"), ), ), after="clickhouse_queries.where_columns", @@ -80,9 +70,7 @@ def __forward_migrations( "clickhouse_queries.groupby_columns", Array( Array(String(Modifiers(low_cardinality=True))), - Modifiers( - default="arrayResize([['']], length(clickhouse_queries.sql))" - ), + Modifiers(default="arrayResize([['']], length(clickhouse_queries.sql))"), ), ), after="clickhouse_queries.where_mapping_columns", @@ -94,18 +82,14 @@ def __forward_migrations( "clickhouse_queries.array_join_columns", Array( Array(String(Modifiers(low_cardinality=True))), - Modifiers( - default="arrayResize([['']], length(clickhouse_queries.sql))" - ), + Modifiers(default="arrayResize([['']], length(clickhouse_queries.sql))"), ), ), after="clickhouse_queries.groupby_columns", ), ] - def __backwards_migrations( - self, table_name: str - ) -> Sequence[operations.SqlOperation]: + def __backwards_migrations(self, table_name: str) -> Sequence[operations.SqlOperation]: return [ operations.DropColumn( StorageSetKey.QUERYLOG, table_name, "clickhouse_queries.all_columns" diff --git a/snuba/snuba_migrations/querylog/0004_add_bytes_scanned.py b/snuba/snuba_migrations/querylog/0004_add_bytes_scanned.py index c9a55bd9eb9..071e54582cb 100644 --- a/snuba/snuba_migrations/querylog/0004_add_bytes_scanned.py +++ b/snuba/snuba_migrations/querylog/0004_add_bytes_scanned.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import Array, Column, UInt from snuba.clusters.storage_sets import StorageSetKey @@ -12,9 +12,7 @@ class Migration(migration.ClickhouseNodeMigrationLegacy): blocking = True - def __forward_migrations( - self, table_name: str - ) -> Sequence[operations.SqlOperation]: + def __forward_migrations(self, table_name: str) -> Sequence[operations.SqlOperation]: return [ operations.AddColumn( storage_set=StorageSetKey.QUERYLOG, @@ -29,9 +27,7 @@ def __forward_migrations( ), ] - def __backwards_migrations( - self, table_name: str - ) -> Sequence[operations.SqlOperation]: + def __backwards_migrations(self, table_name: str) -> Sequence[operations.SqlOperation]: return [ operations.DropColumn( StorageSetKey.QUERYLOG, table_name, "clickhouse_queries.bytes_scanned" diff --git a/snuba/snuba_migrations/querylog/0005_add_codec_update_settings.py b/snuba/snuba_migrations/querylog/0005_add_codec_update_settings.py index 81acb99131d..9fe4ca68173 100644 --- a/snuba/snuba_migrations/querylog/0005_add_codec_update_settings.py +++ b/snuba/snuba_migrations/querylog/0005_add_codec_update_settings.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import UUID, Array, Column, String from snuba.clusters.storage_sets import StorageSetKey @@ -16,7 +16,6 @@ class Migration(migration.ClickhouseNodeMigration): blocking = False def forwards_ops(self) -> Sequence[operations.SqlOperation]: - return [ operations.ModifyTableTTL( StorageSetKey.QUERYLOG, diff --git a/snuba/snuba_migrations/querylog/0006_sorting_key_change.py b/snuba/snuba_migrations/querylog/0006_sorting_key_change.py index ace42d8b351..b6a0c204058 100644 --- a/snuba/snuba_migrations/querylog/0006_sorting_key_change.py +++ b/snuba/snuba_migrations/querylog/0006_sorting_key_change.py @@ -1,7 +1,7 @@ import logging import math import time -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.native import ClickhousePool from snuba.clusters.cluster import ClickhouseClientSettings, get_cluster @@ -24,9 +24,7 @@ def update_querylog_table(clickhouse: ClickhousePool, database: str) -> None: f"SELECT sampling_key, sorting_key FROM system.tables WHERE name = '{TABLE_NAME}' AND database = '{database}'" ).results - new_create_table_statement = curr_create_table_statement.replace( - TABLE_NAME, TABLE_NAME_NEW - ) + new_create_table_statement = curr_create_table_statement.replace(TABLE_NAME, TABLE_NAME_NEW) # Switch the sorting key if curr_sorting_key != new_sorting_key: @@ -71,9 +69,7 @@ def update_querylog_table(clickhouse: ClickhousePool, database: str) -> None: clickhouse.execute(insert_op.format_sql()) # Ensure each table has the same number of rows before deleting the old one - [(new_row_count,)] = clickhouse.execute( - f"SELECT count() FROM {TABLE_NAME_NEW}" - ).results + [(new_row_count,)] = clickhouse.execute(f"SELECT count() FROM {TABLE_NAME_NEW}").results assert row_count == new_row_count clickhouse.execute(f"RENAME TABLE {TABLE_NAME} TO {TABLE_NAME_OLD};") diff --git a/snuba/snuba_migrations/querylog/0007_add_offset_column.py b/snuba/snuba_migrations/querylog/0007_add_offset_column.py index 22dd78bde0d..0b4439cd31d 100644 --- a/snuba/snuba_migrations/querylog/0007_add_offset_column.py +++ b/snuba/snuba_migrations/querylog/0007_add_offset_column.py @@ -1,11 +1,11 @@ -from typing import Generator, Sequence, Tuple +from collections.abc import Generator, Sequence from snuba.clickhouse.columns import Column, UInt from snuba.clusters.storage_sets import StorageSetKey from snuba.migrations import migration, operations from snuba.migrations.columns import MigrationModifiers as Modifiers -columns: Sequence[Tuple[Column[Modifiers], str]] = [ +columns: Sequence[tuple[Column[Modifiers], str]] = [ (Column("partition", UInt(16)), "status"), (Column("offset", UInt(64)), "partition"), ] @@ -26,7 +26,7 @@ def backwards_ops(self) -> Sequence[operations.SqlOperation]: return list(_backward()) -def _forward() -> Generator[operations.SqlOperation, None, None]: +def _forward() -> Generator[operations.SqlOperation]: for column, after in columns: yield operations.AddColumn( StorageSetKey.QUERYLOG, @@ -45,8 +45,8 @@ def _forward() -> Generator[operations.SqlOperation, None, None]: ) -def _backward() -> Generator[operations.SqlOperation, None, None]: - for column, after in columns: +def _backward() -> Generator[operations.SqlOperation]: + for column, _after in columns: yield operations.DropColumn( storage_set=StorageSetKey.QUERYLOG, table_name="querylog_dist", diff --git a/snuba/snuba_migrations/replays/0001_replays.py b/snuba/snuba_migrations/replays/0001_replays.py index cd04bb72bbd..7fcc0594c53 100644 --- a/snuba/snuba_migrations/replays/0001_replays.py +++ b/snuba/snuba_migrations/replays/0001_replays.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import ( UUID, @@ -24,9 +24,7 @@ Column("trace_ids", Array(UUID())), Column( "_trace_ids_hashed", - Array( - UInt(64), Modifiers(materialized="arrayMap(t -> cityHash64(t), trace_ids)") - ), + Array(UInt(64), Modifiers(materialized="arrayMap(t -> cityHash64(t), trace_ids)")), ), Column("title", String()), ### columns used by other sentry events @@ -106,7 +104,5 @@ def forwards_dist(self) -> Sequence[operations.SqlOperation]: def backwards_dist(self) -> Sequence[operations.SqlOperation]: return [ - operations.DropTable( - storage_set=StorageSetKey.REPLAYS, table_name="replays_dist" - ), + operations.DropTable(storage_set=StorageSetKey.REPLAYS, table_name="replays_dist"), ] diff --git a/snuba/snuba_migrations/replays/0002_add_url.py b/snuba/snuba_migrations/replays/0002_add_url.py index 899f481f342..2c24e80bc38 100644 --- a/snuba/snuba_migrations/replays/0002_add_url.py +++ b/snuba/snuba_migrations/replays/0002_add_url.py @@ -1,11 +1,11 @@ -from typing import Sequence, Tuple +from collections.abc import Sequence from snuba.clickhouse.columns import Column, String from snuba.clusters.storage_sets import StorageSetKey from snuba.migrations import migration, operations from snuba.migrations.columns import MigrationModifiers as Modifiers -new_columns: Sequence[Tuple[Column[Modifiers], str]] = [ +new_columns: Sequence[tuple[Column[Modifiers], str]] = [ (Column("url", String()), "title"), ] diff --git a/snuba/snuba_migrations/replays/0003_alter_url_allow_null.py b/snuba/snuba_migrations/replays/0003_alter_url_allow_null.py index 11f34f06990..f6f6e9770bb 100644 --- a/snuba/snuba_migrations/replays/0003_alter_url_allow_null.py +++ b/snuba/snuba_migrations/replays/0003_alter_url_allow_null.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import Column, String from snuba.clusters.storage_sets import StorageSetKey diff --git a/snuba/snuba_migrations/replays/0004_add_error_ids_column.py b/snuba/snuba_migrations/replays/0004_add_error_ids_column.py index c13b5fad48b..27eacf9128f 100644 --- a/snuba/snuba_migrations/replays/0004_add_error_ids_column.py +++ b/snuba/snuba_migrations/replays/0004_add_error_ids_column.py @@ -1,11 +1,11 @@ -from typing import List, Sequence, Tuple +from collections.abc import Sequence from snuba.clickhouse.columns import UUID, Array, Column, UInt from snuba.clusters.storage_sets import StorageSetKey from snuba.migrations import migration, operations from snuba.migrations.columns import MigrationModifiers as Modifiers -new_columns: Sequence[Tuple[Column[Modifiers], str]] = [ +new_columns: Sequence[tuple[Column[Modifiers], str]] = [ (Column("error_ids", Array(UUID())), "url"), ( Column( @@ -19,7 +19,7 @@ ), ] -new_indexes: List[operations.SqlOperation] = [ +new_indexes: list[operations.SqlOperation] = [ operations.AddIndex( storage_set=StorageSetKey.REPLAYS, table_name="replays_local", @@ -30,7 +30,7 @@ ), ] -drop_indexes: List[operations.SqlOperation] = [ +drop_indexes: list[operations.SqlOperation] = [ operations.DropIndex( StorageSetKey.REPLAYS, "replays_local", @@ -43,7 +43,7 @@ class Migration(migration.ClickhouseNodeMigrationLegacy): blocking = False def forwards_local(self) -> Sequence[operations.SqlOperation]: - new_column_ops: List[operations.SqlOperation] = [ + new_column_ops: list[operations.SqlOperation] = [ operations.AddColumn( storage_set=StorageSetKey.REPLAYS, table_name="replays_local", @@ -55,7 +55,7 @@ def forwards_local(self) -> Sequence[operations.SqlOperation]: return new_column_ops + new_indexes def backwards_local(self) -> Sequence[operations.SqlOperation]: - drop_column_ops: List[operations.SqlOperation] = [ + drop_column_ops: list[operations.SqlOperation] = [ operations.DropColumn(StorageSetKey.REPLAYS, "replays_local", column.name) for column, _ in reversed(new_columns) ] diff --git a/snuba/snuba_migrations/replays/0005_add_urls_user_agent_replay_start_timestamp.py b/snuba/snuba_migrations/replays/0005_add_urls_user_agent_replay_start_timestamp.py index 191bea1fe4a..cd749ee5e9d 100644 --- a/snuba/snuba_migrations/replays/0005_add_urls_user_agent_replay_start_timestamp.py +++ b/snuba/snuba_migrations/replays/0005_add_urls_user_agent_replay_start_timestamp.py @@ -1,11 +1,11 @@ -from typing import List, Sequence, Tuple +from collections.abc import Sequence from snuba.clickhouse.columns import Array, Column, DateTime, String from snuba.clusters.storage_sets import StorageSetKey from snuba.migrations import migration, operations from snuba.migrations.columns import MigrationModifiers as Modifiers -new_columns: Sequence[Tuple[Column[Modifiers], str]] = [ +new_columns: Sequence[tuple[Column[Modifiers], str]] = [ (Column("urls", Array(String())), "url"), (Column("replay_start_timestamp", DateTime(Modifiers(nullable=True))), "timestamp"), # OS @@ -26,7 +26,7 @@ class Migration(migration.ClickhouseNodeMigrationLegacy): blocking = False def forwards_local(self) -> Sequence[operations.SqlOperation]: - new_column_ops: List[operations.SqlOperation] = [ + new_column_ops: list[operations.SqlOperation] = [ operations.AddColumn( storage_set=StorageSetKey.REPLAYS, table_name="replays_local", @@ -38,7 +38,7 @@ def forwards_local(self) -> Sequence[operations.SqlOperation]: return new_column_ops def backwards_local(self) -> Sequence[operations.SqlOperation]: - drop_column_ops: List[operations.SqlOperation] = [ + drop_column_ops: list[operations.SqlOperation] = [ operations.DropColumn(StorageSetKey.REPLAYS, "replays_local", column.name) for column, _ in reversed(new_columns) ] @@ -46,7 +46,7 @@ def backwards_local(self) -> Sequence[operations.SqlOperation]: return drop_column_ops def forwards_dist(self) -> Sequence[operations.SqlOperation]: - new_column_ops: List[operations.SqlOperation] = [ + new_column_ops: list[operations.SqlOperation] = [ operations.AddColumn( storage_set=StorageSetKey.REPLAYS, table_name="replays_dist", @@ -58,7 +58,7 @@ def forwards_dist(self) -> Sequence[operations.SqlOperation]: return new_column_ops def backwards_dist(self) -> Sequence[operations.SqlOperation]: - drop_column_ops: List[operations.SqlOperation] = [ + drop_column_ops: list[operations.SqlOperation] = [ operations.DropColumn(StorageSetKey.REPLAYS, "replays_dist", column.name) for column, _ in reversed(new_columns) ] diff --git a/snuba/snuba_migrations/replays/0006_add_is_archived_column.py b/snuba/snuba_migrations/replays/0006_add_is_archived_column.py index 75adca5c566..e7b90d63943 100644 --- a/snuba/snuba_migrations/replays/0006_add_is_archived_column.py +++ b/snuba/snuba_migrations/replays/0006_add_is_archived_column.py @@ -1,11 +1,11 @@ -from typing import List, Sequence, Tuple +from collections.abc import Sequence from snuba.clickhouse.columns import Column, UInt from snuba.clusters.storage_sets import StorageSetKey from snuba.migrations import migration, operations from snuba.migrations.columns import MigrationModifiers as Modifiers -new_columns: Sequence[Tuple[Column[Modifiers], str]] = [ +new_columns: Sequence[tuple[Column[Modifiers], str]] = [ (Column("is_archived", UInt(8, Modifiers(nullable=True))), "urls") ] @@ -14,7 +14,7 @@ class Migration(migration.ClickhouseNodeMigrationLegacy): blocking = False def forwards_local(self) -> Sequence[operations.SqlOperation]: - ops: List[operations.SqlOperation] = [] + ops: list[operations.SqlOperation] = [] for column, after in new_columns: ops.append( @@ -29,19 +29,15 @@ def forwards_local(self) -> Sequence[operations.SqlOperation]: return ops def backwards_local(self) -> Sequence[operations.SqlOperation]: - ops: List[operations.SqlOperation] = [] + ops: list[operations.SqlOperation] = [] for column, _ in reversed(new_columns): - ops.append( - operations.DropColumn( - StorageSetKey.REPLAYS, "replays_local", column.name - ) - ) + ops.append(operations.DropColumn(StorageSetKey.REPLAYS, "replays_local", column.name)) return ops def forwards_dist(self) -> Sequence[operations.SqlOperation]: - ops: List[operations.SqlOperation] = [] + ops: list[operations.SqlOperation] = [] for column, after in new_columns: ops.append( @@ -56,13 +52,9 @@ def forwards_dist(self) -> Sequence[operations.SqlOperation]: return ops def backwards_dist(self) -> Sequence[operations.SqlOperation]: - ops: List[operations.SqlOperation] = [] + ops: list[operations.SqlOperation] = [] for column, _ in reversed(new_columns): - ops.append( - operations.DropColumn( - StorageSetKey.REPLAYS, "replays_dist", column.name - ) - ) + ops.append(operations.DropColumn(StorageSetKey.REPLAYS, "replays_dist", column.name)) return ops diff --git a/snuba/snuba_migrations/replays/0007_add_replay_type_column.py b/snuba/snuba_migrations/replays/0007_add_replay_type_column.py index 4ea9d2f46f6..45479874cbc 100644 --- a/snuba/snuba_migrations/replays/0007_add_replay_type_column.py +++ b/snuba/snuba_migrations/replays/0007_add_replay_type_column.py @@ -1,11 +1,11 @@ -from typing import List, Sequence, Tuple +from collections.abc import Sequence from snuba.clickhouse.columns import Column, String from snuba.clusters.storage_sets import StorageSetKey from snuba.migrations import migration, operations from snuba.migrations.columns import MigrationModifiers as Modifiers -new_columns: Sequence[Tuple[Column[Modifiers], str]] = [ +new_columns: Sequence[tuple[Column[Modifiers], str]] = [ ( Column("replay_type", String(Modifiers(low_cardinality=True, nullable=True))), "replay_id", @@ -17,7 +17,7 @@ class Migration(migration.ClickhouseNodeMigrationLegacy): blocking = False def forwards_local(self) -> Sequence[operations.SqlOperation]: - ops: List[operations.SqlOperation] = [] + ops: list[operations.SqlOperation] = [] for column, after in new_columns: ops.append( @@ -32,19 +32,15 @@ def forwards_local(self) -> Sequence[operations.SqlOperation]: return ops def backwards_local(self) -> Sequence[operations.SqlOperation]: - ops: List[operations.SqlOperation] = [] + ops: list[operations.SqlOperation] = [] for column, _ in reversed(new_columns): - ops.append( - operations.DropColumn( - StorageSetKey.REPLAYS, "replays_local", column.name - ) - ) + ops.append(operations.DropColumn(StorageSetKey.REPLAYS, "replays_local", column.name)) return ops def forwards_dist(self) -> Sequence[operations.SqlOperation]: - ops: List[operations.SqlOperation] = [] + ops: list[operations.SqlOperation] = [] for column, after in new_columns: ops.append( @@ -59,13 +55,9 @@ def forwards_dist(self) -> Sequence[operations.SqlOperation]: return ops def backwards_dist(self) -> Sequence[operations.SqlOperation]: - ops: List[operations.SqlOperation] = [] + ops: list[operations.SqlOperation] = [] for column, _ in reversed(new_columns): - ops.append( - operations.DropColumn( - StorageSetKey.REPLAYS, "replays_dist", column.name - ) - ) + ops.append(operations.DropColumn(StorageSetKey.REPLAYS, "replays_dist", column.name)) return ops diff --git a/snuba/snuba_migrations/replays/0008_add_sample_rate.py b/snuba/snuba_migrations/replays/0008_add_sample_rate.py index 2082c77e5ef..ee4a1025314 100644 --- a/snuba/snuba_migrations/replays/0008_add_sample_rate.py +++ b/snuba/snuba_migrations/replays/0008_add_sample_rate.py @@ -1,11 +1,11 @@ -from typing import Iterator, Sequence, Tuple +from collections.abc import Iterator, Sequence from snuba.clickhouse.columns import Column, Float from snuba.clusters.storage_sets import StorageSetKey from snuba.migrations import migration, operations from snuba.migrations.columns import MigrationModifiers as Modifiers -columns: Sequence[Tuple[Column[Modifiers], str]] = [ +columns: Sequence[tuple[Column[Modifiers], str]] = [ ( Column("error_sample_rate", Float(64, Modifiers(nullable=True))), "replay_type", diff --git a/snuba/snuba_migrations/replays/0009_add_dom_index_columns.py b/snuba/snuba_migrations/replays/0009_add_dom_index_columns.py index e32d5c932fe..598fa206e46 100644 --- a/snuba/snuba_migrations/replays/0009_add_dom_index_columns.py +++ b/snuba/snuba_migrations/replays/0009_add_dom_index_columns.py @@ -1,11 +1,11 @@ -from typing import Iterator, Sequence, Tuple +from collections.abc import Iterator, Sequence from snuba.clickhouse.columns import Array, Column, String, UInt from snuba.clusters.storage_sets import StorageSetKey from snuba.migrations import migration, operations from snuba.migrations.columns import MigrationModifiers as Modifiers -columns: Sequence[Tuple[Column[Modifiers], str]] = [ +columns: Sequence[tuple[Column[Modifiers], str]] = [ (Column("click_node_id", UInt(32, Modifiers(default="0"))), "tags.value"), ( Column("click_tag", String(Modifiers(default="''", low_cardinality=True))), diff --git a/snuba/snuba_migrations/replays/0010_add_nullable_columns.py b/snuba/snuba_migrations/replays/0010_add_nullable_columns.py index 9c2ac1861d8..d224bf1318b 100644 --- a/snuba/snuba_migrations/replays/0010_add_nullable_columns.py +++ b/snuba/snuba_migrations/replays/0010_add_nullable_columns.py @@ -1,11 +1,11 @@ -from typing import Iterator, Sequence, Tuple +from collections.abc import Iterator, Sequence from snuba.clickhouse.columns import Column, String from snuba.clusters.storage_sets import StorageSetKey from snuba.migrations import migration, operations from snuba.migrations.columns import MigrationModifiers as Modifiers -columns: Sequence[Tuple[Column[Modifiers], Column[Modifiers]]] = [ +columns: Sequence[tuple[Column[Modifiers], Column[Modifiers]]] = [ ( Column("title", String(Modifiers(nullable=True))), Column("title", String(Modifiers(nullable=False, default="''"))), diff --git a/snuba/snuba_migrations/replays/0011_add_is_dead_rage.py b/snuba/snuba_migrations/replays/0011_add_is_dead_rage.py index 5ccf2e645ad..2d679a3d9d7 100644 --- a/snuba/snuba_migrations/replays/0011_add_is_dead_rage.py +++ b/snuba/snuba_migrations/replays/0011_add_is_dead_rage.py @@ -1,4 +1,4 @@ -from typing import Iterator, Sequence, Tuple +from collections.abc import Iterator, Sequence from snuba.clickhouse.columns import Column, UInt from snuba.clusters.storage_sets import StorageSetKey @@ -6,7 +6,7 @@ from snuba.migrations.columns import MigrationModifiers as Modifiers # Columns to be added to the table. -new_columns: Sequence[Tuple[Column[Modifiers], str]] = [ +new_columns: Sequence[tuple[Column[Modifiers], str]] = [ (Column("click_is_dead", UInt(8)), "click_title"), (Column("click_is_rage", UInt(8)), "click_is_dead"), ] diff --git a/snuba/snuba_migrations/replays/0012_materialize_counts.py b/snuba/snuba_migrations/replays/0012_materialize_counts.py index 142e29206b1..77bbfefcff1 100644 --- a/snuba/snuba_migrations/replays/0012_materialize_counts.py +++ b/snuba/snuba_migrations/replays/0012_materialize_counts.py @@ -1,4 +1,4 @@ -from typing import Iterator, List, Sequence, Tuple +from collections.abc import Iterator, Sequence from snuba.clickhouse.columns import Column, UInt from snuba.clusters.storage_sets import StorageSetKey @@ -52,7 +52,7 @@ def backward_columns_iter() -> Iterator[operations.SqlOperation]: ) -columns: List[Tuple[str, Column[Modifiers]]] = [ +columns: list[tuple[str, Column[Modifiers]]] = [ ( "click_is_rage", Column("count_errors", UInt(16, Modifiers(materialized="length(error_ids)"))), diff --git a/snuba/snuba_migrations/replays/0013_add_low_cardinality_codecs.py b/snuba/snuba_migrations/replays/0013_add_low_cardinality_codecs.py index aee8fda823e..ef559ca9868 100644 --- a/snuba/snuba_migrations/replays/0013_add_low_cardinality_codecs.py +++ b/snuba/snuba_migrations/replays/0013_add_low_cardinality_codecs.py @@ -1,4 +1,4 @@ -from typing import Iterator, List, Sequence +from collections.abc import Iterator, Sequence from snuba.clickhouse.columns import Column, String from snuba.clusters.storage_sets import StorageSetKey @@ -21,18 +21,14 @@ def forward_columns_iter() -> Iterator[operations.ModifyColumn]: yield operations.ModifyColumn( storage_set=StorageSetKey.REPLAYS, table_name="replays_local", - column=Column( - column_name, String(Modifiers(nullable=True, low_cardinality=True)) - ), + column=Column(column_name, String(Modifiers(nullable=True, low_cardinality=True))), target=operations.OperationTarget.LOCAL, ) yield operations.ModifyColumn( storage_set=StorageSetKey.REPLAYS, table_name="replays_dist", - column=Column( - column_name, String(Modifiers(nullable=True, low_cardinality=True)) - ), + column=Column(column_name, String(Modifiers(nullable=True, low_cardinality=True))), target=operations.OperationTarget.DISTRIBUTED, ) @@ -42,23 +38,19 @@ def backward_columns_iter() -> Iterator[operations.SqlOperation]: yield operations.ModifyColumn( storage_set=StorageSetKey.REPLAYS, table_name="replays_dist", - column=Column( - column_name, String(Modifiers(nullable=True, low_cardinality=False)) - ), + column=Column(column_name, String(Modifiers(nullable=True, low_cardinality=False))), target=operations.OperationTarget.DISTRIBUTED, ) yield operations.ModifyColumn( storage_set=StorageSetKey.REPLAYS, table_name="replays_local", - column=Column( - column_name, String(Modifiers(nullable=True, low_cardinality=False)) - ), + column=Column(column_name, String(Modifiers(nullable=True, low_cardinality=False))), target=operations.OperationTarget.LOCAL, ) -columns: List[str] = [ +columns: list[str] = [ "browser_name", "device_brand", "device_family", diff --git a/snuba/snuba_migrations/replays/0014_add_id_event_columns.py b/snuba/snuba_migrations/replays/0014_add_id_event_columns.py index 4298d5e624a..0b7047fca80 100644 --- a/snuba/snuba_migrations/replays/0014_add_id_event_columns.py +++ b/snuba/snuba_migrations/replays/0014_add_id_event_columns.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import UUID, Column from snuba.clusters.storage_sets import StorageSetKey diff --git a/snuba/snuba_migrations/replays/0015_index_frequently_accessed_columns.py b/snuba/snuba_migrations/replays/0015_index_frequently_accessed_columns.py index 63a932a4e36..34824c3172a 100644 --- a/snuba/snuba_migrations/replays/0015_index_frequently_accessed_columns.py +++ b/snuba/snuba_migrations/replays/0015_index_frequently_accessed_columns.py @@ -1,4 +1,4 @@ -from typing import Iterator, Sequence +from collections.abc import Iterator, Sequence from snuba.clusters.storage_sets import StorageSetKey from snuba.migrations import migration, operations diff --git a/snuba/snuba_migrations/replays/0016_materialize_new_event_counts.py b/snuba/snuba_migrations/replays/0016_materialize_new_event_counts.py index df64504051c..b0ac586d27f 100644 --- a/snuba/snuba_migrations/replays/0016_materialize_new_event_counts.py +++ b/snuba/snuba_migrations/replays/0016_materialize_new_event_counts.py @@ -1,4 +1,4 @@ -from typing import Iterator, List, Sequence, Tuple +from collections.abc import Iterator, Sequence from snuba.clickhouse.columns import Column, UInt from snuba.clusters.storage_sets import StorageSetKey @@ -52,7 +52,7 @@ def backward_columns_iter() -> Iterator[operations.SqlOperation]: ) -columns: List[Tuple[str, Column[Modifiers]]] = [ +columns: list[tuple[str, Column[Modifiers]]] = [ ( "debug_id", Column( @@ -71,9 +71,7 @@ def backward_columns_iter() -> Iterator[operations.SqlOperation]: "count_warning_events", UInt( 8, - Modifiers( - materialized="warning_id != '00000000-0000-0000-0000-000000000000'" - ), + Modifiers(materialized="warning_id != '00000000-0000-0000-0000-000000000000'"), ), ), ), diff --git a/snuba/snuba_migrations/replays/0017_add_component_name_column.py b/snuba/snuba_migrations/replays/0017_add_component_name_column.py index 8ed4621ae00..f8979efcf13 100644 --- a/snuba/snuba_migrations/replays/0017_add_component_name_column.py +++ b/snuba/snuba_migrations/replays/0017_add_component_name_column.py @@ -1,4 +1,4 @@ -from typing import Iterator, List, Sequence, Tuple +from collections.abc import Iterator, Sequence from snuba.clickhouse.columns import Column, String from snuba.clusters.storage_sets import StorageSetKey @@ -52,6 +52,6 @@ def backward_columns_iter() -> Iterator[operations.SqlOperation]: ) -columns: List[Tuple[str, Column[Modifiers]]] = [ +columns: list[tuple[str, Column[Modifiers]]] = [ ("click_title", Column("click_component_name", String())) ] diff --git a/snuba/snuba_migrations/replays/0018_add_viewed_by_id_column.py b/snuba/snuba_migrations/replays/0018_add_viewed_by_id_column.py index 14807e7e59b..185c0f7fd04 100644 --- a/snuba/snuba_migrations/replays/0018_add_viewed_by_id_column.py +++ b/snuba/snuba_migrations/replays/0018_add_viewed_by_id_column.py @@ -1,4 +1,4 @@ -from typing import Iterator, List, Sequence, Tuple +from collections.abc import Iterator, Sequence from snuba.clickhouse.columns import Column, UInt from snuba.clusters.storage_sets import StorageSetKey @@ -52,6 +52,4 @@ def backward_columns_iter() -> Iterator[operations.SqlOperation]: ) -columns: List[Tuple[str, Column[Modifiers]]] = [ - ("user_email", Column("viewed_by_id", UInt(64))) -] +columns: list[tuple[str, Column[Modifiers]]] = [("user_email", Column("viewed_by_id", UInt(64)))] diff --git a/snuba/snuba_migrations/replays/0019_add_materialization.py b/snuba/snuba_migrations/replays/0019_add_materialization.py index cebb66223c9..84d734ac186 100644 --- a/snuba/snuba_migrations/replays/0019_add_materialization.py +++ b/snuba/snuba_migrations/replays/0019_add_materialization.py @@ -1,4 +1,4 @@ -from typing import Iterator, List, Sequence +from collections.abc import Iterator, Sequence from snuba.clickhouse.columns import ( UUID, @@ -117,9 +117,7 @@ def any_if_string( ) -def any_if_nullable_string( - column_name: str, low_cardinality: bool = False -) -> Column[Modifiers]: +def any_if_nullable_string(column_name: str, low_cardinality: bool = False) -> Column[Modifiers]: """Returns an aggregate anyIf function.""" return any_if_string(column_name, nullable=True, low_cardinality=low_cardinality) @@ -131,12 +129,10 @@ def sum(column_name: str) -> Column[Modifiers]: def count_nullable(column_name: str) -> Column[Modifiers]: """Returns an aggregate count function capable of accepting nullable integer values.""" - return Column( - column_name, AggregateFunction("count", [UInt(64, Modifiers(nullable=True))]) - ) + return Column(column_name, AggregateFunction("count", [UInt(64, Modifiers(nullable=True))])) -columns: List[Column[Modifiers]] = [ +columns: list[Column[Modifiers]] = [ # Primary-key. Column("project_id", UInt(64)), Column("to_hour_timestamp", DateTime()), @@ -161,20 +157,14 @@ def count_nullable(column_name: str) -> Column[Modifiers]: Column("finished_at", AggregateFunction("maxIf", [DateTime(), UInt(8)])), Column("ip_address_v4", AggregateFunction("any", [IPv4(Modifiers(nullable=True))])), Column("ip_address_v6", AggregateFunction("any", [IPv6(Modifiers(nullable=True))])), - Column( - "is_archived", AggregateFunction("sum", [UInt(64, Modifiers(nullable=True))]) - ), - Column( - "min_segment_id", AggregateFunction("min", [UInt(16, Modifiers(nullable=True))]) - ), + Column("is_archived", AggregateFunction("sum", [UInt(64, Modifiers(nullable=True))])), + Column("min_segment_id", AggregateFunction("min", [UInt(16, Modifiers(nullable=True))])), any_if_nullable_string("os_name"), any_if_nullable_string("os_version"), any_if_string("platform", low_cardinality=False), any_if_nullable_string("sdk_name"), any_if_nullable_string("sdk_version"), - Column( - "started_at", AggregateFunction("min", [DateTime(Modifiers(nullable=True))]) - ), + Column("started_at", AggregateFunction("min", [DateTime(Modifiers(nullable=True))])), any_if_nullable_string("user"), any_if_nullable_string("user_id"), any_if_nullable_string("user_name"), diff --git a/snuba/snuba_migrations/replays/0020_add_dist_migration_for_materialization.py b/snuba/snuba_migrations/replays/0020_add_dist_migration_for_materialization.py index ab1821ec893..5236d022583 100644 --- a/snuba/snuba_migrations/replays/0020_add_dist_migration_for_materialization.py +++ b/snuba/snuba_migrations/replays/0020_add_dist_migration_for_materialization.py @@ -1,4 +1,4 @@ -from typing import Iterator, List, Sequence +from collections.abc import Iterator, Sequence from snuba.clickhouse.columns import ( UUID, @@ -61,9 +61,7 @@ def any_if_string( ) -def any_if_nullable_string( - column_name: str, low_cardinality: bool = False -) -> Column[Modifiers]: +def any_if_nullable_string(column_name: str, low_cardinality: bool = False) -> Column[Modifiers]: """Returns an aggregate anyIf function.""" return any_if_string(column_name, nullable=True, low_cardinality=low_cardinality) @@ -75,12 +73,10 @@ def sum(column_name: str) -> Column[Modifiers]: def count_nullable(column_name: str) -> Column[Modifiers]: """Returns an aggregate count function capable of accepting nullable integer values.""" - return Column( - column_name, AggregateFunction("count", [UInt(64, Modifiers(nullable=True))]) - ) + return Column(column_name, AggregateFunction("count", [UInt(64, Modifiers(nullable=True))])) -columns: List[Column[Modifiers]] = [ +columns: list[Column[Modifiers]] = [ # Primary-key. Column("project_id", UInt(64)), Column("to_hour_timestamp", DateTime()), @@ -105,20 +101,14 @@ def count_nullable(column_name: str) -> Column[Modifiers]: Column("finished_at", AggregateFunction("maxIf", [DateTime(), UInt(8)])), Column("ip_address_v4", AggregateFunction("any", [IPv4(Modifiers(nullable=True))])), Column("ip_address_v6", AggregateFunction("any", [IPv6(Modifiers(nullable=True))])), - Column( - "is_archived", AggregateFunction("sum", [UInt(64, Modifiers(nullable=True))]) - ), - Column( - "min_segment_id", AggregateFunction("min", [UInt(16, Modifiers(nullable=True))]) - ), + Column("is_archived", AggregateFunction("sum", [UInt(64, Modifiers(nullable=True))])), + Column("min_segment_id", AggregateFunction("min", [UInt(16, Modifiers(nullable=True))])), any_if_nullable_string("os_name"), any_if_nullable_string("os_version"), any_if_string("platform", low_cardinality=False), any_if_nullable_string("sdk_name"), any_if_nullable_string("sdk_version"), - Column( - "started_at", AggregateFunction("min", [DateTime(Modifiers(nullable=True))]) - ), + Column("started_at", AggregateFunction("min", [DateTime(Modifiers(nullable=True))])), any_if_nullable_string("user"), any_if_nullable_string("user_id"), any_if_nullable_string("user_name"), diff --git a/snuba/snuba_migrations/replays/0021_index_tags.py b/snuba/snuba_migrations/replays/0021_index_tags.py index 77eccebed7e..ab014d1260c 100644 --- a/snuba/snuba_migrations/replays/0021_index_tags.py +++ b/snuba/snuba_migrations/replays/0021_index_tags.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import Array, Column, UInt from snuba.clusters.storage_sets import StorageSetKey diff --git a/snuba/snuba_migrations/replays/0022_add_context_ota_updates.py b/snuba/snuba_migrations/replays/0022_add_context_ota_updates.py index 1858a1816be..5c5c44fcbe5 100644 --- a/snuba/snuba_migrations/replays/0022_add_context_ota_updates.py +++ b/snuba/snuba_migrations/replays/0022_add_context_ota_updates.py @@ -1,11 +1,11 @@ -from typing import Iterator, Sequence, Tuple +from collections.abc import Iterator, Sequence from snuba.clickhouse.columns import Column, String from snuba.clusters.storage_sets import StorageSetKey from snuba.migrations import migration, operations from snuba.migrations.columns import MigrationModifiers as Modifiers -new_columns: Sequence[Tuple[Column[Modifiers], str]] = [ +new_columns: Sequence[tuple[Column[Modifiers], str]] = [ ( Column("ota_updates_channel", String()), "device_model", diff --git a/snuba/snuba_migrations/replays/0023_add_geo_columns.py b/snuba/snuba_migrations/replays/0023_add_geo_columns.py index b5273a17d4d..a15522115db 100644 --- a/snuba/snuba_migrations/replays/0023_add_geo_columns.py +++ b/snuba/snuba_migrations/replays/0023_add_geo_columns.py @@ -1,4 +1,4 @@ -from typing import Iterator, Sequence, Tuple +from collections.abc import Iterator, Sequence from snuba.clusters.storage_sets import StorageSetKey from snuba.migrations.columns import MigrationModifiers as Modifiers @@ -11,7 +11,7 @@ ) from snuba.utils.schemas import Column, String -new_columns: Sequence[Tuple[Column[Modifiers], str]] = [ +new_columns: Sequence[tuple[Column[Modifiers], str]] = [ (Column("user_geo_city", String()), "user_email"), (Column("user_geo_country_code", String()), "user_geo_city"), (Column("user_geo_region", String()), "user_geo_country_code"), diff --git a/snuba/snuba_migrations/replays/0024_add_tap_columns.py b/snuba/snuba_migrations/replays/0024_add_tap_columns.py index 828742b2b86..da58d574679 100644 --- a/snuba/snuba_migrations/replays/0024_add_tap_columns.py +++ b/snuba/snuba_migrations/replays/0024_add_tap_columns.py @@ -1,4 +1,4 @@ -from typing import Iterator, Sequence, Tuple +from collections.abc import Iterator, Sequence from snuba.clusters.storage_sets import StorageSetKey from snuba.migrations.columns import MigrationModifiers as Modifiers @@ -11,7 +11,7 @@ ) from snuba.utils.schemas import Column, String -new_columns: Sequence[Tuple[Column[Modifiers], str]] = [ +new_columns: Sequence[tuple[Column[Modifiers], str]] = [ (Column("tap_message", String()), "click_is_rage"), (Column("tap_view_class", String()), "tap_message"), (Column("tap_view_id", String()), "tap_view_class"), diff --git a/snuba/snuba_migrations/search_issues/0001_search_issues.py b/snuba/snuba_migrations/search_issues/0001_search_issues.py index bc259037659..01690522793 100644 --- a/snuba/snuba_migrations/search_issues/0001_search_issues.py +++ b/snuba/snuba_migrations/search_issues/0001_search_issues.py @@ -1,4 +1,4 @@ -from typing import List, Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import ( UUID, @@ -16,7 +16,7 @@ from snuba.migrations.columns import MigrationModifiers as Modifiers from snuba.migrations.operations import OperationTarget, SqlOperation -columns: List[Column[Modifiers]] = [ +columns: list[Column[Modifiers]] = [ Column("organization_id", UInt(64)), Column("project_id", UInt(64)), Column("group_id", UInt(64)), @@ -36,9 +36,7 @@ Column("client_timestamp", DateTime()), Column("tags", Nested([("key", String()), ("value", String())])), Column("user", String(Modifiers(nullable=True))), - Column( - "user_hash", UInt(64, Modifiers(nullable=True, materialized="cityHash64(user)")) - ), + Column("user_hash", UInt(64, Modifiers(nullable=True, materialized="cityHash64(user)"))), Column("user_id", String(Modifiers(nullable=True))), Column("user_name", String(Modifiers(nullable=True))), Column("user_email", String(Modifiers(nullable=True))), diff --git a/snuba/snuba_migrations/search_issues/0002_search_issues_add_tags_hash_map.py b/snuba/snuba_migrations/search_issues/0002_search_issues_add_tags_hash_map.py index 4b772eaa1a0..ea27b077099 100644 --- a/snuba/snuba_migrations/search_issues/0002_search_issues_add_tags_hash_map.py +++ b/snuba/snuba_migrations/search_issues/0002_search_issues_add_tags_hash_map.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import Array, Column, UInt from snuba.clusters.storage_sets import StorageSetKey diff --git a/snuba/snuba_migrations/search_issues/0003_search_issues_modify_occurrence_type_id_size.py b/snuba/snuba_migrations/search_issues/0003_search_issues_modify_occurrence_type_id_size.py index 95dd41f7869..f9ec8a2839e 100644 --- a/snuba/snuba_migrations/search_issues/0003_search_issues_modify_occurrence_type_id_size.py +++ b/snuba/snuba_migrations/search_issues/0003_search_issues_modify_occurrence_type_id_size.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import Column, UInt from snuba.clusters.storage_sets import StorageSetKey diff --git a/snuba/snuba_migrations/search_issues/0004_rebuild_search_issues_with_version.py b/snuba/snuba_migrations/search_issues/0004_rebuild_search_issues_with_version.py index 393c1b49540..df6a78d6e6c 100644 --- a/snuba/snuba_migrations/search_issues/0004_rebuild_search_issues_with_version.py +++ b/snuba/snuba_migrations/search_issues/0004_rebuild_search_issues_with_version.py @@ -1,4 +1,4 @@ -from typing import List, Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import ( UUID, @@ -17,7 +17,7 @@ from snuba.migrations.columns import MigrationModifiers as Modifiers from snuba.migrations.operations import OperationTarget, SqlOperation -columns: List[Column[Modifiers]] = [ +columns: list[Column[Modifiers]] = [ Column("organization_id", UInt(64)), Column("project_id", UInt(64)), Column("group_id", UInt(64)), @@ -37,9 +37,7 @@ Column("client_timestamp", DateTime()), Column("tags", Nested([("key", String()), ("value", String())])), Column("user", String(Modifiers(nullable=True))), - Column( - "user_hash", UInt(64, Modifiers(nullable=True, materialized="cityHash64(user)")) - ), + Column("user_hash", UInt(64, Modifiers(nullable=True, materialized="cityHash64(user)"))), Column("user_id", String(Modifiers(nullable=True))), Column("user_name", String(Modifiers(nullable=True))), Column("user_email", String(Modifiers(nullable=True))), diff --git a/snuba/snuba_migrations/search_issues/0005_search_issues_v2.py b/snuba/snuba_migrations/search_issues/0005_search_issues_v2.py index a5d15c835fa..bb1712ac330 100644 --- a/snuba/snuba_migrations/search_issues/0005_search_issues_v2.py +++ b/snuba/snuba_migrations/search_issues/0005_search_issues_v2.py @@ -1,4 +1,4 @@ -from typing import List, Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import ( UUID, @@ -17,7 +17,7 @@ from snuba.migrations.columns import MigrationModifiers as Modifiers from snuba.migrations.operations import OperationTarget, SqlOperation -columns: List[Column[Modifiers]] = [ +columns: list[Column[Modifiers]] = [ Column("organization_id", UInt(64)), Column("project_id", UInt(64)), Column("group_id", UInt(64)), @@ -37,9 +37,7 @@ Column("client_timestamp", DateTime()), Column("tags", Nested([("key", String()), ("value", String())])), Column("user", String(Modifiers(nullable=True))), - Column( - "user_hash", UInt(64, Modifiers(nullable=True, materialized="cityHash64(user)")) - ), + Column("user_hash", UInt(64, Modifiers(nullable=True, materialized="cityHash64(user)"))), Column("user_id", String(Modifiers(nullable=True))), Column("user_name", String(Modifiers(nullable=True))), Column("user_email", String(Modifiers(nullable=True))), diff --git a/snuba/snuba_migrations/search_issues/0006_add_subtitle_culprit_level_resource_id.py b/snuba/snuba_migrations/search_issues/0006_add_subtitle_culprit_level_resource_id.py index cf23690f7b7..89bea0c053b 100644 --- a/snuba/snuba_migrations/search_issues/0006_add_subtitle_culprit_level_resource_id.py +++ b/snuba/snuba_migrations/search_issues/0006_add_subtitle_culprit_level_resource_id.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import Column from snuba.clusters.storage_sets import StorageSetKey diff --git a/snuba/snuba_migrations/search_issues/0007_add_transaction_duration.py b/snuba/snuba_migrations/search_issues/0007_add_transaction_duration.py index ec604817f2c..e5dc748f744 100644 --- a/snuba/snuba_migrations/search_issues/0007_add_transaction_duration.py +++ b/snuba/snuba_migrations/search_issues/0007_add_transaction_duration.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import Column from snuba.clusters.storage_sets import StorageSetKey diff --git a/snuba/snuba_migrations/search_issues/0008_add_profile_id_replay_id.py b/snuba/snuba_migrations/search_issues/0008_add_profile_id_replay_id.py index 2f2ea21e321..f0565092144 100644 --- a/snuba/snuba_migrations/search_issues/0008_add_profile_id_replay_id.py +++ b/snuba/snuba_migrations/search_issues/0008_add_profile_id_replay_id.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import UUID, Column from snuba.clusters.storage_sets import StorageSetKey diff --git a/snuba/snuba_migrations/search_issues/0009_add_message.py b/snuba/snuba_migrations/search_issues/0009_add_message.py index cc0faa018db..f2378384676 100644 --- a/snuba/snuba_migrations/search_issues/0009_add_message.py +++ b/snuba/snuba_migrations/search_issues/0009_add_message.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import Column, String from snuba.clusters.storage_sets import StorageSetKey diff --git a/snuba/snuba_migrations/search_issues/0010_add_group_first_seen.py b/snuba/snuba_migrations/search_issues/0010_add_group_first_seen.py index eea97b21547..c4516e23fb9 100644 --- a/snuba/snuba_migrations/search_issues/0010_add_group_first_seen.py +++ b/snuba/snuba_migrations/search_issues/0010_add_group_first_seen.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clusters.storage_sets import StorageSetKey from snuba.migrations import migration, operations @@ -16,9 +16,7 @@ def forwards_ops(self) -> Sequence[operations.SqlOperation]: operations.AddColumn( storage_set=StorageSetKey.SEARCH_ISSUES, table_name=table_name, - column=Column( - "group_first_seen", DateTime(MigrationModifiers(nullable=True)) - ), + column=Column("group_first_seen", DateTime(MigrationModifiers(nullable=True))), after="message", target=target, ), diff --git a/snuba/snuba_migrations/search_issues/0011_add_timestamp_ms.py b/snuba/snuba_migrations/search_issues/0011_add_timestamp_ms.py index 7ca2231fdf6..e4adafd82dc 100644 --- a/snuba/snuba_migrations/search_issues/0011_add_timestamp_ms.py +++ b/snuba/snuba_migrations/search_issues/0011_add_timestamp_ms.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clusters.storage_sets import StorageSetKey from snuba.migrations.columns import MigrationModifiers diff --git a/snuba/snuba_migrations/search_issues/0012_add_group_id_bloom_filter_index.py b/snuba/snuba_migrations/search_issues/0012_add_group_id_bloom_filter_index.py index 806b5a5fcfd..1060daaf183 100644 --- a/snuba/snuba_migrations/search_issues/0012_add_group_id_bloom_filter_index.py +++ b/snuba/snuba_migrations/search_issues/0012_add_group_id_bloom_filter_index.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clusters.storage_sets import StorageSetKey from snuba.migrations import migration, operations diff --git a/snuba/snuba_migrations/test_migration/0001_create_test_table.py b/snuba/snuba_migrations/test_migration/0001_create_test_table.py index 9b35d7790bf..de0080bfa76 100644 --- a/snuba/snuba_migrations/test_migration/0001_create_test_table.py +++ b/snuba/snuba_migrations/test_migration/0001_create_test_table.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import Column, UInt from snuba.clusters.storage_sets import StorageSetKey @@ -16,7 +16,6 @@ class Migration(migration.ClickhouseNodeMigration): blocking = False def forwards_ops(self) -> Sequence[operations.SqlOperation]: - return [ operations.CreateTable( storage_set=StorageSetKey.QUERYLOG, diff --git a/snuba/snuba_migrations/test_migration/0002_add_test_col.py b/snuba/snuba_migrations/test_migration/0002_add_test_col.py index 029424f2095..ce3aa2b4fbe 100644 --- a/snuba/snuba_migrations/test_migration/0002_add_test_col.py +++ b/snuba/snuba_migrations/test_migration/0002_add_test_col.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import Column, UInt from snuba.clusters.storage_sets import StorageSetKey @@ -16,7 +16,6 @@ class Migration(migration.ClickhouseNodeMigration): blocking = False def forwards_ops(self) -> Sequence[operations.SqlOperation]: - return [ operations.AddColumn( storage_set=StorageSetKey.QUERYLOG, diff --git a/snuba/snuba_migrations/transactions/0001_transactions.py b/snuba/snuba_migrations/transactions/0001_transactions.py index ad21ea399cb..6c295dfa700 100644 --- a/snuba/snuba_migrations/transactions/0001_transactions.py +++ b/snuba/snuba_migrations/transactions/0001_transactions.py @@ -1,4 +1,4 @@ -from typing import List, Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import ( UUID, @@ -16,7 +16,7 @@ UNKNOWN_SPAN_STATUS = 2 -columns: List[Column[Modifiers]] = [ +columns: list[Column[Modifiers]] = [ Column("project_id", UInt(64)), Column("event_id", UUID()), Column("trace_id", UUID()), diff --git a/snuba/snuba_migrations/transactions/0002_transactions_onpremise_fix_orderby_and_partitionby.py b/snuba/snuba_migrations/transactions/0002_transactions_onpremise_fix_orderby_and_partitionby.py index a7d3076f3de..b22a8cf9e11 100644 --- a/snuba/snuba_migrations/transactions/0002_transactions_onpremise_fix_orderby_and_partitionby.py +++ b/snuba/snuba_migrations/transactions/0002_transactions_onpremise_fix_orderby_and_partitionby.py @@ -1,7 +1,7 @@ import logging import math import time -from typing import Sequence +from collections.abc import Sequence from snuba.clusters.cluster import ClickhouseClientSettings, get_cluster from snuba.clusters.storage_sets import StorageSetKey @@ -29,9 +29,7 @@ def forwards(logger: logging.Logger) -> None: new_sampling_key = "cityHash64(span_id)" new_partition_key = "(retention_days, toMonday(finish_ts))" - new_primary_key = ( - "project_id, toStartOfDay(finish_ts), transaction_name, cityHash64(span_id)" - ) + new_primary_key = "project_id, toStartOfDay(finish_ts), transaction_name, cityHash64(span_id)" ((curr_sampling_key, curr_partition_key, curr_primary_key),) = clickhouse.execute( f"SELECT sampling_key, partition_key, primary_key FROM system.tables WHERE name = '{TABLE_NAME}' AND database = '{database}'" @@ -56,9 +54,7 @@ def forwards(logger: logging.Logger) -> None: f"SHOW CREATE TABLE {database}.{TABLE_NAME}" ).results - new_create_table_statement = curr_create_table_statement.replace( - TABLE_NAME, TABLE_NAME_NEW - ) + new_create_table_statement = curr_create_table_statement.replace(TABLE_NAME, TABLE_NAME_NEW) # Insert sample clause before TTL if sampling_key_needs_update: diff --git a/snuba/snuba_migrations/transactions/0003_transactions_onpremise_fix_columns.py b/snuba/snuba_migrations/transactions/0003_transactions_onpremise_fix_columns.py index f9385603130..0adf7279788 100644 --- a/snuba/snuba_migrations/transactions/0003_transactions_onpremise_fix_columns.py +++ b/snuba/snuba_migrations/transactions/0003_transactions_onpremise_fix_columns.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import Column, DateTime, String, UInt from snuba.clusters.storage_sets import StorageSetKey @@ -28,17 +28,13 @@ def forwards_local(self) -> Sequence[operations.SqlOperation]: operations.AddColumn( storage_set=StorageSetKey.TRANSACTIONS, table_name="transactions_local", - column=Column( - "sdk_name", String(Modifiers(low_cardinality=True, default="''")) - ), + column=Column("sdk_name", String(Modifiers(low_cardinality=True, default="''"))), after="dist", ), operations.AddColumn( storage_set=StorageSetKey.TRANSACTIONS, table_name="transactions_local", - column=Column( - "sdk_version", String(Modifiers(low_cardinality=True, default="''")) - ), + column=Column("sdk_version", String(Modifiers(low_cardinality=True, default="''"))), after="sdk_name", ), operations.AddColumn( @@ -65,46 +61,34 @@ def forwards_local(self) -> Sequence[operations.SqlOperation]: operations.AddColumn( storage_set=StorageSetKey.TRANSACTIONS, table_name="transactions_local", - column=Column( - "user_hash", UInt(64, Modifiers(materialized="cityHash64(user)")) - ), + column=Column("user_hash", UInt(64, Modifiers(materialized="cityHash64(user)"))), after="user", ), # The following columns were originally created as non low cardinality strings operations.ModifyColumn( storage_set=StorageSetKey.TRANSACTIONS, table_name="transactions_local", - column=Column( - "transaction_name", String(Modifiers(low_cardinality=True)) - ), + column=Column("transaction_name", String(Modifiers(low_cardinality=True))), ), operations.ModifyColumn( storage_set=StorageSetKey.TRANSACTIONS, table_name="transactions_local", - column=Column( - "release", String(Modifiers(nullable=True, low_cardinality=True)) - ), + column=Column("release", String(Modifiers(nullable=True, low_cardinality=True))), ), operations.ModifyColumn( storage_set=StorageSetKey.TRANSACTIONS, table_name="transactions_local", - column=Column( - "dist", String(Modifiers(nullable=True, low_cardinality=True)) - ), + column=Column("dist", String(Modifiers(nullable=True, low_cardinality=True))), ), operations.ModifyColumn( storage_set=StorageSetKey.TRANSACTIONS, table_name="transactions_local", - column=Column( - "sdk_name", String(Modifiers(low_cardinality=True, default="''")) - ), + column=Column("sdk_name", String(Modifiers(low_cardinality=True, default="''"))), ), operations.ModifyColumn( storage_set=StorageSetKey.TRANSACTIONS, table_name="transactions_local", - column=Column( - "sdk_version", String(Modifiers(low_cardinality=True, default="''")) - ), + column=Column("sdk_version", String(Modifiers(low_cardinality=True, default="''"))), ), operations.ModifyColumn( storage_set=StorageSetKey.TRANSACTIONS, diff --git a/snuba/snuba_migrations/transactions/0004_transactions_add_tags_hash_map.py b/snuba/snuba_migrations/transactions/0004_transactions_add_tags_hash_map.py index 553cae00084..d3e95a554a8 100644 --- a/snuba/snuba_migrations/transactions/0004_transactions_add_tags_hash_map.py +++ b/snuba/snuba_migrations/transactions/0004_transactions_add_tags_hash_map.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import Array, Column, UInt from snuba.clusters.storage_sets import StorageSetKey @@ -55,7 +55,5 @@ def forwards_dist(self) -> Sequence[operations.SqlOperation]: def backwards_dist(self) -> Sequence[operations.SqlOperation]: return [ - operations.DropColumn( - StorageSetKey.TRANSACTIONS, "transactions_dist", "_tags_hash_map" - ) + operations.DropColumn(StorageSetKey.TRANSACTIONS, "transactions_dist", "_tags_hash_map") ] diff --git a/snuba/snuba_migrations/transactions/0005_transactions_add_measurements.py b/snuba/snuba_migrations/transactions/0005_transactions_add_measurements.py index b684717ba2c..9f46339d340 100644 --- a/snuba/snuba_migrations/transactions/0005_transactions_add_measurements.py +++ b/snuba/snuba_migrations/transactions/0005_transactions_add_measurements.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import Column, Float, Nested, String from snuba.clusters.storage_sets import StorageSetKey @@ -33,9 +33,7 @@ def forwards_local(self) -> Sequence[operations.SqlOperation]: def backwards_local(self) -> Sequence[operations.SqlOperation]: return [ - operations.DropColumn( - StorageSetKey.TRANSACTIONS, "transactions_local", "measurements" - ), + operations.DropColumn(StorageSetKey.TRANSACTIONS, "transactions_local", "measurements"), ] def forwards_dist(self) -> Sequence[operations.SqlOperation]: @@ -58,7 +56,5 @@ def forwards_dist(self) -> Sequence[operations.SqlOperation]: def backwards_dist(self) -> Sequence[operations.SqlOperation]: return [ - operations.DropColumn( - StorageSetKey.TRANSACTIONS, "transactions_dist", "measurements" - ) + operations.DropColumn(StorageSetKey.TRANSACTIONS, "transactions_dist", "measurements") ] diff --git a/snuba/snuba_migrations/transactions/0006_transactions_add_http_fields.py b/snuba/snuba_migrations/transactions/0006_transactions_add_http_fields.py index 9c66a74e25e..591ab47b593 100644 --- a/snuba/snuba_migrations/transactions/0006_transactions_add_http_fields.py +++ b/snuba/snuba_migrations/transactions/0006_transactions_add_http_fields.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import Column, String from snuba.clusters.storage_sets import StorageSetKey @@ -34,12 +34,8 @@ def forwards_local(self) -> Sequence[operations.SqlOperation]: def backwards_local(self) -> Sequence[operations.SqlOperation]: return [ - operations.DropColumn( - StorageSetKey.TRANSACTIONS, "transactions_local", "http_method" - ), - operations.DropColumn( - StorageSetKey.TRANSACTIONS, "transactions_local", "http_referer" - ), + operations.DropColumn(StorageSetKey.TRANSACTIONS, "transactions_local", "http_method"), + operations.DropColumn(StorageSetKey.TRANSACTIONS, "transactions_local", "http_referer"), ] def forwards_dist(self) -> Sequence[operations.SqlOperation]: @@ -63,10 +59,6 @@ def forwards_dist(self) -> Sequence[operations.SqlOperation]: def backwards_dist(self) -> Sequence[operations.SqlOperation]: return [ - operations.DropColumn( - StorageSetKey.TRANSACTIONS, "transactions_dist", "http_method" - ), - operations.DropColumn( - StorageSetKey.TRANSACTIONS, "transactions_dist", "http_referer" - ), + operations.DropColumn(StorageSetKey.TRANSACTIONS, "transactions_dist", "http_method"), + operations.DropColumn(StorageSetKey.TRANSACTIONS, "transactions_dist", "http_referer"), ] diff --git a/snuba/snuba_migrations/transactions/0007_transactions_add_discover_cols.py b/snuba/snuba_migrations/transactions/0007_transactions_add_discover_cols.py index 09c95a04d66..c985e5466db 100644 --- a/snuba/snuba_migrations/transactions/0007_transactions_add_discover_cols.py +++ b/snuba/snuba_migrations/transactions/0007_transactions_add_discover_cols.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import Column, DateTime, String from snuba.clusters.storage_sets import StorageSetKey @@ -13,18 +13,14 @@ class Migration(migration.ClickhouseNodeMigrationLegacy): blocking = False - def __forward_migrations( - self, table_name: str - ) -> Sequence[operations.SqlOperation]: + def __forward_migrations(self, table_name: str) -> Sequence[operations.SqlOperation]: return [ operations.AddColumn( storage_set=StorageSetKey.TRANSACTIONS, table_name=table_name, column=Column( "type", - String( - Modifiers(low_cardinality=True, materialized="'transaction'") - ), + String(Modifiers(low_cardinality=True, materialized="'transaction'")), ), after="deleted", ), @@ -33,9 +29,7 @@ def __forward_migrations( table_name=table_name, column=Column( "message", - String( - Modifiers(low_cardinality=True, materialized="transaction_name") - ), + String(Modifiers(low_cardinality=True, materialized="transaction_name")), ), after="type", ), @@ -44,25 +38,19 @@ def __forward_migrations( table_name=table_name, column=Column( "title", - String( - Modifiers(low_cardinality=True, materialized="transaction_name") - ), + String(Modifiers(low_cardinality=True, materialized="transaction_name")), ), after="message", ), operations.AddColumn( storage_set=StorageSetKey.TRANSACTIONS, table_name=table_name, - column=Column( - "timestamp", DateTime(Modifiers(materialized="finish_ts")) - ), + column=Column("timestamp", DateTime(Modifiers(materialized="finish_ts"))), after=("type" if table_name == "transactions_local" else "title"), ), ] - def __backwards_migrations( - self, table_name: str - ) -> Sequence[operations.SqlOperation]: + def __backwards_migrations(self, table_name: str) -> Sequence[operations.SqlOperation]: return [ operations.DropColumn(StorageSetKey.TRANSACTIONS, table_name, "type"), operations.DropColumn(StorageSetKey.TRANSACTIONS, table_name, "message"), diff --git a/snuba/snuba_migrations/transactions/0008_transactions_add_timestamp_index.py b/snuba/snuba_migrations/transactions/0008_transactions_add_timestamp_index.py index 06a5fc43302..5cfb5b39449 100644 --- a/snuba/snuba_migrations/transactions/0008_transactions_add_timestamp_index.py +++ b/snuba/snuba_migrations/transactions/0008_transactions_add_timestamp_index.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clusters.storage_sets import StorageSetKey from snuba.migrations import migration, operations diff --git a/snuba/snuba_migrations/transactions/0009_transactions_fix_title_and_message.py b/snuba/snuba_migrations/transactions/0009_transactions_fix_title_and_message.py index 50598f2534f..a349673c8ad 100644 --- a/snuba/snuba_migrations/transactions/0009_transactions_fix_title_and_message.py +++ b/snuba/snuba_migrations/transactions/0009_transactions_fix_title_and_message.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import Column, String from snuba.clusters.storage_sets import StorageSetKey @@ -14,9 +14,7 @@ class Migration(migration.ClickhouseNodeMigrationLegacy): blocking = False - def __forward_migrations( - self, table_name: str - ) -> Sequence[operations.SqlOperation]: + def __forward_migrations(self, table_name: str) -> Sequence[operations.SqlOperation]: return [ operations.ModifyColumn( storage_set=StorageSetKey.TRANSACTIONS, @@ -36,18 +34,14 @@ def __forward_migrations( ), ] - def __backward_migrations( - self, table_name: str - ) -> Sequence[operations.SqlOperation]: + def __backward_migrations(self, table_name: str) -> Sequence[operations.SqlOperation]: return [ operations.ModifyColumn( storage_set=StorageSetKey.TRANSACTIONS, table_name=table_name, column=Column( "title", - String( - Modifiers(low_cardinality=True, materialized="transaction_name") - ), + String(Modifiers(low_cardinality=True, materialized="transaction_name")), ), ), operations.ModifyColumn( @@ -55,9 +49,7 @@ def __backward_migrations( table_name=table_name, column=Column( "message", - String( - Modifiers(low_cardinality=True, materialized="transaction_name") - ), + String(Modifiers(low_cardinality=True, materialized="transaction_name")), ), ), ] diff --git a/snuba/snuba_migrations/transactions/0010_transactions_nullable_trace_id.py b/snuba/snuba_migrations/transactions/0010_transactions_nullable_trace_id.py index b2e983a55e4..fbbdee97ce2 100644 --- a/snuba/snuba_migrations/transactions/0010_transactions_nullable_trace_id.py +++ b/snuba/snuba_migrations/transactions/0010_transactions_nullable_trace_id.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import UUID, Column from snuba.clusters.storage_sets import StorageSetKey diff --git a/snuba/snuba_migrations/transactions/0011_transactions_add_span_op_breakdowns.py b/snuba/snuba_migrations/transactions/0011_transactions_add_span_op_breakdowns.py index 1d274eaa4c5..fe2596133a6 100644 --- a/snuba/snuba_migrations/transactions/0011_transactions_add_span_op_breakdowns.py +++ b/snuba/snuba_migrations/transactions/0011_transactions_add_span_op_breakdowns.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import Column, Float, Nested, String from snuba.clusters.storage_sets import StorageSetKey diff --git a/snuba/snuba_migrations/transactions/0012_transactions_add_spans.py b/snuba/snuba_migrations/transactions/0012_transactions_add_spans.py index 1bae00fb9d9..67fc0559c1d 100644 --- a/snuba/snuba_migrations/transactions/0012_transactions_add_spans.py +++ b/snuba/snuba_migrations/transactions/0012_transactions_add_spans.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import Array, Column, Float, Nested, String, UInt from snuba.clusters.storage_sets import StorageSetKey @@ -7,7 +7,6 @@ class Migration(migration.ClickhouseNodeMigrationLegacy): - blocking = False def forwards_local(self) -> Sequence[operations.SqlOperation]: @@ -30,9 +29,7 @@ def forwards_local(self) -> Sequence[operations.SqlOperation]: operations.ModifyColumn( storage_set=StorageSetKey.TRANSACTIONS, table_name="transactions_local", - column=Column( - "spans.op", Array(String(Modifiers(low_cardinality=True))) - ), + column=Column("spans.op", Array(String(Modifiers(low_cardinality=True)))), ttl_month=("finish_ts", 1), ), operations.ModifyColumn( @@ -77,9 +74,7 @@ def backwards_local(self) -> Sequence[operations.SqlOperation]: table_name="transactions_local", index_name="bf_spans_op", ), - operations.DropColumn( - StorageSetKey.TRANSACTIONS, "transactions_local", "spans" - ), + operations.DropColumn(StorageSetKey.TRANSACTIONS, "transactions_local", "spans"), ] def forwards_dist(self) -> Sequence[operations.SqlOperation]: @@ -102,8 +97,4 @@ def forwards_dist(self) -> Sequence[operations.SqlOperation]: ] def backwards_dist(self) -> Sequence[operations.SqlOperation]: - return [ - operations.DropColumn( - StorageSetKey.TRANSACTIONS, "transactions_dist", "spans" - ) - ] + return [operations.DropColumn(StorageSetKey.TRANSACTIONS, "transactions_dist", "spans")] diff --git a/snuba/snuba_migrations/transactions/0013_transactions_reduce_spans_exclusive_time.py b/snuba/snuba_migrations/transactions/0013_transactions_reduce_spans_exclusive_time.py index 52114d1ac4a..80952cae217 100644 --- a/snuba/snuba_migrations/transactions/0013_transactions_reduce_spans_exclusive_time.py +++ b/snuba/snuba_migrations/transactions/0013_transactions_reduce_spans_exclusive_time.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import Array, Column, Float from snuba.clusters.storage_sets import StorageSetKey @@ -6,7 +6,6 @@ class Migration(migration.ClickhouseNodeMigrationLegacy): - blocking = False def forwards_local(self) -> Sequence[operations.SqlOperation]: diff --git a/snuba/snuba_migrations/transactions/0014_transactions_remove_flattened_columns.py b/snuba/snuba_migrations/transactions/0014_transactions_remove_flattened_columns.py index c0bc55fce4f..1ec7247d4af 100644 --- a/snuba/snuba_migrations/transactions/0014_transactions_remove_flattened_columns.py +++ b/snuba/snuba_migrations/transactions/0014_transactions_remove_flattened_columns.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import Column, String from snuba.clusters.storage_sets import StorageSetKey @@ -6,7 +6,6 @@ class Migration(migration.ClickhouseNodeMigrationLegacy): - blocking = False forwards_local_first: bool = False backwards_local_first: bool = True diff --git a/snuba/snuba_migrations/transactions/0015_transactions_add_source_column.py b/snuba/snuba_migrations/transactions/0015_transactions_add_source_column.py index 369b87c0cdc..3f9e597e52f 100644 --- a/snuba/snuba_migrations/transactions/0015_transactions_add_source_column.py +++ b/snuba/snuba_migrations/transactions/0015_transactions_add_source_column.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import Column, String from snuba.clusters.storage_sets import StorageSetKey @@ -15,9 +15,7 @@ class Migration(migration.ClickhouseNodeMigrationLegacy): blocking = False - def __forward_migrations( - self, table_name: str - ) -> Sequence[operations.SqlOperation]: + def __forward_migrations(self, table_name: str) -> Sequence[operations.SqlOperation]: return [ operations.AddColumn( storage_set=StorageSetKey.TRANSACTIONS, @@ -30,13 +28,9 @@ def __forward_migrations( ), ] - def __backwards_migrations( - self, table_name: str - ) -> Sequence[operations.SqlOperation]: + def __backwards_migrations(self, table_name: str) -> Sequence[operations.SqlOperation]: return [ - operations.DropColumn( - StorageSetKey.TRANSACTIONS, table_name, "transaction_source" - ), + operations.DropColumn(StorageSetKey.TRANSACTIONS, table_name, "transaction_source"), ] def forwards_local(self) -> Sequence[operations.SqlOperation]: diff --git a/snuba/snuba_migrations/transactions/0016_transactions_add_group_ids_column.py b/snuba/snuba_migrations/transactions/0016_transactions_add_group_ids_column.py index d157b9d0b80..4a54b85b80b 100644 --- a/snuba/snuba_migrations/transactions/0016_transactions_add_group_ids_column.py +++ b/snuba/snuba_migrations/transactions/0016_transactions_add_group_ids_column.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import Array, Column, UInt from snuba.clusters.storage_sets import StorageSetKey @@ -13,9 +13,7 @@ class Migration(migration.ClickhouseNodeMigrationLegacy): blocking = False - def __forward_migrations( - self, table_name: str - ) -> Sequence[operations.SqlOperation]: + def __forward_migrations(self, table_name: str) -> Sequence[operations.SqlOperation]: return [ operations.AddColumn( storage_set=StorageSetKey.TRANSACTIONS, @@ -25,12 +23,8 @@ def __forward_migrations( ) ] - def __backwards_migrations( - self, table_name: str - ) -> Sequence[operations.SqlOperation]: - return [ - operations.DropColumn(StorageSetKey.TRANSACTIONS, table_name, "group_ids") - ] + def __backwards_migrations(self, table_name: str) -> Sequence[operations.SqlOperation]: + return [operations.DropColumn(StorageSetKey.TRANSACTIONS, table_name, "group_ids")] def forwards_local(self) -> Sequence[operations.SqlOperation]: return self.__forward_migrations("transactions_local") diff --git a/snuba/snuba_migrations/transactions/0017_transactions_add_app_start_type_column.py b/snuba/snuba_migrations/transactions/0017_transactions_add_app_start_type_column.py index 47a4b7d8dd5..a940bbd8e5a 100644 --- a/snuba/snuba_migrations/transactions/0017_transactions_add_app_start_type_column.py +++ b/snuba/snuba_migrations/transactions/0017_transactions_add_app_start_type_column.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import Column, String from snuba.clusters.storage_sets import StorageSetKey @@ -14,9 +14,7 @@ class Migration(migration.ClickhouseNodeMigrationLegacy): blocking = False - def __forward_migrations( - self, table_name: str - ) -> Sequence[operations.SqlOperation]: + def __forward_migrations(self, table_name: str) -> Sequence[operations.SqlOperation]: return [ operations.AddColumn( storage_set=StorageSetKey.TRANSACTIONS, @@ -29,13 +27,9 @@ def __forward_migrations( ), ] - def __backwards_migrations( - self, table_name: str - ) -> Sequence[operations.SqlOperation]: + def __backwards_migrations(self, table_name: str) -> Sequence[operations.SqlOperation]: return [ - operations.DropColumn( - StorageSetKey.TRANSACTIONS, table_name, "app_start_type" - ), + operations.DropColumn(StorageSetKey.TRANSACTIONS, table_name, "app_start_type"), ] def forwards_local(self) -> Sequence[operations.SqlOperation]: diff --git a/snuba/snuba_migrations/transactions/0018_transactions_add_profile_id.py b/snuba/snuba_migrations/transactions/0018_transactions_add_profile_id.py index d66fbb730f6..2096175b180 100644 --- a/snuba/snuba_migrations/transactions/0018_transactions_add_profile_id.py +++ b/snuba/snuba_migrations/transactions/0018_transactions_add_profile_id.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import UUID, Column from snuba.clusters.storage_sets import StorageSetKey diff --git a/snuba/snuba_migrations/transactions/0019_transactions_add_indexes_and_context_hash.py b/snuba/snuba_migrations/transactions/0019_transactions_add_indexes_and_context_hash.py index 661a0e82129..e7ef2a97b0b 100644 --- a/snuba/snuba_migrations/transactions/0019_transactions_add_indexes_and_context_hash.py +++ b/snuba/snuba_migrations/transactions/0019_transactions_add_indexes_and_context_hash.py @@ -1,4 +1,4 @@ -from typing import List, Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import Array, Column, UInt from snuba.clusters.storage_sets import StorageSetKey @@ -38,7 +38,7 @@ class Migration(migration.ClickhouseNodeMigration): blocking = False def forwards_ops(self) -> Sequence[operations.SqlOperation]: - ops: List[operations.SqlOperation] = [ + ops: list[operations.SqlOperation] = [ operations.AddColumn( storage_set=StorageSetKey.TRANSACTIONS, table_name="transactions_local", @@ -64,7 +64,7 @@ def forwards_ops(self) -> Sequence[operations.SqlOperation]: # apply the indexes after the minmax_timestamp index and have them follow each other # one after the other index_names = [index[0] for index in indexes] - for index, after_col in zip(indexes, ["minmax_timestamp"] + index_names[:-1]): + for index, after_col in zip(indexes, ["minmax_timestamp"] + index_names[:-1], strict=False): index_name, index_expression, index_type, granularity = index ops.append( operations.AddIndex( @@ -82,7 +82,7 @@ def forwards_ops(self) -> Sequence[operations.SqlOperation]: return ops def backwards_ops(self) -> Sequence[operations.SqlOperation]: - ops: List[operations.SqlOperation] = [] + ops: list[operations.SqlOperation] = [] for index in indexes: index_name, _, _, _ = index diff --git a/snuba/snuba_migrations/transactions/0020_transactions_add_codecs.py b/snuba/snuba_migrations/transactions/0020_transactions_add_codecs.py index b5830c74efe..d1f66cd27ab 100644 --- a/snuba/snuba_migrations/transactions/0020_transactions_add_codecs.py +++ b/snuba/snuba_migrations/transactions/0020_transactions_add_codecs.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import UUID, Column, UInt from snuba.clusters.storage_sets import StorageSetKey @@ -20,9 +20,7 @@ def forwards_ops(self) -> Sequence[operations.SqlOperation]: operations.ModifyColumn( storage_set=StorageSetKey.TRANSACTIONS, table_name="transactions_local", - column=Column( - "trace_id", UUID(Modifiers(nullable=True, codecs=["LZ4"])) - ), + column=Column("trace_id", UUID(Modifiers(nullable=True, codecs=["LZ4"]))), target=operations.OperationTarget.LOCAL, ), operations.ModifyColumn( diff --git a/snuba/snuba_migrations/transactions/0021_transactions_add_replay_id.py b/snuba/snuba_migrations/transactions/0021_transactions_add_replay_id.py index 19c92efadfc..eddefdbb0e3 100644 --- a/snuba/snuba_migrations/transactions/0021_transactions_add_replay_id.py +++ b/snuba/snuba_migrations/transactions/0021_transactions_add_replay_id.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import UUID, Column from snuba.clusters.storage_sets import StorageSetKey diff --git a/snuba/snuba_migrations/transactions/0022_transactions_add_index_on_trace_id.py b/snuba/snuba_migrations/transactions/0022_transactions_add_index_on_trace_id.py index 143a0f99288..9a902499ca5 100644 --- a/snuba/snuba_migrations/transactions/0022_transactions_add_index_on_trace_id.py +++ b/snuba/snuba_migrations/transactions/0022_transactions_add_index_on_trace_id.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clusters.storage_sets import StorageSetKey from snuba.migrations import migration, operations diff --git a/snuba/snuba_migrations/transactions/0023_add_profiler_id_column.py b/snuba/snuba_migrations/transactions/0023_add_profiler_id_column.py index fe5f3973253..51dbb99efad 100644 --- a/snuba/snuba_migrations/transactions/0023_add_profiler_id_column.py +++ b/snuba/snuba_migrations/transactions/0023_add_profiler_id_column.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import UUID, Column from snuba.clusters.storage_sets import StorageSetKey diff --git a/snuba/state/__init__.py b/snuba/state/__init__.py index 1df7c0a6f07..68b71fe09e9 100644 --- a/snuba/state/__init__.py +++ b/snuba/state/__init__.py @@ -1,20 +1,15 @@ -from __future__ import absolute_import, annotations +from __future__ import annotations import logging import os import time +from collections.abc import Iterable, Mapping, Sequence from dataclasses import dataclass from functools import partial from typing import ( Any, - Iterable, - Mapping, - Optional, Protocol, - Sequence, SupportsFloat, - Tuple, - Type, ) import simplejson as json @@ -64,8 +59,8 @@ def _kafka_producer() -> Producer: @dataclass(frozen=True) class MismatchedTypeException(Exception): key: str - original_type: Type[Any] - new_type: Type[Any] + original_type: type[Any] + new_type: type[Any] class ConfigKeyCallable(Protocol): # Necessary for typing the memoize @@ -120,13 +115,13 @@ def get_typed_value(value: Any) -> Any: def set_config( key: str, - value: Optional[Any], - user: Optional[str] = None, + value: Any | None, + user: str | None = None, force: bool = False, config_key: str = config_hash, ) -> None: value = get_typed_value(value) - enc_value = "{}".format(value).encode("utf-8") if value is not None else None + enc_value = f"{value}".encode() if value is not None else None try: enc_original_value = rds.hget(config_key, key) if enc_original_value is not None and value is not None: @@ -159,8 +154,8 @@ def set_config( def set_configs( - values: Mapping[str, Optional[Any]], - user: Optional[str] = None, + values: Mapping[str, Any | None], + user: str | None = None, force: bool = False, config_key: str = config_hash, ) -> None: @@ -169,52 +164,48 @@ def set_configs( def get_int_config( - key: str, default: Optional[int] = None, config_key: str = config_hash -) -> Optional[int]: + key: str, default: int | None = None, config_key: str = config_hash +) -> int | None: config = _get_config(key, default, config_key) return int(config) if config is not None else None def get_float_config( - key: str, default: Optional[float] = None, config_key: str = config_hash -) -> Optional[float]: + key: str, default: float | None = None, config_key: str = config_hash +) -> float | None: config = _get_config(key, default, config_key) return float(config) if config is not None else None def get_str_config( - key: str, default: Optional[str] = None, config_key: str = config_hash -) -> Optional[str]: + key: str, default: str | None = None, config_key: str = config_hash +) -> str | None: config = _get_config(key, default, config_key) return str(config) if config is not None else None # To be deprecated, use get_int_config, get_float_config, get_str_config instead -def get_config( - key: str, default: Optional[Any] = None, config_key: str = config_hash -) -> Optional[Any]: +def get_config(key: str, default: Any | None = None, config_key: str = config_hash) -> Any | None: return _get_config(key, default, config_key) -def _get_config( - key: str, default: Optional[Any] = None, config_key: str = config_hash -) -> Optional[Any]: +def _get_config(key: str, default: Any | None = None, config_key: str = config_hash) -> Any | None: return get_all_configs(config_key=config_key).get(key, default) def get_configs( - key_defaults: Iterable[Tuple[str, Optional[Any]]], config_key: str = config_hash -) -> Sequence[Optional[Any]]: + key_defaults: Iterable[tuple[str, Any | None]], config_key: str = config_hash +) -> Sequence[Any | None]: all_confs = get_all_configs(config_key=config_key) return [all_confs.get(k, d) for k, d in key_defaults] -def get_all_configs(config_key: str = config_hash) -> Mapping[str, Optional[Any]]: - return {k: v for k, v in get_raw_configs(config_key=config_key).items()} +def get_all_configs(config_key: str = config_hash) -> Mapping[str, Any | None]: + return dict(get_raw_configs(config_key=config_key).items()) @memoize(settings.CONFIG_MEMOIZE_TIMEOUT) -def get_raw_configs(config_key: str = config_hash) -> Mapping[str, Optional[Any]]: +def get_raw_configs(config_key: str = config_hash) -> Mapping[str, Any | None]: try: all_configs = rds.hgetall(config_key) configs = { @@ -234,11 +225,11 @@ def get_raw_configs(config_key: str = config_hash) -> Mapping[str, Optional[Any] return {} -def delete_config(key: str, user: Optional[Any] = None, config_key: str = config_hash) -> None: +def delete_config(key: str, user: Any | None = None, config_key: str = config_hash) -> None: set_config(key, None, user=user, config_key=config_key) -def get_uncached_config(key: str, config_key: str = config_hash) -> Optional[Any]: +def get_uncached_config(key: str, config_key: str = config_hash) -> Any | None: value = rds.hget(config_key, key.encode("utf-8")) if value is not None: return get_typed_value(value.decode("utf-8")) @@ -249,7 +240,7 @@ def get_config_changes_legacy() -> Sequence[Any]: return [json.loads(change) for change in rds.lrange(config_changes_list, 0, -1)] -def get_config_changes() -> Sequence[Tuple[str, float, Optional[str], Any, Any]]: +def get_config_changes() -> Sequence[tuple[str, float, str | None, Any, Any]]: """ Like get_config_changes_legacy() but ensures that values are cast to their correct type """ @@ -265,9 +256,9 @@ def get_config_changes() -> Sequence[Tuple[str, float, Optional[str], Any, Any]] def set_config_description( - key: str, description: Optional[str] = None, user: Optional[str] = None + key: str, description: str | None = None, user: str | None = None ) -> None: - enc_desc = "{}".format(description).encode("utf-8") if description is not None else None + enc_desc = f"{description}".encode() if description is not None else None try: enc_original_desc = rds.hget(config_description_hash, key) @@ -290,7 +281,7 @@ def set_config_description( logger.exception(e) -def get_config_description(key: str) -> Optional[str]: +def get_config_description(key: str) -> str | None: try: enc_desc = rds.hget(config_description_hash, key) return enc_desc.decode("utf-8") if enc_desc is not None else None @@ -299,7 +290,7 @@ def get_config_description(key: str) -> Optional[str]: return None -def get_all_config_descriptions() -> Mapping[str, Optional[str]]: +def get_all_config_descriptions() -> Mapping[str, str | None]: try: all_descriptions = rds.hgetall(config_description_hash) return { @@ -312,7 +303,7 @@ def get_all_config_descriptions() -> Mapping[str, Optional[str]]: return {} -def delete_config_description(key: str, user: Optional[str] = None) -> None: +def delete_config_description(key: str, user: str | None = None) -> None: set_config_description(key, None, user=user) @@ -328,7 +319,7 @@ def safe_dumps_default(value: Any) -> Any: safe_dumps = partial(json.dumps, for_json=True, default=safe_dumps_default) -def _record_query_delivery_callback(error: Optional[KafkaError], message: KafkaMessage) -> None: +def _record_query_delivery_callback(error: KafkaError | None, message: KafkaMessage) -> None: metrics.increment( "record_query.delivery_callback", tags={"status": "success" if error is None else "failure"}, diff --git a/snuba/state/cache/abstract.py b/snuba/state/cache/abstract.py index 09181609b0e..488180bfaa9 100644 --- a/snuba/state/cache/abstract.py +++ b/snuba/state/cache/abstract.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod -from typing import Callable, Generic, Optional, TypeVar +from collections.abc import Callable +from typing import Generic, TypeVar from snuba.utils.metrics.timer import Timer from snuba.utils.serializable_exception import SerializableException @@ -17,7 +18,7 @@ class ExecutionTimeoutError(ExecutionError): class Cache(Generic[TValue], ABC): @abstractmethod - def get(self, key: str) -> Optional[TValue]: + def get(self, key: str) -> TValue | None: """ Gets a value from the cache. """ @@ -36,7 +37,7 @@ def get_readthrough( key: str, function: Callable[[], TValue], record_cache_hit_type: Callable[[int], None], - timer: Optional[Timer] = None, + timer: Timer | None = None, ) -> TValue: """ Implements a read-through caching pattern for the value at the given diff --git a/snuba/state/cache/redis/backend.py b/snuba/state/cache/redis/backend.py index 375f1ab18cb..fe508923365 100644 --- a/snuba/state/cache/redis/backend.py +++ b/snuba/state/cache/redis/backend.py @@ -1,5 +1,5 @@ import logging -from typing import Callable, Optional, Type +from collections.abc import Callable import sentry_sdk from redis import ResponseError @@ -25,7 +25,7 @@ class FuzzyMatchException: - def __init__(self, exception: Type[Exception], message: str | None = None): + def __init__(self, exception: type[Exception], message: str | None = None): self._exception = exception self._message = message @@ -34,10 +34,9 @@ def __eq__(self, other: object) -> bool: return isinstance(other, self._exception) and ( self._message is None or self._message == str(other) ) - elif isinstance(other, self.__class__): + if isinstance(other, self.__class__): return other._exception == self._exception and other._message == self._message - else: - return False + return False DONT_CAPTURE_ERRORS = ( @@ -58,14 +57,12 @@ def __init__( self.__prefix = prefix self.__codec = codec - def __build_key( - self, key: str, prefix: Optional[str] = None, suffix: Optional[str] = None - ) -> str: + def __build_key(self, key: str, prefix: str | None = None, suffix: str | None = None) -> str: return self.__prefix + "/".join( [bit for bit in [prefix, f"{{{key}}}", suffix] if bit is not None] ) - def get(self, key: str) -> Optional[TValue]: + def get(self, key: str) -> TValue | None: value = self.__client.get(self.__build_key(key)) if value is None: return None @@ -84,7 +81,7 @@ def __get_value_with_simple_readthrough( key: str, function: Callable[[], TValue], record_cache_hit_type: Callable[[int], None], - timer: Optional[Timer] = None, + timer: Timer | None = None, ) -> TValue: record_cache_hit_type(SIMPLE_READTHROUGH) result_key = self.__build_key(key) @@ -106,37 +103,34 @@ def __get_value_with_simple_readthrough( if cached_value is not None: record_cache_hit_type(RESULT_VALUE) return self.__codec.decode(cached_value) - else: + try: + value = function() try: - value = function() - try: - self.__client.set( - result_key, - self.__codec.encode(value), - ex=get_config("cache_expiry_sec", 1), - ) - - except Exception as e: - metrics.increment( - "redis_cache_set_error", tags={"error": str(e), **metric_tags} - ) - if e not in DONT_CAPTURE_ERRORS: - sentry_sdk.capture_exception(e) - return value - record_cache_hit_type(RESULT_EXECUTE) - if timer is not None: - timer.mark("cache_set") + self.__client.set( + result_key, + self.__codec.encode(value), + ex=get_config("cache_expiry_sec", 1), + ) + except Exception as e: - metrics.increment("execute_error", tags=metric_tags) - raise e - return value + metrics.increment("redis_cache_set_error", tags={"error": str(e), **metric_tags}) + if e not in DONT_CAPTURE_ERRORS: + sentry_sdk.capture_exception(e) + return value + record_cache_hit_type(RESULT_EXECUTE) + if timer is not None: + timer.mark("cache_set") + except Exception as e: + metrics.increment("execute_error", tags=metric_tags) + raise e + return value def get_readthrough( self, key: str, function: Callable[[], TValue], record_cache_hit_type: Callable[[int], None], - timer: Optional[Timer] = None, + timer: Timer | None = None, ) -> TValue: # in case something is wrong with redis, we want to be able to # disable the read_through_cache but still serve traffic. diff --git a/snuba/state/explain_meta.py b/snuba/state/explain_meta.py index 4f4dfb41040..e0f35bc11be 100644 --- a/snuba/state/explain_meta.py +++ b/snuba/state/explain_meta.py @@ -1,9 +1,10 @@ from __future__ import annotations import difflib +from collections.abc import Generator from contextlib import contextmanager from dataclasses import dataclass, field -from typing import Any, Generator, Literal, cast +from typing import Any, Literal, cast from flask import g @@ -44,7 +45,7 @@ def add_step(self, step: ExplainStep) -> None: @contextmanager -def with_query_differ(category: str, name: str, query: Any) -> Generator[None, None, None]: +def with_query_differ(category: str, name: str, query: Any) -> Generator[None]: original = str(query) yield transformed = str(query) @@ -89,7 +90,7 @@ def get_explain_meta() -> ExplainMeta | None: if hasattr(g, "explain_meta"): return cast(ExplainMeta, g.explain_meta) g.explain_meta = ExplainMeta() - return g.explain_meta + return cast(ExplainMeta, g.explain_meta) except RuntimeError: # Code is executing outside of a flask context return None diff --git a/snuba/state/rate_limit.py b/snuba/state/rate_limit.py index ae4c0bb0012..1d17f17d317 100644 --- a/snuba/state/rate_limit.py +++ b/snuba/state/rate_limit.py @@ -5,11 +5,12 @@ import time import uuid from collections import ChainMap, namedtuple +from collections import ChainMap as TypingChainMap +from collections.abc import Iterator, MutableMapping, Sequence from contextlib import AbstractContextManager, ExitStack, contextmanager from dataclasses import dataclass from types import TracebackType -from typing import Any, Iterator, MutableMapping, Optional, Sequence, Type, cast -from typing import ChainMap as TypingChainMap +from typing import Any, cast from redis.exceptions import TimeoutError as RedisTimeoutError @@ -66,8 +67,8 @@ class RateLimitParameters: rate_limit_name: str bucket: str - per_second_limit: Optional[float] - concurrent_limit: Optional[int] + per_second_limit: float | None + concurrent_limit: int | None class RateLimitExceeded(SerializableException): @@ -103,7 +104,7 @@ def __init__(self) -> None: def add_stats(self, rate_limit_name: str, rate_limit_stats: RateLimitStats) -> None: self.__stats[rate_limit_name] = rate_limit_stats - def get_stats(self, rate_limit_name: str) -> Optional[RateLimitStats]: + def get_stats(self, rate_limit_name: str) -> RateLimitStats | None: return self.__stats.get(rate_limit_name) def __format_single_dict(self, name: str, stats: RateLimitStats) -> MutableMapping[str, float]: @@ -131,7 +132,7 @@ def _get_bucket_key(prefix: str, bucket: str, shard_id: int) -> str: # sharding. shard_suffix = f":shard-{shard_id}" - return "{}{}{}".format(prefix, bucket, shard_suffix) + return f"{prefix}{bucket}{shard_suffix}" def rate_limit_start_request( @@ -199,7 +200,7 @@ def rate_limit_start_request( # it is fine to only perform this cleanup for the shard of the current # query, because on average there will be many other queries that hit other # shards and perform cleanup there - pipe.zremrangebyscore(query_bucket, "-inf", "({:f}".format(now - rate_history_sec)) + pipe.zremrangebyscore(query_bucket, "-inf", f"({now - rate_history_sec:f}") # Now for the tricky bit: # ====================== @@ -261,7 +262,7 @@ def rate_limit_start_request( # of concurrent queries for shard_i in range(rate_limit_shard_factor): bucket = _get_bucket_key(rate_limit_prefix, rate_limit_params.bucket, shard_i) - pipe.zcount(bucket, "({:f}".format(now), "+inf") + pipe.zcount(bucket, f"({now:f}", "+inf") try: results = pipe.execute() @@ -334,7 +335,7 @@ def rate_limit_finish_request( @contextmanager def rate_limit( rate_limit_params: RateLimitParameters, -) -> Iterator[Optional[RateLimitStats]]: +) -> Iterator[RateLimitStats | None]: """ A context manager for rate limiting that allows for limiting based on: * a rolling-window per-second rate @@ -409,7 +410,7 @@ def rate_limit( ) raise RateLimitExceeded( - "{r.scope} {r.name} of {r.val:.0f} exceeds limit of {r.limit:.0f}".format(r=reason), + f"{reason.scope} {reason.name} of {reason.val:.0f} exceeds limit of {reason.limit:.0f}", scope=reason.scope, name=reason.name, ) @@ -457,7 +458,7 @@ def _record_metrics(exc: RateLimitExceeded, rate_limit_param: RateLimitParameter metrics.increment("rate-limited", tags=tags) -class RateLimitAggregator(AbstractContextManager): # type: ignore +class RateLimitAggregator(AbstractContextManager[RateLimitStatsContainer]): """ Runs the rate limits provided by the `rate_limit_params` configuration object. @@ -490,8 +491,8 @@ def __enter__(self) -> RateLimitStatsContainer: def __exit__( self, - exc_type: Optional[Type[BaseException]], - exc_val: Optional[BaseException], - exc_tb: Optional[TracebackType], + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, ) -> None: self.stack.pop_all().close() diff --git a/snuba/subscriptions/codecs.py b/snuba/subscriptions/codecs.py index 9d71b47e695..17131675aa1 100644 --- a/snuba/subscriptions/codecs.py +++ b/snuba/subscriptions/codecs.py @@ -34,8 +34,8 @@ def encode(self, value: SubscriptionData) -> bytes: def decode(self, value: bytes) -> SubscriptionData: try: data = json.loads(value.decode("utf-8")) - except json.JSONDecodeError: - raise InvalidQueryException("Invalid JSON") + except json.JSONDecodeError as e: + raise InvalidQueryException("Invalid JSON") from e if data.get("subscription_type") == SubscriptionType.RPC.value: return RPCSubscriptionData.from_dict(data, self.entity_key) diff --git a/snuba/subscriptions/combined_scheduler_executor.py b/snuba/subscriptions/combined_scheduler_executor.py index bf316313784..69c75a7dfd8 100644 --- a/snuba/subscriptions/combined_scheduler_executor.py +++ b/snuba/subscriptions/combined_scheduler_executor.py @@ -1,5 +1,6 @@ +from collections.abc import Mapping, Sequence from datetime import timedelta -from typing import Mapping, NamedTuple, Optional, Sequence +from typing import NamedTuple from arroyo import Message, Partition, Topic from arroyo.backends.abstract import Producer @@ -48,10 +49,10 @@ def build_scheduler_executor_consumer( auto_offset_reset: str, strict_offset_reset: bool, schedule_ttl: int, - stale_threshold_seconds: Optional[int], + stale_threshold_seconds: int | None, total_concurrent_queries: int, metrics: MetricsBackend, - health_check_file: Optional[str] = None, + health_check_file: str | None = None, ) -> StreamProcessor[Tick]: dataset = get_dataset(dataset_name) @@ -156,10 +157,10 @@ def __init__( total_concurrent_queries: int, producer: Producer[KafkaPayload], metrics: MetricsBackend, - stale_threshold_seconds: Optional[int], + stale_threshold_seconds: int | None, result_topic: str, schedule_ttl: int, - health_check_file: Optional[str] = None, + health_check_file: str | None = None, ) -> None: self.__partitions = partitions self.__entity_names = entity_names @@ -260,7 +261,7 @@ def submit(self, message: Message[Tick]) -> None: tasks = [] for entity_scheduler in self.__schedulers: - tasks.extend([task for task in entity_scheduler[tick.partition].find(tick)]) + tasks.extend(list(entity_scheduler[tick.partition].find(tick))) encoded_tasks = [self.__encoder.encode(task) for task in tasks] @@ -275,5 +276,5 @@ def terminate(self) -> None: self.__closed = True self.__next_step.terminate() - def join(self, timeout: Optional[float] = None) -> None: + def join(self, timeout: float | None = None) -> None: self.__next_step.join(timeout) diff --git a/snuba/subscriptions/data.py b/snuba/subscriptions/data.py index d77676e118d..27c258e2d26 100644 --- a/snuba/subscriptions/data.py +++ b/snuba/subscriptions/data.py @@ -3,6 +3,7 @@ import base64 import logging from abc import ABC, abstractmethod +from collections.abc import Iterator, Mapping from concurrent.futures import Future from dataclasses import dataclass, field from datetime import datetime, timedelta @@ -11,15 +12,9 @@ from typing import ( Any, Generic, - Iterator, - List, - Mapping, NamedTuple, NewType, - Optional, - Tuple, TypeVar, - Union, ) from uuid import UUID @@ -117,12 +112,12 @@ class _SubscriptionData(ABC, Generic[TRequest]): time_window_sec: int entity: Entity metadata: Mapping[str, Any] - tenant_ids: Mapping[str, Any] = field(default_factory=lambda: dict()) + tenant_ids: Mapping[str, Any] = field(default_factory=lambda: {}) def validate(self) -> None: if self.time_window_sec < 60: raise InvalidSubscriptionError("Time window must be greater than or equal to 1 minute") - elif self.time_window_sec > 60 * 60 * 24: + if self.time_window_sec > 60 * 60 * 24: raise InvalidSubscriptionError("Time window must be less than or equal to 24 hours") if self.resolution_sec < 60: @@ -133,9 +128,9 @@ def build_request( self, dataset: Dataset, timestamp: datetime, - offset: Optional[int], + offset: int | None, timer: Timer, - metrics: Optional[MetricsBackend] = None, + metrics: MetricsBackend | None = None, referrer: str = SUBSCRIPTION_REFERRER, ) -> TRequest: raise NotImplementedError @@ -147,7 +142,7 @@ def run_query( request: TRequest, timer: Timer, robust: bool = False, - concurrent_queries_gauge: Optional[Gauge] = None, + concurrent_queries_gauge: Gauge | None = None, ) -> QueryResult: raise NotImplementedError @@ -220,9 +215,9 @@ def build_request( self, dataset: Dataset, timestamp: datetime, - offset: Optional[int], + offset: int | None, timer: Timer, - metrics: Optional[MetricsBackend] = None, + metrics: MetricsBackend | None = None, referrer: str = SUBSCRIPTION_REFERRER, ) -> TimeSeriesRequest: request_class = EndpointTimeSeries().request_class()() @@ -247,7 +242,7 @@ def run_query( request: TimeSeriesRequest, timer: Timer, robust: bool = False, - concurrent_queries_gauge: Optional[Gauge] = None, + concurrent_queries_gauge: Gauge | None = None, ) -> QueryResult: response = EndpointTimeSeries().execute(request) if not response.result_timeseries or not any( @@ -270,7 +265,7 @@ def run_query( def from_dict(cls, data: Mapping[str, Any], entity_key: EntityKey) -> RPCSubscriptionData: entity: Entity = get_entity(entity_key) metadata = {} - for key in data.keys(): + for key in data: if key == "metadata": metadata.update(data[key]) elif key not in SUBSCRIPTION_DATA_PAYLOAD_KEYS: @@ -285,7 +280,7 @@ def from_dict(cls, data: Mapping[str, Any], entity_key: EntityKey) -> RPCSubscri request_name=data["request_name"], entity=entity, metadata=metadata, - tenant_ids=data.get("tenant_ids", dict()), + tenant_ids=data.get("tenant_ids", {}), ) @classmethod @@ -342,12 +337,12 @@ class SnQLSubscriptionData(_SubscriptionData[Request]): def add_conditions( self, timestamp: datetime, - offset: Optional[int], - query: Union[CompositeQuery[EntityDS], Query], + offset: int | None, + query: CompositeQuery[EntityDS] | Query, ) -> None: added_timestamp_column = False from_clause = query.get_from_clause() - entities: List[Tuple[Optional[str], Entity]] = [] + entities: list[tuple[str | None, Entity]] = [] if isinstance(from_clause, JoinClause): for alias, node in from_clause.get_alias_node_map().items(): assert isinstance(node.data_source, EntityDS), node.data_source @@ -357,7 +352,7 @@ def add_conditions( else: raise InvalidSubscriptionError("Only simple queries and join queries are supported") for entity_alias, entity in entities: - conditions_to_add: List[Expression] = [ + conditions_to_add: list[Expression] = [ binary_condition( ConditionFunctions.EQ, Column(None, entity_alias, "project_id"), @@ -407,9 +402,9 @@ def build_request( self, dataset: Dataset, timestamp: datetime, - offset: Optional[int], + offset: int | None, timer: Timer, - metrics: Optional[MetricsBackend] = None, + metrics: MetricsBackend | None = None, referrer: str = SUBSCRIPTION_REFERRER, ) -> Request: schema = RequestSchema.build(SubscriptionQuerySettings) @@ -440,7 +435,7 @@ def build_request( referrer, # subscriptions are tied to entities, these validators are going to run on the entity # anyways so it's okay that the post-processing is done without type safety - custom_processing, # type: ignore + custom_processing, # type: ignore[arg-type] ) return request @@ -450,7 +445,7 @@ def run_query( request: Request, timer: Timer, robust: bool = False, - concurrent_queries_gauge: Optional[Gauge] = None, + concurrent_queries_gauge: Gauge | None = None, ) -> QueryResult: return run_query( dataset, @@ -465,7 +460,7 @@ def from_dict(cls, data: Mapping[str, Any], entity_key: EntityKey) -> SnQLSubscr entity: Entity = get_entity(entity_key) metadata = {} - for key in data.keys(): + for key in data: if key not in SUBSCRIPTION_DATA_PAYLOAD_KEYS: metadata[key] = data[key] @@ -477,7 +472,7 @@ def from_dict(cls, data: Mapping[str, Any], entity_key: EntityKey) -> SnQLSubscr query=data["query"], entity=entity, metadata=metadata, - tenant_ids=data.get("tenant_ids", dict()), + tenant_ids=data.get("tenant_ids", {}), ) def to_dict(self) -> Mapping[str, Any]: @@ -496,8 +491,8 @@ def to_dict(self) -> Mapping[str, Any]: return subscription_data_dict -SubscriptionData = Union[RPCSubscriptionData, SnQLSubscriptionData] -SubscriptionRequest = Union[Request, TimeSeriesRequest] +SubscriptionData = RPCSubscriptionData | SnQLSubscriptionData +SubscriptionRequest = Request | TimeSeriesRequest class Subscription(NamedTuple): @@ -545,9 +540,9 @@ def find(self, tick: Tick) -> Iterator[ScheduledSubscriptionTask]: class SubscriptionTaskResultFuture(NamedTuple): task: ScheduledSubscriptionTask - future: Future[Tuple[SubscriptionRequest, Result]] + future: Future[tuple[SubscriptionRequest, Result]] class SubscriptionTaskResult(NamedTuple): task: ScheduledSubscriptionTask - result: Tuple[SubscriptionRequest, Result] + result: tuple[SubscriptionRequest, Result] diff --git a/snuba/subscriptions/executor_consumer.py b/snuba/subscriptions/executor_consumer.py index afb94964266..d7844991b2e 100644 --- a/snuba/subscriptions/executor_consumer.py +++ b/snuba/subscriptions/executor_consumer.py @@ -4,10 +4,10 @@ import math import time from collections import deque +from collections.abc import Mapping, Sequence from concurrent.futures import ThreadPoolExecutor from concurrent.futures import TimeoutError as FutureTimeoutError from datetime import datetime -from typing import Deque, Mapping, Optional, Sequence, Tuple from arroyo import Message, Partition, Topic from arroyo.backends.abstract import Producer @@ -75,7 +75,7 @@ def calculate_max_concurrent_queries( fall back to 1 replica (meaning the max concurrent queries == the total) """ replicas = total_partition_count / assigned_partition_count - return math.ceil((total_concurrent_queries / replicas)) + return math.ceil(total_concurrent_queries / replicas) def build_executor_consumer( @@ -83,14 +83,14 @@ def build_executor_consumer( entity_names: Sequence[str], consumer_group: str, bootstrap_servers: Sequence[str], - slice_id: Optional[int], + slice_id: int | None, producer: Producer[KafkaPayload], total_concurrent_queries: int, auto_offset_reset: str, - strict_offset_reset: Optional[bool], + strict_offset_reset: bool | None, metrics: MetricsBackend, - stale_threshold_seconds: Optional[int], - health_check_file: Optional[str] = None, + stale_threshold_seconds: int | None, + health_check_file: str | None = None, ) -> StreamProcessor[KafkaPayload]: # Validate that a valid dataset/entity pair was passed in dataset = get_dataset(dataset_name) @@ -101,7 +101,7 @@ def build_executor_consumer( def get_topics_for_entity( entity_name: str, - ) -> Tuple[KafkaTopicSpec, KafkaTopicSpec]: + ) -> tuple[KafkaTopicSpec, KafkaTopicSpec]: assert entity_name in dataset_entity_names, ( f"Entity {entity_name} does not exist in dataset {dataset_name}" ) @@ -171,9 +171,9 @@ def __init__( entity_names: Sequence[str], producer: Producer[KafkaPayload], metrics: MetricsBackend, - stale_threshold_seconds: Optional[int], + stale_threshold_seconds: int | None, result_topic: str, - health_check_file: Optional[str] = None, + health_check_file: str | None = None, ) -> None: self.__total_concurrent_queries = total_concurrent_queries self.__total_partition_count = total_partition_count @@ -223,7 +223,7 @@ def __init__( dataset: Dataset, entity_names: Sequence[str], max_concurrent_queries: int, - stale_threshold_seconds: Optional[int], + stale_threshold_seconds: int | None, metrics: MetricsBackend, next_step: ProcessingStrategy[KafkaPayload], ) -> None: @@ -238,7 +238,7 @@ def __init__( self.__encoder = SubscriptionScheduledTaskEncoder() self.__result_encoder = SubscriptionTaskResultEncoder() - self.__queue: Deque[Tuple[Message[KafkaPayload], SubscriptionTaskResultFuture]] = deque() + self.__queue: deque[tuple[Message[KafkaPayload], SubscriptionTaskResultFuture]] = deque() self.__closed = False @@ -250,7 +250,7 @@ def __init__( def __execute_query( self, task: ScheduledSubscriptionTask, tick_upper_offset: int - ) -> Tuple[SubscriptionRequest, Result]: + ) -> tuple[SubscriptionRequest, Result]: # Measure the amount of time that took between the task's scheduled # time and it beginning to execute. self.__metrics.timing("executor.latency", (time.time() - task.timestamp.timestamp()) * 1000) @@ -269,7 +269,7 @@ def __execute_query( result = task.task.subscription.data.run_query( self.__dataset, - request, # type: ignore + request, # type: ignore[arg-type] timer, robust=True, concurrent_queries_gauge=self.__concurrent_clickhouse_gauge, @@ -300,7 +300,7 @@ def poll(self) -> None: tags={"error_type": str(cause.code)}, ) else: - raise SubscriptionQueryException(exc.message) + raise SubscriptionQueryException(exc.message) from exc self.__next_step.submit(transformed_message) @@ -359,7 +359,7 @@ def terminate(self) -> None: self.__executor.shutdown() self.__next_step.terminate() - def join(self, timeout: Optional[float] = None) -> None: + def join(self, timeout: float | None = None) -> None: start = time.time() while self.__queue: diff --git a/snuba/subscriptions/scheduler.py b/snuba/subscriptions/scheduler.py index 767f80837c6..b375e067345 100644 --- a/snuba/subscriptions/scheduler.py +++ b/snuba/subscriptions/scheduler.py @@ -1,17 +1,8 @@ import math from abc import ABC, abstractmethod +from collections.abc import Iterator, Mapping, MutableMapping, MutableSequence, Sequence from datetime import datetime, timedelta from enum import Enum -from typing import ( - Iterator, - List, - Mapping, - MutableMapping, - MutableSequence, - Optional, - Sequence, - Tuple, -) from snuba import settings, state from snuba.datasets.entities.entity_key import EntityKey @@ -45,11 +36,11 @@ class TaskBuilder(ABC): @abstractmethod def get_task( self, subscription_with_metadata: SubscriptionWithMetadata, timestamp: int - ) -> Optional[ScheduledSubscriptionTask]: + ) -> ScheduledSubscriptionTask | None: raise NotImplementedError @abstractmethod - def reset_metrics(self) -> Sequence[Tuple[str, int, Tags]]: + def reset_metrics(self) -> Sequence[tuple[str, int, Tags]]: raise NotImplementedError @@ -63,7 +54,7 @@ def __init__(self) -> None: def get_task( self, subscription_with_metadata: SubscriptionWithMetadata, timestamp: int - ) -> Optional[ScheduledSubscriptionTask]: + ) -> ScheduledSubscriptionTask | None: subscription = subscription_with_metadata.subscription resolution = subscription.data.resolution_sec @@ -73,11 +64,10 @@ def get_task( datetime.utcfromtimestamp(timestamp), subscription_with_metadata, ) - else: - return None + return None - def reset_metrics(self) -> Sequence[Tuple[str, int, Tags]]: - metrics: Sequence[Tuple[str, int, Tags]] = [("tasks.built", self.__count, {})] + def reset_metrics(self) -> Sequence[tuple[str, int, Tags]]: + metrics: Sequence[tuple[str, int, Tags]] = [("tasks.built", self.__count, {})] self.__count = 0 return metrics @@ -111,7 +101,7 @@ def __init__(self) -> None: def get_task( self, subscription_with_metadata: SubscriptionWithMetadata, timestamp: int - ) -> Optional[ScheduledSubscriptionTask]: + ) -> ScheduledSubscriptionTask | None: subscription = subscription_with_metadata.subscription resolution = subscription.data.resolution_sec @@ -124,8 +114,7 @@ def get_task( datetime.utcfromtimestamp(timestamp), subscription_with_metadata, ) - else: - return None + return None jitter = subscription.identifier.uuid.int % resolution if timestamp % resolution == jitter: @@ -134,11 +123,10 @@ def get_task( datetime.utcfromtimestamp(timestamp - jitter), subscription_with_metadata, ) - else: - return None + return None - def reset_metrics(self) -> Sequence[Tuple[str, int, Tags]]: - metrics: Sequence[Tuple[str, int, Tags]] = [ + def reset_metrics(self) -> Sequence[tuple[str, int, Tags]]: + metrics: Sequence[tuple[str, int, Tags]] = [ ("tasks.built", self.__count, {}), ("tasks.above.resolution", self.__count_max_resolution, {}), ] @@ -235,17 +223,16 @@ def __init__(self) -> None: def get_task( self, subscription_with_metadata: SubscriptionWithMetadata, timestamp: int - ) -> Optional[ScheduledSubscriptionTask]: + ) -> ScheduledSubscriptionTask | None: subscription = subscription_with_metadata.subscription primary_builder = self.__rollout_state.get_current_mode(subscription, timestamp) if primary_builder == TaskBuilderMode.JITTERED: return self.__jittered_builder.get_task(subscription_with_metadata, timestamp) - else: - return self.__immediate_builder.get_task(subscription_with_metadata, timestamp) + return self.__immediate_builder.get_task(subscription_with_metadata, timestamp) - def reset_metrics(self) -> Sequence[Tuple[str, int, Tags]]: + def reset_metrics(self) -> Sequence[tuple[str, int, Tags]]: def add_tag(tags: Tags, builder_type: str) -> Tags: return { **tags, @@ -271,7 +258,7 @@ def filter_subscriptions( subscriptions: MutableSequence[Subscription], entity_key: EntityKey, metrics: MetricsBackend, - slice_id: Optional[int] = None, + slice_id: int | None = None, ) -> MutableSequence[Subscription]: filtered_subscriptions: MutableSequence[Subscription] = [] @@ -312,7 +299,7 @@ def __init__( partition_id: PartitionId, cache_ttl: timedelta, metrics: MetricsBackend, - slice_id: Optional[int] = None, + slice_id: int | None = None, ) -> None: self.__entity_key = entity_key self.__slice_id = slice_id @@ -321,8 +308,8 @@ def __init__( self.__partition_id = partition_id self.__metrics = metrics - self.__subscriptions: List[Subscription] = [] - self.__last_refresh: Optional[datetime] = None + self.__subscriptions: list[Subscription] = [] + self.__last_refresh: datetime | None = None self.__delegate_builder = DelegateTaskBuilder() self.__jittered_builder = JitteredTaskBuilder() @@ -372,8 +359,7 @@ def __get_subscriptions(self) -> MutableSequence[Subscription]: return filter_subscriptions( self.__subscriptions, self.__entity_key, self.__metrics, self.__slice_id ) - else: - return self.__subscriptions + return self.__subscriptions def find(self, tick: Tick) -> Iterator[ScheduledSubscriptionTask]: self.__reset_builder() diff --git a/snuba/subscriptions/scheduler_consumer.py b/snuba/subscriptions/scheduler_consumer.py index d5d2c9de0c2..fed1764f7c0 100644 --- a/snuba/subscriptions/scheduler_consumer.py +++ b/snuba/subscriptions/scheduler_consumer.py @@ -1,6 +1,7 @@ import logging +from collections.abc import Callable, Mapping, MutableMapping, Sequence from datetime import timedelta -from typing import Callable, Mapping, MutableMapping, NamedTuple, Optional, Sequence +from typing import NamedTuple from arroyo.backends.abstract import Consumer, Producer from arroyo.backends.kafka import KafkaConsumer, KafkaPayload @@ -41,7 +42,7 @@ class MessageDetails(NamedTuple): orig_message_ts: float # The timestamp the message was first received by Sentry (Relay) # It is optional since it is not currently present on all topics - received_p99: Optional[float] + received_p99: float | None class CommitLogTickConsumer(Consumer[Tick]): @@ -99,7 +100,7 @@ def __init__( followed_consumer_group: str, metrics: MetricsBackend, synchronization_timestamp: str, - time_shift: Optional[timedelta] = None, + time_shift: timedelta | None = None, ) -> None: self.__consumer = consumer self.__followed_consumer_group = followed_consumer_group @@ -112,8 +113,8 @@ def __init__( def subscribe( self, topics: Sequence[Topic], - on_assign: Optional[Callable[[Mapping[Partition, int]], None]] = None, - on_revoke: Optional[Callable[[Sequence[Partition]], None]] = None, + on_assign: Callable[[Mapping[Partition, int]], None] | None = None, + on_revoke: Callable[[Sequence[Partition]], None] | None = None, ) -> None: def revocation_callback(partitions: Sequence[Partition]) -> None: self.__previous_messages = {} @@ -121,14 +122,12 @@ def revocation_callback(partitions: Sequence[Partition]) -> None: if on_revoke is not None: on_revoke(partitions) - self.__consumer.subscribe( - topics, on_assign=on_assign, on_revoke=revocation_callback - ) + self.__consumer.subscribe(topics, on_assign=on_assign, on_revoke=revocation_callback) def unsubscribe(self) -> None: self.__consumer.unsubscribe() - def poll(self, timeout: Optional[float] = None) -> Optional[BrokerValue[Tick]]: + def poll(self, timeout: float | None = None) -> BrokerValue[Tick] | None: value = self.__consumer.poll(timeout) if value is None: return None @@ -152,7 +151,7 @@ def poll(self, timeout: Optional[float] = None) -> Optional[BrokerValue[Tick]]: previous_message = self.__previous_messages.get(commit.partition) - result: Optional[BrokerValue[Tick]] + result: BrokerValue[Tick] | None if previous_message is not None: try: time_interval = Interval( @@ -204,10 +203,10 @@ def seek(self, offsets: Mapping[Partition, int]) -> None: def stage_offsets(self, offsets: Mapping[Partition, int]) -> None: return self.__consumer.stage_offsets(offsets) - def commit_offsets(self) -> Mapping[Partition, int]: + def commit_offsets(self) -> Mapping[Partition, int] | None: return self.__consumer.commit_offsets() - def close(self, timeout: Optional[float] = None) -> None: + def close(self, timeout: float | None = None) -> None: return self.__consumer.close(timeout) @property @@ -228,20 +227,20 @@ def __init__( bootstrap_servers: Sequence[str], producer: Producer[KafkaPayload], auto_offset_reset: str, - strict_offset_reset: Optional[bool], + strict_offset_reset: bool | None, schedule_ttl: int, - stale_threshold_seconds: Optional[int], + stale_threshold_seconds: int | None, metrics: MetricsBackend, - slice_id: Optional[int] = None, - health_check_file: Optional[str] = None, + slice_id: int | None = None, + health_check_file: str | None = None, ) -> None: self.__entity_key = EntityKey(entity_name) storage = get_entity(self.__entity_key).get_writable_storage() - assert ( - storage is not None - ), f"Entity {entity_name} does not have a writable storage by default." + assert storage is not None, ( + f"Entity {entity_name} does not have a writable storage by default." + ) stream_loader = storage.get_table_writer().get_stream_loader() @@ -252,9 +251,9 @@ def __init__( try: default_topic_spec = stream_loader.get_default_topic_spec() default_topic_config = default_topic_spec.topic_current_config_values - assert ( - default_topic_config["message.timestamp.type"] == "LogAppendTime" - ), f"{default_topic_spec.get_physical_topic_name()} topic requires LogAppendTime" + assert default_topic_config["message.timestamp.type"] == "LogAppendTime", ( + f"{default_topic_spec.get_physical_topic_name()} topic requires LogAppendTime" + ) except AssertionError: raise except Exception: @@ -272,9 +271,7 @@ def __init__( assert mode is not None self.__mode = mode - synchronization_timestamp = ( - stream_loader.get_subscription_sychronization_timestamp() - ) + synchronization_timestamp = stream_loader.get_subscription_sychronization_timestamp() assert synchronization_timestamp is not None self.__synchronization_timestamp = synchronization_timestamp @@ -299,9 +296,7 @@ def __init__( def build_consumer(self) -> StreamProcessor[Tick]: return StreamProcessor( self.__build_tick_consumer(), - Topic( - self.__commit_log_topic_spec.get_physical_topic_name(self.__slice_id) - ), + Topic(self.__commit_log_topic_spec.get_physical_topic_name(self.__slice_id)), self.__build_strategy_factory(), ONCE_PER_SECOND, ) @@ -349,13 +344,13 @@ def __init__( entity_key: EntityKey, mode: SchedulingWatermarkMode, schedule_ttl: int, - stale_threshold_seconds: Optional[int], + stale_threshold_seconds: int | None, partitions: int, producer: Producer[KafkaPayload], scheduled_topic_spec: KafkaTopicSpec, metrics: MetricsBackend, - slice_id: Optional[int] = None, - health_check_file: Optional[str] = None, + slice_id: int | None = None, + health_check_file: str | None = None, ) -> None: self.__mode = mode self.__stale_threshold_seconds = stale_threshold_seconds @@ -373,9 +368,7 @@ def __init__( self.__schedulers = { index: SubscriptionScheduler( entity_key, - RedisSubscriptionDataStore( - redis_client, entity_key, PartitionId(index) - ), + RedisSubscriptionDataStore(redis_client, entity_key, PartitionId(index)), partition_id=PartitionId(index), cache_ttl=timedelta(seconds=schedule_ttl), metrics=self.__metrics, diff --git a/snuba/subscriptions/scheduler_processing_strategy.py b/snuba/subscriptions/scheduler_processing_strategy.py index 4b81ba3eff8..6c1451bd10a 100644 --- a/snuba/subscriptions/scheduler_processing_strategy.py +++ b/snuba/subscriptions/scheduler_processing_strategy.py @@ -3,15 +3,10 @@ import logging import time from collections import deque +from collections.abc import Mapping, MutableMapping from datetime import datetime from typing import ( - Deque, - List, - Mapping, - MutableMapping, NamedTuple, - Optional, - Tuple, cast, ) @@ -37,7 +32,7 @@ class CommittableTick(NamedTuple): tick: Tick # Offset that we can safely committed once the tick is processed. # Not necessarily the same as the tick's offset. - offset_to_commit: Optional[int] + offset_to_commit: int | None class ProvideCommitStrategy(ProcessingStrategy[Tick]): @@ -81,11 +76,11 @@ def __init__( # Store the last message we received for each partition so know when # to commit offsets. - self.__latest_messages_by_partition: MutableMapping[int, Optional[BrokerValue[Tick]]] = { - index: None for index in range(self.__partitions) - } - self.__offset_low_watermark: Optional[int] = None - self.__offset_high_watermark: Optional[int] = None + self.__latest_messages_by_partition: MutableMapping[int, BrokerValue[Tick] | None] = ( + dict.fromkeys(range(self.__partitions)) + ) + self.__offset_low_watermark: int | None = None + self.__offset_high_watermark: int | None = None self.__closed = False @@ -164,7 +159,7 @@ def terminate(self) -> None: self.__closed = True self.__next_step.terminate() - def join(self, timeout: Optional[float] = None) -> None: + def join(self, timeout: float | None = None) -> None: self.__next_step.close() self.__next_step.join(timeout) @@ -197,7 +192,7 @@ def __init__( self, mode: SchedulingWatermarkMode, partitions: int, - max_ticks_buffered_per_partition: Optional[int], + max_ticks_buffered_per_partition: int | None, next_step: ProcessingStrategy[Tick], metrics: MetricsBackend, ) -> None: @@ -210,7 +205,7 @@ def __init__( self.__next_step = next_step self.__metrics = metrics - self.__buffers: Mapping[int, Deque[Message[Tick]]] = { + self.__buffers: Mapping[int, deque[Message[Tick]]] = { index: deque() for index in range(self.__partitions) } @@ -293,7 +288,7 @@ def terminate(self) -> None: self.__closed = True self.__next_step.terminate() - def join(self, timeout: Optional[float] = None) -> None: + def join(self, timeout: float | None = None) -> None: self.__next_step.close() self.__next_step.join(timeout) @@ -301,7 +296,7 @@ def join(self, timeout: Optional[float] = None) -> None: class TickSubscription(NamedTuple): tick_message: BrokerValue[CommittableTick] subscription_future: ProducerFuture[BrokerValue[KafkaPayload]] - offset_to_commit: Optional[int] + offset_to_commit: int | None class ScheduledSubscriptionQueue: @@ -310,22 +305,22 @@ class ScheduledSubscriptionQueue: """ def __init__(self) -> None: - self.__queues: Deque[ - Tuple[ + self.__queues: deque[ + tuple[ BrokerValue[CommittableTick], - Deque[ProducerFuture[BrokerValue[KafkaPayload]]], + deque[ProducerFuture[BrokerValue[KafkaPayload]]], ] ] = deque() def append( self, tick_message: BrokerValue[CommittableTick], - futures: Deque[ProducerFuture[BrokerValue[KafkaPayload]]], + futures: deque[ProducerFuture[BrokerValue[KafkaPayload]]], ) -> None: if len(futures) > 0: self.__queues.append((tick_message, futures)) - def peek(self) -> Optional[TickSubscription]: + def peek(self) -> TickSubscription | None: if self.__queues: tick, futures = self.__queues[0] @@ -385,9 +380,9 @@ def __init__( producer: Producer[KafkaPayload], scheduled_topic_spec: KafkaTopicSpec, commit: Commit, - stale_threshold_seconds: Optional[int], + stale_threshold_seconds: int | None, metrics: MetricsBackend, - slice_id: Optional[int] = None, + slice_id: int | None = None, ) -> None: self.__schedulers = schedulers self.__encoder = SubscriptionScheduledTaskEncoder() @@ -444,9 +439,9 @@ def submit(self, message: Message[CommittableTick]) -> None: self.__stale_threshold_seconds is not None and time.time() - tick.timestamps.lower > self.__stale_threshold_seconds ): - encoded_tasks: List[KafkaPayload] = [] + encoded_tasks: list[KafkaPayload] = [] else: - tasks = [task for task in self.__schedulers[tick.partition].find(tick)] + tasks = list(self.__schedulers[tick.partition].find(tick)) encoded_tasks = [] for task in tasks: @@ -491,7 +486,7 @@ def close(self) -> None: def terminate(self) -> None: self.__closed = True - def join(self, timeout: Optional[float] = None) -> None: + def join(self, timeout: float | None = None) -> None: start = time.time() while self.__queue: diff --git a/snuba/subscriptions/store.py b/snuba/subscriptions/store.py index b801438567d..291eab3fe20 100644 --- a/snuba/subscriptions/store.py +++ b/snuba/subscriptions/store.py @@ -1,6 +1,6 @@ import abc import time -from typing import Iterable, Tuple +from collections.abc import Iterable from uuid import UUID from snuba import environment @@ -30,7 +30,7 @@ def delete(self, key: UUID) -> None: pass @abc.abstractmethod - def all(self) -> Iterable[Tuple[UUID, SubscriptionData]]: + def all(self) -> Iterable[tuple[UUID, SubscriptionData]]: """ Fetches all `Subscriptions` from the store :return: An iterable of `Subscriptions`. @@ -63,7 +63,7 @@ def delete(self, key: UUID) -> None: """ self.client.hdel(self.__key, key.hex.encode("utf-8")) - def all(self) -> Iterable[Tuple[UUID, SubscriptionData]]: + def all(self) -> Iterable[tuple[UUID, SubscriptionData]]: """ Fetches all subscriptions from the store. :return: An iterable of `Subscriptions`. diff --git a/snuba/subscriptions/subscription.py b/snuba/subscriptions/subscription.py index 68aaf840c58..eb283a2220f 100644 --- a/snuba/subscriptions/subscription.py +++ b/snuba/subscriptions/subscription.py @@ -49,7 +49,7 @@ def create(self, data: SubscriptionData, timer: Timer) -> SubscriptionIdentifier def _test_request(self, data: SubscriptionData, timer: Timer) -> None: request = data.build_request(self.dataset, datetime.utcnow(), None, timer) - data.run_query(self.dataset, request, timer) # type: ignore + data.run_query(self.dataset, request, timer) # type: ignore[arg-type] class SubscriptionDeleter: diff --git a/snuba/util.py b/snuba/util.py index f0f07fc2b15..6d6b73f27d4 100644 --- a/snuba/util.py +++ b/snuba/util.py @@ -2,18 +2,15 @@ import _strptime # NOQA fixes _strptime deferred import issue import re +from collections.abc import Callable, MutableMapping, Sequence from datetime import UTC, datetime, timedelta from enum import Enum from functools import wraps +from re import Pattern from typing import ( Any, - Callable, - MutableMapping, NamedTuple, - Pattern, - Sequence, TypeVar, - Union, cast, ) @@ -120,14 +117,12 @@ def get_re(format: Sequence[PartSegment]) -> Pattern[Any]: int(retention_days), ) - else: - raise ValueError("Unknown part name/format: " + str(part_str)) + raise ValueError("Unknown part name/format: " + str(part_str)) -def force_bytes(s: Union[bytes, str]) -> bytes: +def force_bytes(s: bytes | str) -> bytes: if isinstance(s, bytes): return s - elif isinstance(s, str): + if isinstance(s, str): return s.encode("utf-8", "replace") - else: - raise TypeError(f"cannot convert {type(s).__name__} to bytes") + raise TypeError(f"cannot convert {type(s).__name__} to bytes") diff --git a/snuba/utils/bucket_timer.py b/snuba/utils/bucket_timer.py index 7bd1f05fd63..985c9cbb7d1 100644 --- a/snuba/utils/bucket_timer.py +++ b/snuba/utils/bucket_timer.py @@ -2,8 +2,8 @@ import typing from collections import defaultdict +from collections.abc import MutableMapping from datetime import datetime, timedelta -from typing import List, MutableMapping from snuba import environment, state from snuba.state import get_int_config @@ -79,7 +79,7 @@ def record_time_spent(self, project_id: int, start: datetime, end: datetime) -> right += timedelta(minutes=1) self.__add_to_bucket(project_id, start_minute, end - left) - def get_projects_exceeding_limit(self) -> List[int]: + def get_projects_exceeding_limit(self) -> list[int]: now = datetime.now() self.__trim_expired_buckets(now) project_groups: dict[int, timedelta] = defaultdict(lambda: timedelta(seconds=0)) diff --git a/snuba/utils/describer.py b/snuba/utils/describer.py index 33140052daf..a13e7b8d9ef 100644 --- a/snuba/utils/describer.py +++ b/snuba/utils/describer.py @@ -1,8 +1,9 @@ from __future__ import annotations from abc import ABC +from collections.abc import Sequence from dataclasses import dataclass -from typing import NamedTuple, Optional, Sequence, Union +from typing import NamedTuple class Property(NamedTuple): @@ -10,8 +11,8 @@ class Property(NamedTuple): value: str -class DescriptionVisitor(ABC): - def visit_header(self, header: Optional[str]) -> None: +class DescriptionVisitor(ABC): # noqa: B024 - intentional ABC; methods raise NotImplementedError rather than being abstract + def visit_header(self, header: str | None) -> None: raise NotImplementedError def visit_description(self, desc: Description) -> None: @@ -34,8 +35,8 @@ class Description: The serialization method is independent on the structure. """ - header: Optional[str] - content: Sequence[Union[Description, str, Property]] + header: str | None + content: Sequence[Description | str | Property] def accept(self, visitor: DescriptionVisitor) -> None: visitor.visit_header(self.header) @@ -48,7 +49,7 @@ def accept(self, visitor: DescriptionVisitor) -> None: visitor.visit_description(c) -class Describable(ABC): +class Describable(ABC): # noqa: B024 - intentional ABC; methods raise NotImplementedError rather than being abstract """ Class to be extended by any data structure we want to describe either via CLI commands, UI or API. diff --git a/snuba/utils/gcs.py b/snuba/utils/gcs.py index fef2ee1d9cd..3c7efc753e2 100644 --- a/snuba/utils/gcs.py +++ b/snuba/utils/gcs.py @@ -1,8 +1,9 @@ import os -from typing import NamedTuple, Optional, Sequence +from collections.abc import Sequence +from typing import NamedTuple import structlog -from google.cloud.storage.client import Client # type: ignore +from google.cloud.storage.client import Client # type: ignore[import-untyped] logger = structlog.get_logger().bind(module=__name__) @@ -29,9 +30,7 @@ def __init__(self, bucket_name: str): self.bucket_name = bucket_name self.storage_client = Client(project=self.project_id) - def upload_file( - self, source_file_name: str, destination_blob_name: Optional[str] = None - ) -> None: + def upload_file(self, source_file_name: str, destination_blob_name: str | None = None) -> None: """ Upload a file to the bucket. If no destination_blob_name is specified, the source file name is used. """ @@ -56,7 +55,7 @@ def download_file(self, source_blob_name: str, destination_file_name: str) -> No logger.info(f"File {source_blob_name} downloaded to {destination_file_name}.") - def list_blobs(self, prefix: Optional[str] = None, delimiter: Optional[str] = None) -> Blobs: + def list_blobs(self, prefix: str | None = None, delimiter: str | None = None) -> Blobs: """ List blob names. If the prefix is provided, it will list blob names that exists under that prefix only. Delimiter is used if you want to get back prefixes. See @@ -67,7 +66,7 @@ def list_blobs(self, prefix: Optional[str] = None, delimiter: Optional[str] = No names = [blob.name for blob in blobs] prefixes = [] if delimiter: - prefixes = [prefix for prefix in blobs.prefixes] + prefixes = list(blobs.prefixes) return Blobs(names, prefixes) @@ -77,4 +76,4 @@ def blob_exists(self, source_blob_name: str) -> bool: """ bucket = self.storage_client.bucket(self.bucket_name) blob = bucket.blob(source_blob_name) - return True if blob.exists() else False # satisfy mypy + return bool(blob.exists()) # satisfy mypy diff --git a/snuba/utils/health_info.py b/snuba/utils/health_info.py index 9b47c1342ac..a5ef8b4587f 100644 --- a/snuba/utils/health_info.py +++ b/snuba/utils/health_info.py @@ -2,8 +2,8 @@ import logging import time +from collections.abc import Mapping from dataclasses import dataclass -from typing import Dict, Mapping, Union import simplejson as json @@ -20,13 +20,13 @@ class HealthInfo: body: str status: int - content_type: Dict[str, str] + content_type: dict[str, str] -def get_health_info(thorough: Union[bool, str]) -> HealthInfo: +def get_health_info(thorough: bool | str) -> HealthInfo: start = time.time() - body: Mapping[str, Union[str, bool]] = {"status": "ok"} + body: Mapping[str, str | bool] = {"status": "ok"} payload = json.dumps(body) metrics.timing( diff --git a/snuba/utils/iterators.py b/snuba/utils/iterators.py index b14517161f8..19f4fbe2f95 100644 --- a/snuba/utils/iterators.py +++ b/snuba/utils/iterators.py @@ -1,4 +1,5 @@ -from typing import Iterable, Iterator, MutableSequence, Sequence, TypeVar +from collections.abc import Iterable, Iterator, MutableSequence, Sequence +from typing import TypeVar T = TypeVar("T") diff --git a/snuba/utils/manage_topics.py b/snuba/utils/manage_topics.py index ef1e86703f1..1c381e1b69e 100644 --- a/snuba/utils/manage_topics.py +++ b/snuba/utils/manage_topics.py @@ -1,9 +1,12 @@ import logging import time -from typing import Sequence +from collections.abc import Sequence from confluent_kafka import KafkaError, KafkaException -from confluent_kafka.admin import AdminClient, NewTopic +from confluent_kafka.admin import ( # type: ignore[attr-defined] # NewTopic lacks explicit re-export + AdminClient, + NewTopic, +) from snuba.datasets.table_storage import KafkaTopicSpec from snuba.utils.streams.topics import Topic @@ -21,16 +24,16 @@ def create_topics(client: AdminClient, topics: Sequence[Topic], num_partitions: topic_spec.topic_name, num_partitions=num_partitions, replication_factor=1, - config=topic_spec.topic_creation_config, + config=dict(topic_spec.topic_creation_config), ) logger.info("Creating Kafka topics...") - for topic, future in client.create_topics( + for topic_name, future in client.create_topics( list(topics_to_create.values()), operation_timeout=1 ).items(): try: future.result() - logger.info("Topic %s created", topic) + logger.info("Topic %s created", topic_name) except KafkaException as err: if err.args[0].code() != KafkaError.TOPIC_ALREADY_EXISTS: logger.error("Failed to create topic %s", topic, exc_info=err) @@ -43,7 +46,7 @@ def recreate_topic(client: AdminClient, topic: Topic, num_partitions: int = 1) - topic_spec.topic_name, num_partitions=num_partitions, replication_factor=1, - config=topic_spec.topic_creation_config, + config=dict(topic_spec.topic_creation_config), ) logger.info(f"Deleting Kafka topic {topic_spec.topic_name} ...") future = client.delete_topics([topic_spec.topic_name])[topic_spec.topic_name] @@ -58,10 +61,12 @@ def recreate_topic(client: AdminClient, topic: Topic, num_partitions: int = 1) - time.sleep(2) logger.info(f"Recreating Kafka topic {topic_spec.topic_name} ...") - for topic, future in client.create_topics([topic_to_recreate], operation_timeout=1).items(): + for topic_name, future in client.create_topics( + [topic_to_recreate], operation_timeout=1 + ).items(): try: future.result() - logger.info("Topic %s recreated", topic) + logger.info("Topic %s recreated", topic_name) except KafkaException as err: if err.args[0].code() != KafkaError.TOPIC_ALREADY_EXISTS: logger.error("Failed to recreate topic %s", topic, exc_info=err) diff --git a/snuba/utils/metrics/addr_config.py b/snuba/utils/metrics/addr_config.py index eab659c8c81..5946c4f723a 100644 --- a/snuba/utils/metrics/addr_config.py +++ b/snuba/utils/metrics/addr_config.py @@ -1,8 +1,7 @@ import os -from typing import Optional -def get_statsd_addr() -> tuple[Optional[str], Optional[int]]: +def get_statsd_addr() -> tuple[str | None, int | None]: """Returns the address of the StatsD server.""" snuba_statsd_address = os.environ.get("SNUBA_STATSD_ADDR") if snuba_statsd_address: diff --git a/snuba/utils/metrics/backends/abstract.py b/snuba/utils/metrics/backends/abstract.py index fbe6a6761bb..2431cf7d22d 100644 --- a/snuba/utils/metrics/backends/abstract.py +++ b/snuba/utils/metrics/backends/abstract.py @@ -1,5 +1,4 @@ from abc import ABC, abstractmethod -from typing import Optional, Union from snuba.utils.metrics.types import Tags @@ -13,9 +12,9 @@ class MetricsBackend(ABC): def increment( self, name: str, - value: Union[int, float] = 1, - tags: Optional[Tags] = None, - unit: Optional[str] = None, + value: int | float = 1, + tags: Tags | None = None, + unit: str | None = None, ) -> None: """ Increment a counter metric. These increments can also be @@ -33,9 +32,9 @@ def increment( def gauge( self, name: str, - value: Union[int, float], - tags: Optional[Tags] = None, - unit: Optional[str] = None, + value: int | float, + tags: Tags | None = None, + unit: str | None = None, ) -> None: """ Emit a metric that is the authoritative value for a quantity at a point in time @@ -50,9 +49,9 @@ def gauge( def timing( self, name: str, - value: Union[int, float], - tags: Optional[Tags] = None, - unit: Optional[str] = None, + value: int | float, + tags: Tags | None = None, + unit: str | None = None, ) -> None: """ Emit a metric for the timing performance of an operation. @@ -67,9 +66,9 @@ def timing( def distribution( self, name: str, - value: Union[int, float], - tags: Optional[Tags] = None, - unit: Optional[str] = None, + value: int | float, + tags: Tags | None = None, + unit: str | None = None, ) -> None: """ Emit a metric for the performance of an operation. @@ -87,6 +86,6 @@ def events( text: str, alert_type: str, priority: str, - tags: Optional[Tags] = None, + tags: Tags | None = None, ) -> None: raise NotImplementedError diff --git a/snuba/utils/metrics/backends/datadog.py b/snuba/utils/metrics/backends/datadog.py index 1bd05692e9f..f2243ecad23 100644 --- a/snuba/utils/metrics/backends/datadog.py +++ b/snuba/utils/metrics/backends/datadog.py @@ -1,9 +1,9 @@ from __future__ import annotations import threading -from typing import Callable, Mapping, Optional, Sequence, Union +from collections.abc import Callable, Mapping -from datadog import DogStatsd +from datadog.dogstatsd.base import DogStatsd from snuba.utils.metrics.backends.abstract import MetricsBackend from snuba.utils.metrics.types import Tags @@ -17,7 +17,7 @@ class DatadogMetricsBackend(MetricsBackend): def __init__( self, client_factory: Callable[[], DogStatsd], - sample_rates: Optional[Mapping[str, float]] = None, + sample_rates: Mapping[str, float] | None = None, ) -> None: """ :param client_factory: A function that returns a new ``DogStatsd`` @@ -35,23 +35,22 @@ def __init__( @property def __client(self) -> DogStatsd: try: - client = self.__thread_state.client + client: DogStatsd = self.__thread_state.client except AttributeError: client = self.__thread_state.client = self.__client_factory() return client - def __normalize_tags(self, tags: Optional[Tags]) -> Optional[Sequence[str]]: + def __normalize_tags(self, tags: Tags | None) -> list[str] | None: if tags is None: return None - else: - return [f"{key}:{value.replace('|', '_')}" for key, value in tags.items()] + return [f"{key}:{value.replace('|', '_')}" for key, value in tags.items()] def increment( self, name: str, - value: Union[int, float] = 1, - tags: Optional[Tags] = None, - unit: Optional[str] = None, + value: int | float = 1, + tags: Tags | None = None, + unit: str | None = None, ) -> None: self.__client.increment( name, @@ -63,9 +62,9 @@ def increment( def gauge( self, name: str, - value: Union[int, float], - tags: Optional[Tags] = None, - unit: Optional[str] = None, + value: int | float, + tags: Tags | None = None, + unit: str | None = None, ) -> None: self.__client.gauge( name, @@ -77,9 +76,9 @@ def gauge( def timing( self, name: str, - value: Union[int, float], - tags: Optional[Tags] = None, - unit: Optional[str] = None, + value: int | float, + tags: Tags | None = None, + unit: str | None = None, ) -> None: self.__client.timing( name, @@ -91,9 +90,9 @@ def timing( def distribution( self, name: str, - value: Union[int, float], - tags: Optional[Tags] = None, - unit: Optional[str] = None, + value: int | float, + tags: Tags | None = None, + unit: str | None = None, ) -> None: self.__client.distribution( name, @@ -108,9 +107,12 @@ def events( text: str, alert_type: str, priority: str, - tags: Optional[Tags] = None, + tags: Tags | None = None, ) -> None: - self.__client.event( + # datadog's DogStatsd.event is untyped and its second positional + # parameter is named ``message`` in the installed package; keep the + # existing keyword call to preserve runtime behavior. + self.__client.event( # type: ignore[no-untyped-call, call-arg] title=title, text=text, alert_type=alert_type, diff --git a/snuba/utils/metrics/backends/dummy.py b/snuba/utils/metrics/backends/dummy.py index c3d3ada1b91..ec3203c66a0 100644 --- a/snuba/utils/metrics/backends/dummy.py +++ b/snuba/utils/metrics/backends/dummy.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Mapping, Optional, Union +from collections.abc import Mapping from snuba.utils.metrics.backends.abstract import MetricsBackend from snuba.utils.metrics.types import Tags @@ -28,9 +28,9 @@ def __validate_tags(self, tags: Tags) -> None: def increment( self, name: str, - value: Union[int, float] = 1, - tags: Optional[Tags] = None, - unit: Optional[str] = None, + value: int | float = 1, + tags: Tags | None = None, + unit: str | None = None, ) -> None: if self.__strict: assert isinstance(name, str) @@ -41,9 +41,9 @@ def increment( def gauge( self, name: str, - value: Union[int, float], - tags: Optional[Tags] = None, - unit: Optional[str] = None, + value: int | float, + tags: Tags | None = None, + unit: str | None = None, ) -> None: if self.__strict: assert isinstance(name, str) @@ -54,9 +54,9 @@ def gauge( def timing( self, name: str, - value: Union[int, float], - tags: Optional[Tags] = None, - unit: Optional[str] = None, + value: int | float, + tags: Tags | None = None, + unit: str | None = None, ) -> None: if self.__strict: assert isinstance(name, str) @@ -67,9 +67,9 @@ def timing( def distribution( self, name: str, - value: Union[int, float], - tags: Optional[Tags] = None, - unit: Optional[str] = None, + value: int | float, + tags: Tags | None = None, + unit: str | None = None, ) -> None: if self.__strict: assert isinstance(name, str) @@ -83,7 +83,7 @@ def events( text: str, alert_type: str, priority: str, - tags: Optional[Tags] = None, + tags: Tags | None = None, ) -> None: if self.__strict: assert isinstance(title, str) diff --git a/snuba/utils/metrics/backends/sentry.py b/snuba/utils/metrics/backends/sentry.py index 42d0b3c88ed..dc144e1401b 100644 --- a/snuba/utils/metrics/backends/sentry.py +++ b/snuba/utils/metrics/backends/sentry.py @@ -40,7 +40,7 @@ def timing( unit: str | None = None, ) -> None: # The Sentry SDK has strict typing on the unit, so it doesn't allow passing arbitrary units - metrics.timing(name, value, unit or "millisecond", tags) # type: ignore + metrics.timing(name, value, unit or "millisecond", tags) # type: ignore[arg-type] def distribution( self, diff --git a/snuba/utils/metrics/backends/testing.py b/snuba/utils/metrics/backends/testing.py index ef5d99bf85d..d07d86a16de 100644 --- a/snuba/utils/metrics/backends/testing.py +++ b/snuba/utils/metrics/backends/testing.py @@ -1,7 +1,7 @@ from __future__ import annotations +from collections.abc import Mapping, MutableMapping from dataclasses import dataclass -from typing import List, Mapping, MutableMapping, Optional, Union from snuba.utils.metrics.backends.abstract import MetricsBackend from snuba.utils.metrics.types import Tags @@ -20,19 +20,19 @@ class RecordedEventCall: tags: Tags -RecordedMetricCalls = List[RecordedMetricCall] +RecordedMetricCalls = list[RecordedMetricCall] -RECORDED_METRIC_CALLS: MutableMapping[str, MutableMapping[str, List[RecordedMetricCall]]] = {} -RECORDED_EVENT_CALLS: MutableMapping[str, List[RecordedEventCall]] = {} +RECORDED_METRIC_CALLS: MutableMapping[str, MutableMapping[str, list[RecordedMetricCall]]] = {} +RECORDED_EVENT_CALLS: MutableMapping[str, list[RecordedEventCall]] = {} def record_metric_call( mtype: str, name: str, value: int | float, - tags: Optional[Tags], - unit: Optional[str] = None, + tags: Tags | None, + unit: str | None = None, ) -> None: if mtype not in RECORDED_METRIC_CALLS: RECORDED_METRIC_CALLS[mtype] = {} @@ -46,7 +46,7 @@ def record_metric_call( def record_event_call( - title: str, text: str, alert_type: str, priority: str, tags: Optional[Tags] = None + title: str, text: str, alert_type: str, priority: str, tags: Tags | None = None ) -> None: value = str( { @@ -69,7 +69,7 @@ def get_recorded_metric_calls(mtype: str, name: str) -> RecordedMetricCalls | No """ Used in tests to determine if the metrics were called with the correct values """ - return RECORDED_METRIC_CALLS.get(mtype, dict()).get(name) + return RECORDED_METRIC_CALLS.get(mtype, {}).get(name) class TestingMetricsBackend(MetricsBackend): @@ -92,9 +92,9 @@ def __validate_tags(self, tags: Tags) -> None: def increment( self, name: str, - value: Union[int, float] = 1, - tags: Optional[Tags] = None, - unit: Optional[str] = None, + value: int | float = 1, + tags: Tags | None = None, + unit: str | None = None, ) -> None: record_metric_call("increment", name, value, tags) if self.__strict: @@ -106,9 +106,9 @@ def increment( def gauge( self, name: str, - value: Union[int, float], - tags: Optional[Tags] = None, - unit: Optional[str] = None, + value: int | float, + tags: Tags | None = None, + unit: str | None = None, ) -> None: record_metric_call("gauge", name, value, tags) if self.__strict: @@ -120,9 +120,9 @@ def gauge( def timing( self, name: str, - value: Union[int, float], - tags: Optional[Tags] = None, - unit: Optional[str] = None, + value: int | float, + tags: Tags | None = None, + unit: str | None = None, ) -> None: record_metric_call("timing", name, value, tags) if self.__strict: @@ -134,9 +134,9 @@ def timing( def distribution( self, name: str, - value: Union[int, float], - tags: Optional[Tags] = None, - unit: Optional[str] = None, + value: int | float, + tags: Tags | None = None, + unit: str | None = None, ) -> None: record_metric_call("distribution", name, value, tags) if self.__strict: @@ -151,7 +151,7 @@ def events( text: str, alert_type: str, priority: str, - tags: Optional[Tags] = None, + tags: Tags | None = None, ) -> None: record_event_call(title, text, alert_type, priority, tags) if self.__strict: diff --git a/snuba/utils/metrics/gauge.py b/snuba/utils/metrics/gauge.py index 1735d8b7f95..f98e0e3a83c 100644 --- a/snuba/utils/metrics/gauge.py +++ b/snuba/utils/metrics/gauge.py @@ -1,5 +1,5 @@ from threading import Lock -from typing import Any, Optional +from typing import Any from snuba.utils.metrics.backends.abstract import MetricsBackend from snuba.utils.metrics.types import Tags @@ -10,7 +10,7 @@ def __init__( self, metrics: MetricsBackend, name: str, - tags: Optional[Tags] = None, + tags: Tags | None = None, ) -> None: self.__metrics = metrics self.__name = name @@ -25,9 +25,9 @@ def __enter__(self) -> None: def __exit__( self, - type: Optional[Any] = None, - value: Optional[Any] = None, - traceback: Optional[Any] = None, + type: Any | None = None, + value: Any | None = None, + traceback: Any | None = None, ) -> None: self.decrement() @@ -55,8 +55,8 @@ def __init__( self, metrics: MetricsBackend, name: str, - tags: Optional[Tags] = None, - lock: Optional[Lock] = None, + tags: Tags | None = None, + lock: Lock | None = None, ) -> None: if lock is None: lock = Lock() diff --git a/snuba/utils/metrics/timer.py b/snuba/utils/metrics/timer.py index b8a52bfadba..dd306e1267c 100644 --- a/snuba/utils/metrics/timer.py +++ b/snuba/utils/metrics/timer.py @@ -1,5 +1,5 @@ +from collections.abc import Mapping, MutableSequence from itertools import groupby -from typing import Dict, List, Mapping, MutableSequence, Optional, Tuple from sentry_kafka_schemas.schema_types.snuba_queries_v1 import TimerData @@ -10,7 +10,7 @@ class MissingTimerMarksException(SerializableException): - def __init__(self, marks: List[str]): + def __init__(self, marks: list[str]): super().__init__(f"Please pass in timer marks that exist: missing {marks}") @@ -18,15 +18,15 @@ class Timer: def __init__( self, name: str, - clock: Clock = SystemClock(), - tags: Optional[Mapping[str, str]] = None, + clock: Clock | None = None, + tags: Mapping[str, str] | None = None, ): self.__name = name - self.__clock = clock + self.__clock = clock if clock is not None else SystemClock() - self.__marks: MutableSequence[Tuple[str, float]] = [(self.__name, self.__clock.time())] - self.__data: Optional[TimerData] = None - self.__tags: Dict[str, str] = dict(tags or {}) + self.__marks: MutableSequence[tuple[str, float]] = [(self.__name, self.__clock.time())] + self.__data: TimerData | None = None + self.__tags: dict[str, str] = dict(tags or {}) def mark(self, name: str) -> None: self.__data = None @@ -95,7 +95,7 @@ def finish(self) -> TimerData: return self.__data @property - def tags(self) -> Optional[Mapping[str, str]]: + def tags(self) -> Mapping[str, str] | None: return self.__tags def for_json(self) -> TimerData: @@ -104,8 +104,8 @@ def for_json(self) -> TimerData: def send_metrics_to( self, backend: MetricsBackend, - tags: Optional[Tags] = None, - mark_tags: Optional[Tags] = None, + tags: Tags | None = None, + mark_tags: Tags | None = None, ) -> None: data = self.finish() merged_tags = {**data["tags"], **tags} if tags else self.__tags diff --git a/snuba/utils/metrics/types.py b/snuba/utils/metrics/types.py index 1572291d66e..10e99f428fd 100644 --- a/snuba/utils/metrics/types.py +++ b/snuba/utils/metrics/types.py @@ -1,3 +1,3 @@ -from typing import Mapping +from collections.abc import Mapping Tags = Mapping[str, str] diff --git a/snuba/utils/metrics/util.py b/snuba/utils/metrics/util.py index 5a897af055c..a756eabb86b 100644 --- a/snuba/utils/metrics/util.py +++ b/snuba/utils/metrics/util.py @@ -1,8 +1,9 @@ +import _strptime # NOQA fixes _strptime deferred import issue import inspect from functools import partial, wraps -from typing import Any, Callable, Mapping, Optional, TypeVar, cast +from typing import Any, TypeVar, cast +from collections.abc import Callable, Mapping -import _strptime # NOQA fixes _strptime deferred import issue import sentry_sdk from snuba import settings @@ -12,30 +13,30 @@ def create_metrics( prefix: str, - tags: Optional[Tags] = None, - sample_rates: Optional[Mapping[str, float]] = None, + tags: Tags | None = None, + sample_rates: Mapping[str, float] | None = None, ) -> MetricsBackend: """Create a DogStatsd object if DOGSTATSD_HOST and DOGSTATSD_PORT are defined, with the specified prefix and tags. Return a DummyMetricsBackend otherwise. Prefixes must start with `snuba.`, for example: `snuba.processor`. """ - host: Optional[str] = settings.DOGSTATSD_HOST - port: Optional[int] = settings.DOGSTATSD_PORT + host: str | None = settings.DOGSTATSD_HOST + port: int | None = settings.DOGSTATSD_PORT if settings.TESTING: from snuba.utils.metrics.backends.testing import TestingMetricsBackend return TestingMetricsBackend() - elif host is None and port is None: + if host is None and port is None: from snuba.utils.metrics.backends.dummy import DummyMetricsBackend return DummyMetricsBackend() - elif host is None or port is None: + if host is None or port is None: raise ValueError( f"DOGSTATSD_HOST and DOGSTATSD_PORT should both be None or not None. Found DOGSTATSD_HOST: {host}, DOGSTATSD_PORT: {port} instead." ) - from datadog import DogStatsd + from datadog import DogStatsd # type: ignore[attr-defined] # datadog lacks explicit re-export from snuba.utils.metrics.backends.datadog import DatadogMetricsBackend from snuba.utils.metrics.backends.dualwrite import SentryDatadogMetricsBackend diff --git a/snuba/utils/metrics/wrapper.py b/snuba/utils/metrics/wrapper.py index f81e68d5483..1464d8e14e2 100644 --- a/snuba/utils/metrics/wrapper.py +++ b/snuba/utils/metrics/wrapper.py @@ -1,5 +1,3 @@ -from typing import Optional, Union - from snuba.utils.metrics.backends.abstract import MetricsBackend from snuba.utils.metrics.types import Tags @@ -8,8 +6,8 @@ class MetricsWrapper(MetricsBackend): def __init__( self, backend: MetricsBackend, - name: Optional[str] = None, - tags: Optional[Tags] = None, + name: str | None = None, + tags: Tags | None = None, ) -> None: self.__backend = backend self.__name = name @@ -18,50 +16,48 @@ def __init__( def __merge_name(self, name: str) -> str: if self.__name is None: return name - else: - return f"{self.__name}.{name}" + return f"{self.__name}.{name}" - def __merge_tags(self, tags: Optional[Tags]) -> Optional[Tags]: + def __merge_tags(self, tags: Tags | None) -> Tags | None: if self.__tags is None: return tags - elif tags is None: + if tags is None: return self.__tags - else: - return {**tags, **self.__tags} + return {**tags, **self.__tags} def increment( self, name: str, - value: Union[int, float] = 1, - tags: Optional[Tags] = None, - unit: Optional[str] = None, + value: int | float = 1, + tags: Tags | None = None, + unit: str | None = None, ) -> None: self.__backend.increment(self.__merge_name(name), value, self.__merge_tags(tags), unit) def gauge( self, name: str, - value: Union[int, float], - tags: Optional[Tags] = None, - unit: Optional[str] = None, + value: int | float, + tags: Tags | None = None, + unit: str | None = None, ) -> None: self.__backend.gauge(self.__merge_name(name), value, self.__merge_tags(tags), unit) def timing( self, name: str, - value: Union[int, float], - tags: Optional[Tags] = None, - unit: Optional[str] = None, + value: int | float, + tags: Tags | None = None, + unit: str | None = None, ) -> None: self.__backend.timing(self.__merge_name(name), value, self.__merge_tags(tags), unit) def distribution( self, name: str, - value: Union[int, float], - tags: Optional[Tags] = None, - unit: Optional[str] = None, + value: int | float, + tags: Tags | None = None, + unit: str | None = None, ) -> None: self.__backend.distribution(self.__merge_name(name), value, self.__merge_tags(tags), unit) @@ -71,6 +67,6 @@ def events( text: str, alert_type: str, priority: str, - tags: Optional[Tags] = None, + tags: Tags | None = None, ) -> None: self.__backend.events(title, text, alert_type, priority, self.__merge_tags(tags)) diff --git a/snuba/utils/profiler.py b/snuba/utils/profiler.py index fece6ac753c..8727df523df 100644 --- a/snuba/utils/profiler.py +++ b/snuba/utils/profiler.py @@ -2,7 +2,6 @@ import socket import threading import time -from typing import Union import sentry_sdk from sentry_sdk.tracing import NoOpSpan, Transaction @@ -31,7 +30,7 @@ def _profiler_main() -> None: logger.warning("starting ondemand profile for %s", own_hostname) with sentry_sdk.Hub.main: - open_transaction: Union[Transaction, NoOpSpan, None] = sentry_sdk.start_transaction( + open_transaction: Transaction | NoOpSpan | None = sentry_sdk.start_transaction( name=f"ondemand profile: {own_hostname}", sampled=True ) assert isinstance(open_transaction, Transaction) diff --git a/snuba/utils/rate_limiter.py b/snuba/utils/rate_limiter.py index 04b0a878cfb..7f4e838d6be 100644 --- a/snuba/utils/rate_limiter.py +++ b/snuba/utils/rate_limiter.py @@ -3,7 +3,7 @@ import time from enum import Enum from threading import Lock -from typing import Any, Optional, Tuple +from typing import Any from snuba import state @@ -30,14 +30,14 @@ class RateLimiter: window. """ - def __init__(self, bucket: str, max_rate_per_sec: Optional[float] = None) -> None: + def __init__(self, bucket: str, max_rate_per_sec: float | None = None) -> None: self.__lock = Lock() - self.__bucket_epoch: Optional[int] = None - self.__bucket_attempts: Optional[int] = None + self.__bucket_epoch: int | None = None + self.__bucket_attempts: int | None = None self.__max_rate_per_sec = max_rate_per_sec self.__bucket = bucket - def __enter__(self) -> Tuple[RateLimitResult, int]: + def __enter__(self) -> tuple[RateLimitResult, int]: limit = ( state.get_config(f"{RATE_LIMIT_PER_SEC_KEY_PREFIX}{self.__bucket}", None) if not self.__max_rate_per_sec diff --git a/snuba/utils/registered_class.py b/snuba/utils/registered_class.py index f448fb5b722..bcd182f2e14 100644 --- a/snuba/utils/registered_class.py +++ b/snuba/utils/registered_class.py @@ -3,7 +3,8 @@ import importlib import os from abc import ABCMeta -from typing import Any, Dict, Sequence, Tuple, Type, cast +from collections.abc import Sequence +from typing import Any, cast class NoConfigKeyError(Exception): @@ -22,9 +23,9 @@ class _ClassRegistry: """Keep a mapping of classes to their names""" def __init__(self) -> None: - self.__mapping: Dict[str, RegisteredClass] = {} + self.__mapping: dict[str, RegisteredClass] = {} - def register_class(self, cls: "RegisteredClass") -> None: + def register_class(self, cls: RegisteredClass) -> None: key = cls.config_key() existing_class = self.__mapping.get(key) if not existing_class: @@ -34,7 +35,7 @@ def register_class(self, cls: "RegisteredClass") -> None: f"Class with name {key} already exists in the registry, change the config_key property in the class {cls} or {existing_class}" ) - def get_class_from_name(self, config_key: str) -> "RegisteredClass": + def get_class_from_name(self, config_key: str) -> RegisteredClass: res = self.__mapping.get(config_key) if res is None: raise InvalidConfigKeyError( @@ -42,7 +43,7 @@ def get_class_from_name(self, config_key: str) -> "RegisteredClass": ) return res - def all_classes(self) -> Sequence["RegisteredClass"]: + def all_classes(self) -> Sequence[RegisteredClass]: return list(self.__mapping.values()) def all_names(self) -> Sequence[str]: @@ -78,30 +79,34 @@ def get_from_name(cls, name: str) -> "SomeGenericClass": """ + # Injected onto each class that uses this metaclass in __new__ below. + # Declared here so attribute access on registered classes type-checks. + _registry: _ClassRegistry + def config_key(cls) -> str: raise NotImplementedError - def __new__(cls, name: str, bases: Tuple[Type[Any]], dct: Dict[str, Any]) -> Any: + def __new__(cls, name: str, bases: tuple[type[Any]], dct: dict[str, Any]) -> Any: res = super().__new__(cls, name, bases, dct) if not hasattr(res, "config_key"): raise NoConfigKeyError("RegisteredClass(es) must define the `config-key` property") if not hasattr(res, "_registry"): - setattr(res, "_registry", _ClassRegistry()) + res._registry = _ClassRegistry() else: - getattr(res, "_registry").register_class(res) + res._registry.register_class(res) return res - def class_from_name(self, name: str) -> Type[Any]: + def class_from_name(self, name: str) -> type[Any]: return cast( - Type[Any], - getattr(self, "_registry").get_class_from_name(name), + type[Any], + self._registry.get_class_from_name(name), ) - def all_classes(self) -> Sequence[Type[Any]]: - return [cast(Type[Any], rclass) for rclass in getattr(self, "_registry").all_classes()] + def all_classes(self) -> Sequence[type[Any]]: + return [cast(type[Any], rclass) for rclass in self._registry.all_classes()] def all_names(self) -> Sequence[str]: - return list(getattr(self, "_registry").all_names()) + return list(self._registry.all_names()) TModule = object diff --git a/snuba/utils/schemas.py b/snuba/utils/schemas.py index c0bb11d1c20..6bdcbbc1c7e 100644 --- a/snuba/utils/schemas.py +++ b/snuba/utils/schemas.py @@ -2,22 +2,14 @@ import uuid from abc import ABC, abstractmethod +from collections.abc import Callable, Iterator, Mapping, MutableMapping, Sequence from dataclasses import dataclass from functools import partial from itertools import chain from typing import Any as AnyType from typing import ( - Callable, Generic, - Iterator, - List, - Mapping, - MutableMapping, - Optional, - Sequence, - Type, TypeVar, - Union, cast, ) @@ -68,7 +60,7 @@ def _get_modifiers(self) -> Sequence[TypeModifier]: """ raise NotImplementedError - def has_modifier(self, modifier: Type[TypeModifier]) -> bool: + def has_modifier(self, modifier: type[TypeModifier]) -> bool: """ Returns true if a modifier of the type provided is present in this container. @@ -88,20 +80,15 @@ def has_modifier(self, modifier: Type[TypeModifier]) -> bool: # have to be provided everytime we define a column in a migration or # in a schema. class ColumnType(Generic[TModifiers]): - def __init__(self, modifiers: Optional[TModifiers] = None): + def __init__(self, modifiers: TModifiers | None = None): self.__modifiers = modifiers def __repr__(self) -> str: # return f"{self.__class__.__name__}({self._repr_content()})[{self.__modifiers}]" repr_content = self._repr_content() if repr_content: - return "schemas.{}({}, modifiers={})".format( - self.__class__.__name__, repr_content, repr(self.__modifiers) - ) - else: - return "schemas.{}(modifiers={})".format( - self.__class__.__name__, repr(self.__modifiers) - ) + return f"schemas.{self.__class__.__name__}({repr_content}, modifiers={repr(self.__modifiers)})" + return f"schemas.{self.__class__.__name__}(modifiers={repr(self.__modifiers)})" def _repr_content(self) -> str: """ @@ -111,10 +98,7 @@ def _repr_content(self) -> str: return "" def __eq__(self, other: object) -> bool: - return ( - self.__class__ == other.__class__ - and self.__modifiers == cast(ColumnType[TModifiers], other).get_modifiers() - ) + return self.__class__ == other.__class__ and self.__modifiers == other.get_modifiers() def for_schema(self) -> str: return ( @@ -133,10 +117,10 @@ def _for_schema_impl(self) -> str: def flatten(self, name: str) -> Sequence[FlattenedColumn]: return [FlattenedColumn(None, name, self)] - def get_modifiers(self) -> Optional[TModifiers]: + def get_modifiers(self) -> TModifiers | None: return self.__modifiers - def set_modifiers(self, modifiers: Optional[TModifiers]) -> ColumnType[TModifiers]: + def set_modifiers(self, modifiers: TModifiers | None) -> ColumnType[TModifiers]: """ Returns a new instance of this class with the provided modifiers. @@ -146,7 +130,7 @@ def set_modifiers(self, modifiers: Optional[TModifiers]) -> ColumnType[TModifier def get_raw(self) -> ColumnType[TModifiers]: return type(self)() - def has_modifier(self, modifier: Type[TypeModifier]) -> bool: + def has_modifier(self, modifier: type[TypeModifier]) -> bool: if self.__modifiers is None: return False return self.__modifiers.has_modifier(modifier) @@ -162,46 +146,44 @@ def __init__(self, name: str, type: ColumnType[TModifiers]) -> None: self.escaped: str = escaped def __repr__(self) -> str: - return "Column({}, {})".format(repr(self.name), repr(self.type)) + return f"Column({repr(self.name)}, {repr(self.type)})" def __eq__(self, other: object) -> bool: return ( self.__class__ == other.__class__ - and self.name == cast(Column[TModifiers], other).name - and self.type == cast(Column[TModifiers], other).type + and self.name == other.name + and self.type == other.type ) def for_schema(self) -> str: - return "{} {}".format(escape_identifier(self.name), self.type.for_schema()) + return f"{escape_identifier(self.name)} {self.type.for_schema()}" @staticmethod def to_columns( - columns: Sequence[Union[Column[TModifiers], tuple[str, ColumnType[TModifiers]]]], + columns: Sequence[Column[TModifiers] | tuple[str, ColumnType[TModifiers]]], ) -> Sequence[Column[TModifiers]]: return [Column(*col) if not isinstance(col, Column) else col for col in columns] class FlattenedColumn: - def __init__(self, base_name: Optional[str], name: str, type: ColumnType[TModifiers]) -> None: + def __init__(self, base_name: str | None, name: str, type: ColumnType[TModifiers]) -> None: self.base_name = base_name self.name = name self.type = type - self.flattened = "{}.{}".format(self.base_name, self.name) if self.base_name else self.name + self.flattened = f"{self.base_name}.{self.name}" if self.base_name else self.name escaped = escape_identifier(self.flattened) assert escaped is not None self.escaped: str = escaped def __repr__(self) -> str: - return "FlattenedColumn({}, {}, {})".format( - repr(self.base_name), repr(self.name), repr(self.type) - ) + return f"FlattenedColumn({repr(self.base_name)}, {repr(self.name)}, {repr(self.type)})" def __eq__(self, other: object) -> bool: return ( self.__class__ == other.__class__ - and self.flattened == cast(FlattenedColumn, other).flattened - and self.type == cast(FlattenedColumn, other).type + and self.flattened == other.flattened + and self.type == other.type ) @@ -215,7 +197,7 @@ class SchemaModifiers(TypeModifiers): readonly: bool = False def _get_modifiers(self) -> Sequence[TypeModifier]: - ret: List[TypeModifier] = [] + ret: list[TypeModifier] = [] if self.nullable: ret.append(Nullable()) if self.readonly: @@ -236,10 +218,10 @@ def for_schema(self, content: str) -> str: @dataclass(frozen=True) class Nullable(TypeModifier): def for_schema(self, content: str) -> str: - return "Nullable({})".format(content) + return f"Nullable({content})" -class ColumnSet(ABC): +class ColumnSet(ABC): # noqa: B024 - intentional abstract base shared by entity and ClickHouse column sets """ Base column set extended by both ClickHouse column set and entity column set A base column set class that will be shared by logical (entity) and physical (ClickHouse) @@ -256,7 +238,7 @@ def __init__(self, columns: Sequence[Column[SchemaModifiers]]) -> None: self._lookup: MutableMapping[str, Sequence[FlattenedColumn]] = {} self._nested = {} - self._flattened: List[FlattenedColumn] = [] + self._flattened: list[FlattenedColumn] = [] self._flattened_lookup: MutableMapping[str, FlattenedColumn] = {} for column in self.__columns: @@ -268,7 +250,7 @@ def __init__(self, columns: Sequence[Column[SchemaModifiers]]) -> None: for col in self._flattened: if col.flattened in self._flattened_lookup: - raise RuntimeError("Duplicate column: {}".format(col.flattened)) + raise RuntimeError(f"Duplicate column: {col.flattened}") if col.base_name: self._nested[col.flattened] = col @@ -278,8 +260,8 @@ def __init__(self, columns: Sequence[Column[SchemaModifiers]]) -> None: def __eq__(self, other: object) -> bool: return ( self.__class__ == other.__class__ - and self._flattened == cast(ColumnSet, other)._flattened - and self._wildcard_columns == cast(ColumnSet, other)._wildcard_columns + and self._flattened == other._flattened + and self._wildcard_columns == other._wildcard_columns ) def __getitem__(self, key: str) -> FlattenedColumn: @@ -296,7 +278,7 @@ def __getitem__(self, key: str) -> FlattenedColumn: raise KeyError(key) - def get(self, key: str, default: Optional[FlattenedColumn] = None) -> Optional[FlattenedColumn]: + def get(self, key: str, default: FlattenedColumn | None = None) -> FlattenedColumn | None: try: return self[key] except KeyError: @@ -350,7 +332,7 @@ def __init__(self) -> None: class Array(ColumnType[TModifiers]): def __init__( - self, inner_type: ColumnType[TModifiers], modifiers: Optional[TModifiers] = None + self, inner_type: ColumnType[TModifiers], modifiers: TModifiers | None = None ) -> None: super().__init__(modifiers) self.inner_type = inner_type @@ -361,8 +343,8 @@ def _repr_content(self) -> str: def __eq__(self, other: object) -> bool: return ( self.__class__ == other.__class__ - and self.inner_type == cast(Array[TModifiers], other).inner_type - and self.get_modifiers() == cast(Array[TModifiers], other).get_modifiers() + and self.inner_type == other.inner_type + and self.get_modifiers() == other.get_modifiers() ) def _for_schema_impl(self) -> str: @@ -370,11 +352,10 @@ def _for_schema_impl(self) -> str: if len(inner_schema) == 1: inner_type = inner_schema[0] return f"Array({inner_type})" - else: - inner_type, codec_modifiers = inner_schema - return f"Array({inner_type}) CODEC {codec_modifiers}" + inner_type, codec_modifiers = inner_schema + return f"Array({inner_type}) CODEC {codec_modifiers}" - def set_modifiers(self, modifiers: Optional[TModifiers]) -> Array[TModifiers]: + def set_modifiers(self, modifiers: TModifiers | None) -> Array[TModifiers]: return Array(inner_type=self.inner_type, modifiers=modifiers) def get_raw(self) -> Array[TModifiers]: @@ -386,7 +367,7 @@ def __init__( self, key: ColumnType[TModifiers], value: ColumnType[TModifiers], - modifiers: Optional[TModifiers] = None, + modifiers: TModifiers | None = None, ) -> None: super().__init__(modifiers) self.key = key @@ -398,15 +379,15 @@ def _repr_content(self) -> str: def __eq__(self, other: object) -> bool: return ( self.__class__ == other.__class__ - and self.key == cast(Map[TModifiers], other).key - and self.value == cast(Map[TModifiers], other).value - and self.get_modifiers() == cast(Map[TModifiers], other).get_modifiers() + and self.key == other.key + and self.value == other.value + and self.get_modifiers() == other.get_modifiers() ) def _for_schema_impl(self) -> str: return f"Map({self.key.for_schema()}, {self.value.for_schema()})" - def set_modifiers(self, modifiers: Optional[TModifiers]) -> Map[TModifiers]: + def set_modifiers(self, modifiers: TModifiers | None) -> Map[TModifiers]: return Map(key=self.key, value=self.value, modifiers=modifiers) def get_raw(self) -> Map[TModifiers]: @@ -416,8 +397,8 @@ def get_raw(self) -> Map[TModifiers]: class Nested(ColumnType[TModifiers]): def __init__( self, - nested_columns: Sequence[Union[Column[TModifiers], tuple[str, ColumnType[TModifiers]]]], - modifiers: Optional[TModifiers] = None, + nested_columns: Sequence[Column[TModifiers] | tuple[str, ColumnType[TModifiers]]], + modifiers: TModifiers | None = None, ) -> None: super().__init__(modifiers) self.nested_columns = Column.to_columns(nested_columns) @@ -428,8 +409,8 @@ def _repr_content(self) -> str: def __eq__(self, other: object) -> bool: return ( self.__class__ == other.__class__ - and self.get_modifiers() == cast(Nested[TModifiers], other).get_modifiers() - and self.nested_columns == cast(Nested[TModifiers], other).nested_columns + and self.get_modifiers() == other.get_modifiers() + and self.nested_columns == other.nested_columns ) def _for_schema_impl(self) -> str: @@ -440,7 +421,7 @@ def flatten(self, name: str) -> Sequence[FlattenedColumn]: FlattenedColumn(name, column.name, Array(column.type)) for column in self.nested_columns ] - def set_modifiers(self, modifiers: Optional[TModifiers]) -> Nested[TModifiers]: + def set_modifiers(self, modifiers: TModifiers | None) -> Nested[TModifiers]: return Nested(nested_columns=self.nested_columns, modifiers=modifiers) def get_raw(self) -> Nested[TModifiers]: @@ -453,7 +434,7 @@ def __init__( self, func: str, arg_types: Sequence[ColumnType[TModifiers]], - modifiers: Optional[TModifiers] = None, + modifiers: TModifiers | None = None, ) -> None: super().__init__(modifiers) self.func = func @@ -465,9 +446,9 @@ def _repr_content(self) -> str: def __eq__(self, other: object) -> bool: return ( self.__class__ == other.__class__ - and self.get_modifiers() == cast(AggregateFunction[TModifiers], other).get_modifiers() - and self.func == cast(AggregateFunction[TModifiers], other).func - and self.arg_types == cast(AggregateFunction[TModifiers], other).arg_types + and self.get_modifiers() == other.get_modifiers() + and self.func == other.func + and self.arg_types == other.arg_types ) def _for_schema_impl(self) -> str: @@ -475,7 +456,7 @@ def _for_schema_impl(self) -> str: ", ".join(chain([self.func], (x.for_schema() for x in self.arg_types))), ) - def set_modifiers(self, modifiers: Optional[TModifiers]) -> AggregateFunction[TModifiers]: + def set_modifiers(self, modifiers: TModifiers | None) -> AggregateFunction[TModifiers]: return AggregateFunction(self.func, self.arg_types, modifiers) def get_raw(self) -> AggregateFunction[TModifiers]: @@ -487,7 +468,7 @@ def __init__( self, func: str, arg_types: Sequence[ColumnType[TModifiers]], - modifiers: Optional[TModifiers] = None, + modifiers: TModifiers | None = None, ) -> None: super().__init__(modifiers) self.func = func @@ -499,10 +480,9 @@ def _repr_content(self) -> str: def __eq__(self, other: object) -> bool: return ( self.__class__ == other.__class__ - and self.get_modifiers() - == cast(SimpleAggregateFunction[TModifiers], other).get_modifiers() - and self.func == cast(SimpleAggregateFunction[TModifiers], other).func - and self.arg_types == cast(SimpleAggregateFunction[TModifiers], other).arg_types + and self.get_modifiers() == other.get_modifiers() + and self.func == other.func + and self.arg_types == other.arg_types ) def _for_schema_impl(self) -> str: @@ -510,7 +490,7 @@ def _for_schema_impl(self) -> str: ", ".join(chain([self.func], (x.for_schema() for x in self.arg_types))), ) - def set_modifiers(self, modifiers: Optional[TModifiers]) -> SimpleAggregateFunction[TModifiers]: + def set_modifiers(self, modifiers: TModifiers | None) -> SimpleAggregateFunction[TModifiers]: return SimpleAggregateFunction(self.func, self.arg_types, modifiers) def get_raw(self) -> SimpleAggregateFunction[TModifiers]: @@ -534,7 +514,7 @@ class IPv6(ColumnType[TModifiers]): class FixedString(ColumnType[TModifiers]): - def __init__(self, length: int, modifiers: Optional[TModifiers] = None) -> None: + def __init__(self, length: int, modifiers: TModifiers | None = None) -> None: super().__init__(modifiers) self.length = length @@ -544,14 +524,14 @@ def _repr_content(self) -> str: def __eq__(self, other: object) -> bool: return ( self.__class__ == other.__class__ - and self.get_modifiers() == cast(FixedString[TModifiers], other).get_modifiers() - and self.length == cast(FixedString[TModifiers], other).length + and self.get_modifiers() == other.get_modifiers() + and self.length == other.length ) def _for_schema_impl(self) -> str: - return "FixedString({})".format(self.length) + return f"FixedString({self.length})" - def set_modifiers(self, modifiers: Optional[TModifiers]) -> FixedString[TModifiers]: + def set_modifiers(self, modifiers: TModifiers | None) -> FixedString[TModifiers]: return FixedString(length=self.length, modifiers=modifiers) def get_raw(self) -> FixedString[TModifiers]: @@ -559,7 +539,7 @@ def get_raw(self) -> FixedString[TModifiers]: class UInt(ColumnType[TModifiers]): - def __init__(self, size: int, modifiers: Optional[TModifiers] = None) -> None: + def __init__(self, size: int, modifiers: TModifiers | None = None) -> None: super().__init__(modifiers) assert size in (8, 16, 32, 64, 128) self.size = size @@ -570,14 +550,14 @@ def _repr_content(self) -> str: def __eq__(self, other: object) -> bool: return ( self.__class__ == other.__class__ - and self.get_modifiers() == cast(UInt[TModifiers], other).get_modifiers() - and self.size == cast(UInt[TModifiers], other).size + and self.get_modifiers() == other.get_modifiers() + and self.size == other.size ) def _for_schema_impl(self) -> str: - return "UInt{}".format(self.size) + return f"UInt{self.size}" - def set_modifiers(self, modifiers: Optional[TModifiers]) -> UInt[TModifiers]: + def set_modifiers(self, modifiers: TModifiers | None) -> UInt[TModifiers]: return UInt(size=self.size, modifiers=modifiers) def get_raw(self) -> UInt[TModifiers]: @@ -585,7 +565,7 @@ def get_raw(self) -> UInt[TModifiers]: class Int(ColumnType[TModifiers]): - def __init__(self, size: int, modifiers: Optional[TModifiers] = None) -> None: + def __init__(self, size: int, modifiers: TModifiers | None = None) -> None: super().__init__(modifiers) assert size in (8, 16, 32, 64, 128) self.size = size @@ -596,14 +576,14 @@ def _repr_content(self) -> str: def __eq__(self, other: object) -> bool: return ( self.__class__ == other.__class__ - and self.get_modifiers() == cast(Int[TModifiers], other).get_modifiers() - and self.size == cast(Int[TModifiers], other).size + and self.get_modifiers() == other.get_modifiers() + and self.size == other.size ) def _for_schema_impl(self) -> str: - return "Int{}".format(self.size) + return f"Int{self.size}" - def set_modifiers(self, modifiers: Optional[TModifiers]) -> Int[TModifiers]: + def set_modifiers(self, modifiers: TModifiers | None) -> Int[TModifiers]: return Int(size=self.size, modifiers=modifiers) def get_raw(self) -> Int[TModifiers]: @@ -614,7 +594,7 @@ class Float(ColumnType[TModifiers]): def __init__( self, size: int, - modifiers: Optional[TModifiers] = None, + modifiers: TModifiers | None = None, ) -> None: super().__init__(modifiers) assert size in (32, 64) @@ -626,14 +606,14 @@ def _repr_content(self) -> str: def __eq__(self, other: object) -> bool: return ( self.__class__ == other.__class__ - and self.get_modifiers() == cast(Float[TModifiers], other).get_modifiers() - and self.size == cast(Float[TModifiers], other).size + and self.get_modifiers() == other.get_modifiers() + and self.size == other.size ) def _for_schema_impl(self) -> str: - return "Float{}".format(self.size) + return f"Float{self.size}" - def set_modifiers(self, modifiers: Optional[TModifiers]) -> Float[TModifiers]: + def set_modifiers(self, modifiers: TModifiers | None) -> Float[TModifiers]: return Float(size=self.size, modifiers=modifiers) def get_raw(self) -> Float[TModifiers]: @@ -652,8 +632,8 @@ class DateTime64(ColumnType[TModifiers]): def __init__( self, precision: int = 3, - timezone: Optional[str] = None, - modifiers: Optional[TModifiers] = None, + timezone: str | None = None, + modifiers: TModifiers | None = None, ) -> None: assert precision <= 9 super().__init__(modifiers) @@ -669,19 +649,15 @@ def _repr_content(self) -> str: def __eq__(self, other: object) -> bool: return ( self.__class__ == other.__class__ - and self.get_modifiers() - == cast( - DateTime64[TModifiers], - other, - ).get_modifiers() - and self.precision == cast(DateTime64[TModifiers], other).precision - and self.timezone == cast(DateTime64[TModifiers], other).timezone + and self.get_modifiers() == other.get_modifiers() + and self.precision == other.precision + and self.timezone == other.timezone ) def _for_schema_impl(self) -> str: return f"DateTime64({self._repr_content()})" - def set_modifiers(self, modifiers: Optional[TModifiers]) -> DateTime64[TModifiers]: + def set_modifiers(self, modifiers: TModifiers | None) -> DateTime64[TModifiers]: return DateTime64( precision=self.precision, timezone=self.timezone, @@ -699,25 +675,25 @@ class Enum(ColumnType[TModifiers]): def __init__( self, values: Sequence[tuple[str, int]], - modifiers: Optional[TModifiers] = None, + modifiers: TModifiers | None = None, ) -> None: super().__init__(modifiers) self.values = values def _repr_content(self) -> str: - return ", ".join("'{}' = {}".format(v[0], v[1]) for v in self.values) + return ", ".join(f"'{v[0]}' = {v[1]}" for v in self.values) def __eq__(self, other: object) -> bool: return ( self.__class__ == other.__class__ - and self.get_modifiers() == cast(Enum[TModifiers], other).get_modifiers() - and self.values == cast(Enum[TModifiers], other).values + and self.get_modifiers() == other.get_modifiers() + and self.values == other.values ) def _for_schema_impl(self) -> str: - return "Enum({})".format(", ".join("'{}' = {}".format(v[0], v[1]) for v in self.values)) + return "Enum({})".format(", ".join(f"'{v[0]}' = {v[1]}" for v in self.values)) - def set_modifiers(self, modifiers: Optional[TModifiers]) -> Enum[TModifiers]: + def set_modifiers(self, modifiers: TModifiers | None) -> Enum[TModifiers]: return Enum(values=self.values, modifiers=modifiers) def get_raw(self) -> Enum[TModifiers]: @@ -728,25 +704,25 @@ class Tuple(ColumnType[TModifiers]): def __init__( self, types: tuple[ColumnType[TModifiers], ...], - modifiers: Optional[TModifiers] = None, + modifiers: TModifiers | None = None, ) -> None: super().__init__(modifiers) self.types = types def _repr_content(self) -> str: - return ", ".join("{}".format(v) for v in self.types) + return ", ".join(f"{v}" for v in self.types) def __eq__(self, other: object) -> bool: return ( self.__class__ == other.__class__ - and self.get_modifiers() == cast(Tuple[TModifiers], other).get_modifiers() - and self.types == cast(Tuple[TModifiers], other).types + and self.get_modifiers() == other.get_modifiers() + and self.types == other.types ) def _for_schema_impl(self) -> str: - return "Tuple({})".format(", ".join("{}".format(t.for_schema()) for t in self.types)) + return "Tuple({})".format(", ".join(f"{t.for_schema()}" for t in self.types)) - def set_modifiers(self, modifiers: Optional[TModifiers]) -> Tuple[TModifiers]: + def set_modifiers(self, modifiers: TModifiers | None) -> Tuple[TModifiers]: return Tuple(types=self.types, modifiers=modifiers) def get_raw(self) -> Tuple[TModifiers]: @@ -822,7 +798,7 @@ def _valid_tuple(self, tuple_column: Tuple[AnyType], value: tuple[Any]) -> bool: class Bool(ColumnType[TModifiers]): - def __init__(self, modifiers: Optional[TModifiers] = None) -> None: + def __init__(self, modifiers: TModifiers | None = None) -> None: super().__init__(modifiers) def __eq__(self, other: object) -> bool: @@ -834,7 +810,7 @@ def __eq__(self, other: object) -> bool: def _for_schema_impl(self) -> str: return "Bool" - def set_modifiers(self, modifiers: Optional[TModifiers]) -> Bool[TModifiers]: + def set_modifiers(self, modifiers: TModifiers | None) -> Bool[TModifiers]: return Bool(modifiers=modifiers) def get_raw(self) -> Bool[TModifiers]: @@ -844,13 +820,17 @@ def get_raw(self) -> Bool[TModifiers]: class JSON(ColumnType[TModifiers]): def __init__( self, - max_dynamic_paths: Optional[int] = None, - max_dynamic_types: Optional[int] = None, + max_dynamic_paths: int | None = None, + max_dynamic_types: int | None = None, type_hints: Mapping[str, ColumnType[TModifiers]] = {}, - skip_paths: list[str] = [], - skip_regexp: list[str] = [], - modifiers: Optional[TModifiers] = None, + skip_paths: list[str] | None = None, + skip_regexp: list[str] | None = None, + modifiers: TModifiers | None = None, ) -> None: + if skip_regexp is None: + skip_regexp = [] + if skip_paths is None: + skip_paths = [] super().__init__(modifiers) self.max_dynamic_paths = max_dynamic_paths self.max_dynamic_types = max_dynamic_types @@ -872,12 +852,12 @@ def _repr_content(self) -> str: def __eq__(self, other: object) -> bool: return ( self.__class__ == other.__class__ - and self.get_modifiers() == cast(JSON[TModifiers], other).get_modifiers() - and self.max_dynamic_paths == cast(JSON[TModifiers], other).max_dynamic_paths - and self.max_dynamic_types == cast(JSON[TModifiers], other).max_dynamic_types - and self.type_hints == cast(JSON[TModifiers], other).type_hints - and self.skip_paths == cast(JSON[TModifiers], other).skip_paths - and self.skip_regexp == cast(JSON[TModifiers], other).skip_regexp + and self.get_modifiers() == other.get_modifiers() + and self.max_dynamic_paths == other.max_dynamic_paths + and self.max_dynamic_types == other.max_dynamic_types + and self.type_hints == other.type_hints + and self.skip_paths == other.skip_paths + and self.skip_regexp == other.skip_regexp ) def _for_schema_impl(self) -> str: @@ -904,10 +884,9 @@ def _for_schema_impl(self) -> str: if parts: return f"JSON({', '.join(parts)})" - else: - return "JSON" + return "JSON" - def set_modifiers(self, modifiers: Optional[TModifiers]) -> JSON[TModifiers]: + def set_modifiers(self, modifiers: TModifiers | None) -> JSON[TModifiers]: return JSON( max_dynamic_paths=self.max_dynamic_paths, max_dynamic_types=self.max_dynamic_types, diff --git a/snuba/utils/serializable_exception.py b/snuba/utils/serializable_exception.py index ef4d77f600d..d10679be487 100644 --- a/snuba/utils/serializable_exception.py +++ b/snuba/utils/serializable_exception.py @@ -44,19 +44,19 @@ from __future__ import annotations -from typing import Any, Dict, List, Optional, Type, TypedDict, Union, cast +from typing import Any, TypedDict, cast import rapidjson # mypy has not figured out recursive types yet so this can't be totally typesafe -JsonSerializable = Union[str, int, float, bool, None, Dict[str, Any], List[Any]] +JsonSerializable = str | int | float | bool | None | dict[str, Any] | list[Any] class SerializableExceptionDict(TypedDict): __type__: str __name__: str __message__: str - __extra_data__: Dict[str, JsonSerializable] + __extra_data__: dict[str, JsonSerializable] __should_report__: bool @@ -64,14 +64,14 @@ class _ExceptionRegistry: """Keep a mapping of SerializableExceptions to their names""" def __init__(self) -> None: - self.__mapping: Dict[str, Type["SerializableException"]] = {} + self.__mapping: dict[str, type[SerializableException]] = {} - def register_class(self, cls: Type["SerializableException"]) -> None: + def register_class(self, cls: type[SerializableException]) -> None: existing_class = self.__mapping.get(cls.__name__) if not existing_class: self.__mapping[cls.__name__] = cls - def get_class_by_name(self, cls_name: str) -> Optional[Type["SerializableException"]]: + def get_class_by_name(self, cls_name: str) -> type[SerializableException] | None: return self.__mapping.get(cls_name) @@ -96,7 +96,7 @@ class SerializableException(Exception): def __init__( self, - message: Optional[str] = None, + message: str | None = None, should_report: bool = True, **extra_data: JsonSerializable, ) -> None: @@ -112,7 +112,7 @@ def format_message(self, message: str) -> str: """ return message - def _format_message(self, message: Optional[str]) -> str: + def _format_message(self, message: str | None) -> str: if not message: return "" @@ -128,7 +128,7 @@ def to_dict(self) -> SerializableExceptionDict: } @classmethod - def from_dict(cls, edict: SerializableExceptionDict) -> "SerializableException": + def from_dict(cls, edict: SerializableExceptionDict) -> SerializableException: assert edict["__type__"] == "SerializableException" defined_exception = _get_registry().get_class_by_name(edict.get("__name__", "")) @@ -160,7 +160,7 @@ def __init_subclass__(cls) -> None: return super().__init_subclass__() @classmethod - def from_standard_exception_instance(cls, exc: Exception) -> "SerializableException": + def from_standard_exception_instance(cls, exc: Exception) -> SerializableException: if isinstance(exc, cls): return exc return cls.from_dict( diff --git a/snuba/utils/server.py b/snuba/utils/server.py index 88a32a5a558..cb3d8ee414b 100644 --- a/snuba/utils/server.py +++ b/snuba/utils/server.py @@ -1,7 +1,5 @@ from __future__ import annotations -from typing import Optional - from granian import Granian from granian.constants import Interfaces @@ -10,7 +8,7 @@ def serve( module: str, bind: str, processes: int = 1, - threads: Optional[int] = None, + threads: int | None = None, backlog: int = 128, reload: bool = False, name: str | None = None, @@ -33,4 +31,4 @@ def serve( reload=reload, process_name=name, ) - server.serve() # type: ignore + server.serve() diff --git a/snuba/utils/streams/configuration_builder.py b/snuba/utils/streams/configuration_builder.py index c1dab5dbf4b..2f02fe193b3 100644 --- a/snuba/utils/streams/configuration_builder.py +++ b/snuba/utils/streams/configuration_builder.py @@ -1,4 +1,5 @@ -from typing import Any, Mapping, Optional, Sequence +from collections.abc import Mapping, Sequence +from typing import Any from arroyo.backends.kafka import build_kafka_configuration from arroyo.backends.kafka import ( @@ -14,24 +15,22 @@ def _get_default_topic_configuration( - topic: Optional[Topic], slice_id: Optional[int] = None + topic: Topic | None, slice_id: int | None = None ) -> Mapping[str, Any]: if topic is not None: if slice_id is not None: return settings.SLICED_KAFKA_BROKER_CONFIG.get( (topic.value, slice_id), settings.BROKER_CONFIG ) - else: - return settings.KAFKA_BROKER_CONFIG.get(topic.value, settings.BROKER_CONFIG) - else: - return settings.BROKER_CONFIG + return settings.KAFKA_BROKER_CONFIG.get(topic.value, settings.BROKER_CONFIG) + return settings.BROKER_CONFIG def get_default_kafka_configuration( - topic: Optional[Topic] = None, - slice_id: Optional[int] = None, - bootstrap_servers: Optional[Sequence[str]] = None, - override_params: Optional[Mapping[str, Any]] = None, + topic: Topic | None = None, + slice_id: int | None = None, + bootstrap_servers: Sequence[str] | None = None, + override_params: Mapping[str, Any] | None = None, ) -> KafkaBrokerConfig: default_topic_config = _get_default_topic_configuration(topic, slice_id) @@ -39,15 +38,15 @@ def get_default_kafka_configuration( def build_kafka_consumer_configuration( - topic: Optional[Topic], + topic: Topic | None, group_id: str, - slice_id: Optional[int] = None, - auto_offset_reset: Optional[str] = None, - queued_max_messages_kbytes: Optional[int] = None, - queued_min_messages: Optional[int] = None, - bootstrap_servers: Optional[Sequence[str]] = None, - override_params: Optional[Mapping[str, Any]] = None, - strict_offset_reset: Optional[bool] = None, + slice_id: int | None = None, + auto_offset_reset: str | None = None, + queued_max_messages_kbytes: int | None = None, + queued_min_messages: int | None = None, + bootstrap_servers: Sequence[str] | None = None, + override_params: Mapping[str, Any] | None = None, + strict_offset_reset: bool | None = None, ) -> KafkaBrokerConfig: default_topic_config = _get_default_topic_configuration(topic, slice_id) @@ -64,10 +63,10 @@ def build_kafka_consumer_configuration( def build_kafka_producer_configuration( - topic: Optional[Topic], - slice_id: Optional[int] = None, - bootstrap_servers: Optional[Sequence[str]] = None, - override_params: Optional[Mapping[str, Any]] = None, + topic: Topic | None, + slice_id: int | None = None, + bootstrap_servers: Sequence[str] | None = None, + override_params: Mapping[str, Any] | None = None, ) -> KafkaBrokerConfig: default_topic_config = _get_default_topic_configuration(topic, slice_id) diff --git a/snuba/utils/streams/metrics_adapter.py b/snuba/utils/streams/metrics_adapter.py index 7991f9691a9..cf9657d3941 100644 --- a/snuba/utils/streams/metrics_adapter.py +++ b/snuba/utils/streams/metrics_adapter.py @@ -1,5 +1,3 @@ -from typing import Optional, Union - from snuba.utils.metrics import MetricsBackend from snuba.utils.metrics.types import Tags @@ -8,13 +6,11 @@ class StreamMetricsAdapter: def __init__(self, metrics: MetricsBackend) -> None: self.__wrapper = metrics - def increment( - self, name: str, value: Union[int, float] = 1, tags: Optional[Tags] = None - ) -> None: + def increment(self, name: str, value: int | float = 1, tags: Tags | None = None) -> None: self.__wrapper.increment(name, value, tags) - def gauge(self, name: str, value: Union[int, float], tags: Optional[Tags] = None) -> None: + def gauge(self, name: str, value: int | float, tags: Tags | None = None) -> None: self.__wrapper.gauge(name, value, tags) - def timing(self, name: str, value: Union[int, float], tags: Optional[Tags] = None) -> None: + def timing(self, name: str, value: int | float, tags: Tags | None = None) -> None: self.__wrapper.timing(name, value, tags) diff --git a/snuba/utils/streams/topics.py b/snuba/utils/streams/topics.py index d872e448383..7eb547d0b9a 100644 --- a/snuba/utils/streams/topics.py +++ b/snuba/utils/streams/topics.py @@ -1,5 +1,5 @@ +from collections.abc import Mapping from enum import Enum -from typing import Mapping from sentry_kafka_schemas import SchemaNotFound, get_topic diff --git a/snuba/utils/streams/types.py b/snuba/utils/streams/types.py index 3dd64582194..dfe00be5aae 100644 --- a/snuba/utils/streams/types.py +++ b/snuba/utils/streams/types.py @@ -1,3 +1,3 @@ -from typing import Any, Dict +from typing import Any -KafkaBrokerConfig = Dict[str, Any] +KafkaBrokerConfig = dict[str, Any] diff --git a/snuba/utils/threaded_function_delegator.py b/snuba/utils/threaded_function_delegator.py index c38771c967b..59a4e69dbbc 100644 --- a/snuba/utils/threaded_function_delegator.py +++ b/snuba/utils/threaded_function_delegator.py @@ -1,9 +1,10 @@ import logging import time +from collections.abc import Callable, Iterator, Mapping from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass from functools import partial -from typing import Callable, Generic, Iterator, List, Mapping, Optional, Tuple, TypeVar +from typing import Generic, TypeVar logger = logging.getLogger(__name__) @@ -35,8 +36,8 @@ class ThreadedFunctionDelegator(Generic[TInput, TResult]): def __init__( self, callables: Mapping[str, Callable[[], TResult]], - selector_func: Callable[[TInput], Tuple[str, List[str]]], - callback_func: Optional[Callable[[Optional[Result[TResult]], List[Result[TResult]]], None]], + selector_func: Callable[[TInput], tuple[str, list[str]]], + callback_func: Callable[[Result[TResult] | None, list[Result[TResult]]], None] | None, ignore_secondary_exceptions: bool = False, ) -> None: self.__callables = callables @@ -66,8 +67,8 @@ def __execute_callables(self, input: TInput) -> Iterator[Result[TResult]]: def execute(self, input: TInput) -> TResult: generator = self.__execute_callables(input) - primary_result: Optional[Result[TResult]] = None - other_results: List[Result[TResult]] = [] + primary_result: Result[TResult] | None = None + other_results: list[Result[TResult]] = [] try: primary_result = next(generator) diff --git a/snuba/utils/types.py b/snuba/utils/types.py index 61bc961f33a..39beeb31953 100644 --- a/snuba/utils/types.py +++ b/snuba/utils/types.py @@ -1,3 +1,3 @@ -from typing import MutableMapping, Tuple +from collections.abc import MutableMapping -ColumnStatesMapType = MutableMapping[Tuple[str, int, str, str], str] +ColumnStatesMapType = MutableMapping[tuple[str, int, str, str], str] diff --git a/snuba/web/__init__.py b/snuba/web/__init__.py index f633a531540..b87531479d0 100644 --- a/snuba/web/__init__.py +++ b/snuba/web/__init__.py @@ -1,14 +1,15 @@ from __future__ import annotations +from collections.abc import Mapping from dataclasses import dataclass -from typing import Any, Dict, Mapping, TypedDict, cast +from typing import Any, TypedDict, cast from snuba.reader import Column, Result, Row, transform_rows from snuba.utils.serializable_exception import JsonSerializable, SerializableException class QueryExtraData(TypedDict): - stats: Dict[str, Any] + stats: dict[str, Any] sql: str experiments: Mapping[str, Any] @@ -33,9 +34,7 @@ def __init__( super().__init__(message, should_report, **extra_data) @classmethod - def from_args( - cls, exception_type: str, message: str, extra: QueryExtraData - ) -> "QueryException": + def from_args(cls, exception_type: str, message: str, extra: QueryExtraData) -> QueryException: return cls( exception_type=exception_type, message=message, @@ -66,7 +65,10 @@ class QueryResult: @property def quota_allowance(self) -> Mapping[str, Mapping[str, Any]]: - return self.extra.get("stats", {}).get("quota_allowance", {}) + result: Mapping[str, Mapping[str, Any]] = self.extra.get("stats", {}).get( + "quota_allowance", {} + ) + return result def transform_column_names(result: QueryResult, mapping: Mapping[str, list[str]]) -> None: diff --git a/snuba/web/bulk_delete_query.py b/snuba/web/bulk_delete_query.py index 5bcb5299a04..4a43d498a43 100644 --- a/snuba/web/bulk_delete_query.py +++ b/snuba/web/bulk_delete_query.py @@ -1,7 +1,8 @@ import logging import time +from collections.abc import Mapping, MutableMapping, Sequence from threading import Thread -from typing import Any, Dict, Mapping, MutableMapping, Optional, Sequence, TypedDict +from typing import Any, TypedDict import rapidjson from confluent_kafka import KafkaError, Producer @@ -50,8 +51,8 @@ class DeleteQueryMessage(TypedDict, total=False): storage_name: str conditions: ConditionsType tenant_ids: Mapping[str, str | int] - attribute_conditions: Optional[Dict[str, WireAttributeCondition]] - attribute_conditions_item_type: Optional[int] + attribute_conditions: dict[str, WireAttributeCondition] | None + attribute_conditions_item_type: int | None PRODUCER_MAP: MutableMapping[str, Producer] = {} @@ -95,7 +96,7 @@ def _flush_producers() -> None: Thread(target=_flush_producers, name="flush_producers", daemon=True).start() -def _delete_query_delivery_callback(error: Optional[KafkaError], message: KafkaMessage) -> None: +def _delete_query_delivery_callback(error: KafkaError | None, message: KafkaMessage) -> None: metrics.increment( "delete_query.delivery_callback", tags={"status": "failure" if error else "success"}, @@ -155,7 +156,7 @@ def _validate_attribute_conditions( try: item_type_name = get_trace_item_type_name(attribute_conditions.item_type) except ValueError as e: - raise InvalidQueryException(str(e)) + raise InvalidQueryException(str(e)) from e # Check if this specific item_type has any allowed attributes configured if item_type_name not in allowed_attrs_config: @@ -183,9 +184,9 @@ def _validate_attribute_conditions( @with_span() def delete_from_storage( storage: WritableTableStorage, - column_conditions: Dict[str, list[Any]], + column_conditions: dict[str, list[Any]], attribution_info: Mapping[str, Any], - attribute_conditions: Optional[AttributeConditions] = None, + attribute_conditions: AttributeConditions | None = None, ) -> dict[str, Result]: """ This method does a series of validation checks (outline below), @@ -218,7 +219,7 @@ def delete_from_storage( for col, values in column_conditions.items(): column_validator.validate(col, values) except InvalidColumnType as e: - raise InvalidQueryException(e.message) + raise InvalidQueryException(e.message) from e # validate attribute conditions if provided if attribute_conditions: @@ -258,8 +259,8 @@ def construct_query(storage: WritableTableStorage, table: str, condition: Expres def _serialize_attribute_conditions( attribute_conditions: AttributeConditions, -) -> Dict[str, WireAttributeCondition]: - result: Dict[str, WireAttributeCondition] = {} +) -> dict[str, WireAttributeCondition]: + result: dict[str, WireAttributeCondition] = {} for key, (attr_key_enum, values) in attribute_conditions.attributes.items(): result[key] = { "attr_key_type": attr_key_enum.type, diff --git a/snuba/web/db_query.py b/snuba/web/db_query.py index ff8344330a6..3e38f6d63c6 100644 --- a/snuba/web/db_query.py +++ b/snuba/web/db_query.py @@ -3,11 +3,12 @@ import logging import random import uuid +from collections.abc import Mapping, MutableMapping, MutableSequence from dataclasses import dataclass from functools import partial from hashlib import md5 from threading import Lock -from typing import Any, Mapping, MutableMapping, MutableSequence, Optional, Union, cast +from typing import Any, cast import rapidjson import sentry_sdk @@ -114,7 +115,7 @@ def encode_exception(self, value: SerializableException) -> bytes: def update_query_metadata_and_stats( - query: Union[Query, CompositeQuery[Table]], + query: Query | CompositeQuery[Table], sql: str, stats: MutableMapping[str, Any], query_metadata_list: MutableSequence[ClickhouseQueryMetadata], @@ -122,9 +123,9 @@ def update_query_metadata_and_stats( trace_id: str, status: QueryStatus, request_status: Status, - profile_data: Optional[snuba_queries_v1._QueryMetadataResultProfileObject] = None, - error_code: Optional[int] = None, - triggered_rate_limiter: Optional[str] = None, + profile_data: snuba_queries_v1._QueryMetadataResultProfileObject | None = None, + error_code: int | None = None, + triggered_rate_limiter: str | None = None, ) -> MutableMapping[str, Any]: """ If query logging is enabled then logs details about the query and its status, as @@ -162,7 +163,7 @@ def execute_query( # as the execute method depends on it. Otherwise we can make this # file rely either entirely on clickhouse query or entirely on # the formatter. - clickhouse_query: Union[Query, CompositeQuery[Table]], + clickhouse_query: Query | CompositeQuery[Table], query_settings: QuerySettings, formatted_query: FormattedQuery, reader: Reader, @@ -225,7 +226,7 @@ def _get_cache_partition(reader: Reader) -> Cache[Result]: @with_span(op="function") def execute_query_with_query_id( - clickhouse_query: Union[Query, CompositeQuery[Table]], + clickhouse_query: Query | CompositeQuery[Table], query_settings: QuerySettings, formatted_query: FormattedQuery, reader: Reader, @@ -280,7 +281,7 @@ def execute_query_with_query_id( @with_span(op="function") def execute_query_with_readthrough_caching( - clickhouse_query: Union[Query, CompositeQuery[Table]], + clickhouse_query: Query | CompositeQuery[Table], query_settings: QuerySettings, formatted_query: FormattedQuery, reader: Reader, @@ -349,9 +350,9 @@ def record_cache_hit_type(hit_type: int) -> None: def _get_query_settings_from_config( - override_prefix: Optional[str], + override_prefix: str | None, async_override: bool, - referrer: Optional[str], + referrer: str | None, ) -> MutableMapping[str, Any]: """ Helper function to get the query settings from the config. Order of precedence @@ -391,7 +392,7 @@ def _get_query_settings_from_config( def _raw_query( - clickhouse_query: Union[Query, CompositeQuery[Table]], + clickhouse_query: Query | CompositeQuery[Table], query_settings: QuerySettings, attribution_info: AttributionInfo, dataset_name: str, @@ -402,7 +403,7 @@ def _raw_query( timer: Timer, # NOTE: This variable is a piece of state which is updated and used outside this function stats: MutableMapping[str, Any], - trace_id: Optional[str] = None, + trace_id: str | None = None, robust: bool = False, ) -> QueryResult: """ @@ -448,7 +449,7 @@ def _raw_query( sql=sql, stats=stats, query_settings=clickhouse_query_settings, - trace_id=trace_id, + trace_id=cast(str, trace_id), ) try: @@ -476,23 +477,22 @@ def _raw_query( elif isinstance(cause, ClickhouseError): error_code = cause.code status = get_query_status_from_error_codes(error_code) - if error_code == ErrorCodes.TOO_MANY_BYTES: - # Only treat as rate limiting if the limit was set by allocation policy - if stats.get("max_bytes_to_read_set_by_policy", False): - calculated_cause = RateLimitExceeded( - "Query scanned more than the allocated amount of bytes", - quota_allowance=stats["quota_allowance"], - ) - status = QueryStatus.RATE_LIMITED + # Only treat as rate limiting if the limit was set by allocation policy + if error_code == ErrorCodes.TOO_MANY_BYTES and stats.get( + "max_bytes_to_read_set_by_policy", False + ): + calculated_cause = RateLimitExceeded( + "Query scanned more than the allocated amount of bytes", + quota_allowance=stats["quota_allowance"], + ) + status = QueryStatus.RATE_LIMITED with configure_scope() as scope: fingerprint = ["{{default}}", str(cause.code), dataset_name] if error_code not in constants.CLICKHOUSE_SYSTEMATIC_FAILURES: fingerprint.append(attribution_info.referrer) scope.fingerprint = fingerprint - elif isinstance(cause, TimeoutError): - status = QueryStatus.TIMEOUT - elif isinstance(cause, ExecutionTimeoutError): + elif isinstance(cause, (TimeoutError, ExecutionTimeoutError)): status = QueryStatus.TIMEOUT with configure_scope() as scope: @@ -520,7 +520,10 @@ def _raw_query( stats = update_with_status( status=QueryStatus.SUCCESS, request_status=get_request_status(), - profile_data=result["profile"], + profile_data=cast( + "snuba_queries_v1._QueryMetadataResultProfileObject | None", + result["profile"], + ), ) return QueryResult( result, @@ -593,7 +596,7 @@ def _record_bytes_scanned( def db_query( - clickhouse_query: Union[Query, CompositeQuery[Table]], + clickhouse_query: Query | CompositeQuery[Table], query_settings: QuerySettings, attribution_info: AttributionInfo, dataset_name: str, @@ -741,7 +744,7 @@ def db_query( stats = dict(result.extra["stats"]) stats["sampling_tier"] = query_settings.get_sampling_tier() result.extra["stats"] = stats - return result + return result # noqa: B012 - intentional: all error paths above are caught into `error`, so this finally is the single terminal point after quota/bytes-scanned bookkeeping raise error or Exception("No error or result when running query, this should never happen") @@ -775,8 +778,8 @@ def _add_quota_info( def _populate_query_status( summary: dict[str, Any], - rejection_quota_and_policy: Optional[_QuotaAndPolicy], - throttle_quota_and_policy: Optional[_QuotaAndPolicy], + rejection_quota_and_policy: _QuotaAndPolicy | None, + throttle_quota_and_policy: _QuotaAndPolicy | None, ) -> None: is_successful = "is_successful" is_rejected = "is_rejected" diff --git a/snuba/web/delete_query.py b/snuba/web/delete_query.py index 075917d1751..3c3870f5a60 100644 --- a/snuba/web/delete_query.py +++ b/snuba/web/delete_query.py @@ -1,6 +1,7 @@ import typing import uuid -from typing import Any, Dict, Mapping, Optional, Sequence +from collections.abc import Mapping, Sequence +from typing import Any from snuba import settings from snuba.attribution import get_app_id @@ -153,7 +154,7 @@ def _delete_from_table( for col, values in conditions.items(): column_validator.validate(col, values) except InvalidColumnType as e: - raise InvalidQueryException(e.message) + raise InvalidQueryException(e.message) from e try: _enforce_max_rows(query) @@ -255,7 +256,7 @@ def get_new_from_clause() -> Table: updates the from_clause to have the correct table. """ dist_table_name = ( - get_writable_storage((storage_key)).get_table_writer().get_schema().get_table_name() + get_writable_storage(storage_key).get_table_writer().get_schema().get_table_name() ) from_clause = delete_query.get_from_clause() return Table( @@ -308,7 +309,7 @@ def _execute_query( query: Query, storage: WritableTableStorage, table: str, - cluster_name: Optional[str], + cluster_name: str | None, attribution_info: AttributionInfo, query_settings: HTTPQuerySettings, ) -> Result: @@ -324,7 +325,7 @@ def _execute_query( result = None error = None - stats: Dict[str, Any] = { + stats: dict[str, Any] = { "clickhouse_table": table, "referrer": attribution_info.referrer, "cluster_name": cluster_name or "", @@ -387,7 +388,7 @@ def _execute_query( result_or_error=result_or_error, ) if result: - return result + return result # noqa: B012 quota balance must run before returning/raising raise error or Exception("No error or result when running query, this should never happen") diff --git a/snuba/web/query.py b/snuba/web/query.py index d1ecc48c283..eb1426eeb99 100644 --- a/snuba/web/query.py +++ b/snuba/web/query.py @@ -1,7 +1,7 @@ from __future__ import annotations import logging -from typing import Any, Optional +from typing import Any import sentry_sdk @@ -38,7 +38,7 @@ def _run_query_pipeline( timer: Timer, query_metadata: SnubaQueryMetadata, robust: bool = False, - concurrent_queries_gauge: Optional[Gauge] = None, + concurrent_queries_gauge: Gauge | None = None, force_dry_run: bool = False, ) -> QueryResult: clickhouse_query = EntityProcessingStage().execute( @@ -59,7 +59,7 @@ def _run_query_pipeline( ).execute(clickhouse_query) if res.error: raise res.error - elif res.data: + if res.data: return res.data # we should never get here raise Exception("No result or data, very bad exception") @@ -71,7 +71,7 @@ def run_query( request: Request, timer: Timer, robust: bool = False, - concurrent_queries_gauge: Optional[Gauge] = None, + concurrent_queries_gauge: Gauge | None = None, ) -> QueryResult: """ Processes, runs a Snuba Query, then records the metadata about the query that was run. @@ -109,7 +109,7 @@ def run_query( return result -def _get_dataset(dataset_name: Optional[str]) -> Dataset: +def _get_dataset(dataset_name: str | None) -> Dataset: if dataset_name: return get_dataset(dataset_name) return PluggableDataset(name=settings.DEFAULT_DATASET_NAME, all_entities=[]) @@ -120,8 +120,8 @@ def parse_and_run_query( body: dict[str, Any], timer: Timer, is_mql: bool = False, - dataset_name: Optional[str] = None, - referrer: Optional[str] = None, + dataset_name: str | None = None, + referrer: str | None = None, ) -> tuple[Request, QueryResult]: """Top level entrypoint from a raw query body to a query result diff --git a/snuba/web/rpc/__init__.py b/snuba/web/rpc/__init__.py index 8846525a960..f0785817528 100644 --- a/snuba/web/rpc/__init__.py +++ b/snuba/web/rpc/__init__.py @@ -3,7 +3,7 @@ import random import uuid from bisect import bisect_left -from typing import Generic, List, Tuple, Type, cast, final +from typing import Any, Generic, cast, final import sentry_sdk from clickhouse_driver.errors import ErrorCodes as clickhouse_errors @@ -129,8 +129,8 @@ def trace_item_type(cls) -> TraceItemType.ValueType: def get_from_trace_item_type( cls, trace_item_type: TraceItemType.ValueType, - ) -> "Type[TraceItemDataResolver[Tin, Tout]]": - registry = getattr(cls, "_registry") + ) -> "type[TraceItemDataResolver[Tin, Tout]]": + registry = cls._registry try: shape = registry.get_class_from_name(f"{cls.endpoint_name()}__{trace_item_type}") except InvalidConfigKeyError: @@ -138,7 +138,7 @@ def get_from_trace_item_type( f"{cls.endpoint_name()}__{TraceItemType.TRACE_ITEM_TYPE_UNSPECIFIED}" ) return cast( - Type["TraceItemDataResolver[Tin, Tout]"], + type["TraceItemDataResolver[Tin, Tout]"], shape, ) @@ -152,11 +152,11 @@ def __init__(self, metrics_backend: MetricsBackend | None = None) -> None: self._metrics_backend = metrics_backend or environment.metrics @classmethod - def request_class(cls) -> Type[Tin]: + def request_class(cls) -> type[Tin]: raise NotImplementedError @classmethod - def response_class(cls) -> Type[Tout]: + def response_class(cls) -> type[Tout]: raise NotImplementedError @classmethod @@ -181,10 +181,10 @@ def metrics(self) -> MetricsWrapper: ) @classmethod - def get_from_name(cls, name: str, version: str) -> Type["RPCEndpoint[Tin, Tout]"]: + def get_from_name(cls, name: str, version: str) -> type["RPCEndpoint[Tin, Tout]"]: return cast( - Type["RPCEndpoint[Tin, Tout]"], - getattr(cls, "_registry").get_class_from_name(f"{name}__{version}"), + type["RPCEndpoint[Tin, Tout]"], + cls._registry.get_class_from_name(f"{name}__{version}"), ) def parse_from_string(self, bytestring: bytes) -> Tin: @@ -275,9 +275,8 @@ def execute(self, in_msg: Tin) -> Tout: def __before_execute(self, in_msg: Tin) -> None: # Generate request_id if not already present meta = getattr(in_msg, "meta", None) - if meta is not None: - if not hasattr(meta, "request_id") or not meta.request_id: - meta.request_id = self.routing_context.query_id + if meta is not None and (not hasattr(meta, "request_id") or not meta.request_id): + meta.request_id = self.routing_context.query_id self._timer.update_tags(self.__extract_request_tags(in_msg)) @@ -371,11 +370,10 @@ def __after_execute(self, in_msg: Tin, out_msg: Tout, error: Exception | None) - tags=self._timer.tags, ) raise error - else: - self.metrics.increment( - "request_success", - tags=self._timer.tags, - ) + self.metrics.increment( + "request_success", + tags=self._timer.tags, + ) return res def _after_execute(self, in_msg: Tin, out_msg: Tout, error: Exception | None) -> Tout: @@ -396,7 +394,7 @@ def _after_execute(self, in_msg: Tin, out_msg: Tout, error: Exception | None) -> return out_msg -def list_all_endpoint_names() -> List[Tuple[str, str]]: +def list_all_endpoint_names() -> list[tuple[str, str]]: return [ (name.split("__")[0], name.split("__")[1]) for name in RPCEndpoint.all_names() @@ -414,7 +412,7 @@ def list_all_endpoint_names() -> List[Tuple[str, str]]: def run_rpc_handler(name: str, version: str, data: bytes) -> ProtobufMessage | ErrorProto: try: - endpoint = RPCEndpoint.get_from_name(name, version)() # type: ignore + endpoint: RPCEndpoint[Any, Any] = RPCEndpoint.get_from_name(name, version)() except (AttributeError, InvalidConfigKeyError) as e: return convert_rpc_exception_to_proto( RPCRequestException( diff --git a/snuba/web/rpc/common/common.py b/snuba/web/rpc/common/common.py index ff0349ef271..f5514d6ff7c 100644 --- a/snuba/web/rpc/common/common.py +++ b/snuba/web/rpc/common/common.py @@ -1,6 +1,7 @@ import json -from datetime import datetime, timedelta, timezone -from typing import Any, Callable, TypeVar, cast +from collections.abc import Callable +from datetime import UTC, datetime, timedelta +from typing import Any, TypeVar, cast from google.protobuf.message import Message as ProtobufMessage from sentry_protos.snuba.v1.request_common_pb2 import RequestMeta @@ -241,10 +242,9 @@ def transform(exp: Expression) -> Expression: if exp.function_name == "and": return combine_and_conditions(exp.parameters) - elif exp.function_name == "or": + if exp.function_name == "or": return combine_or_conditions(exp.parameters) - else: - return exp + return exp query.transform_expressions(transform) @@ -564,7 +564,7 @@ def trace_item_filters_to_expression( filters = item_filter.and_filter.filters if len(filters) == 0: return literal(True) - elif len(filters) == 1: + if len(filters) == 1: return trace_item_filters_to_expression(filters[0], attribute_key_to_expression) return and_cond( *(trace_item_filters_to_expression(x, attribute_key_to_expression) for x in filters) @@ -574,7 +574,7 @@ def trace_item_filters_to_expression( filters = item_filter.or_filter.filters if len(filters) == 0: raise BadSnubaRPCRequestException("Invalid trace item filter, empty 'or' clause") - elif len(filters) == 1: + if len(filters) == 1: return trace_item_filters_to_expression(filters[0], attribute_key_to_expression) return or_cond( *(trace_item_filters_to_expression(x, attribute_key_to_expression) for x in filters) @@ -584,7 +584,7 @@ def trace_item_filters_to_expression( filters = item_filter.not_filter.filters if len(filters) == 0: raise BadSnubaRPCRequestException("Invalid trace item filter, empty 'not' clause") - elif len(filters) == 1: + if len(filters) == 1: return not_cond( trace_item_filters_to_expression(filters[0], attribute_key_to_expression) ) @@ -622,18 +622,15 @@ def trace_item_filters_to_expression( return _type_array_includes_scalar_expression( k_expression, v, item_filter.comparison_filter.ignore_case ) - else: - expr = ( - f.equals(f.lower(k_expression), f.lower(v_expression)) - if item_filter.comparison_filter.ignore_case - else f.equals(k_expression, v_expression) - ) - # we redefine the way equals works for nulls - # now null=null is true - expr_with_null = or_cond( - expr, and_cond(f.isNull(k_expression), f.isNull(v_expression)) - ) - return expr_with_null + expr = ( + f.equals(f.lower(k_expression), f.lower(v_expression)) + if item_filter.comparison_filter.ignore_case + else f.equals(k_expression, v_expression) + ) + # we redefine the way equals works for nulls + # now null=null is true + expr_with_null = or_cond(expr, and_cond(f.isNull(k_expression), f.isNull(v_expression))) + return expr_with_null if op == ComparisonFilter.OP_NOT_EQUALS: _check_non_string_values_cannot_ignore_case(item_filter.comparison_filter) if k.type == AttributeKey.Type.TYPE_ARRAY: @@ -642,18 +639,15 @@ def trace_item_filters_to_expression( k_expression, v, item_filter.comparison_filter.ignore_case ) ) - else: - expr = ( - f.notEquals(f.lower(k_expression), f.lower(v_expression)) - if item_filter.comparison_filter.ignore_case - else f.notEquals(k_expression, v_expression) - ) - # we redefine the way not equals works for nulls - # now null!=null is true - expr_with_null = or_cond( - expr, f.xor(f.isNull(k_expression), f.isNull(v_expression)) - ) - return expr_with_null + expr = ( + f.notEquals(f.lower(k_expression), f.lower(v_expression)) + if item_filter.comparison_filter.ignore_case + else f.notEquals(k_expression, v_expression) + ) + # we redefine the way not equals works for nulls + # now null!=null is true + expr_with_null = or_cond(expr, f.xor(f.isNull(k_expression), f.isNull(v_expression))) + return expr_with_null if op == ComparisonFilter.OP_LIKE: if k.type == AttributeKey.Type.TYPE_ARRAY: like_fn = f.ilike if item_filter.comparison_filter.ignore_case else f.like @@ -791,15 +785,11 @@ def timestamp_in_range_condition(start_ts: int, end_ts: int) -> Expression: return and_cond( f.less( column("timestamp"), - f.toDateTime( - datetime.fromtimestamp(end_ts, tz=timezone.utc).strftime("%Y-%m-%d %H:%M:%S") - ), + f.toDateTime(datetime.fromtimestamp(end_ts, tz=UTC).strftime("%Y-%m-%d %H:%M:%S")), ), f.greaterOrEquals( column("timestamp"), - f.toDateTime( - datetime.fromtimestamp(start_ts, tz=timezone.utc).strftime("%Y-%m-%d %H:%M:%S") - ), + f.toDateTime(datetime.fromtimestamp(start_ts, tz=UTC).strftime("%Y-%m-%d %H:%M:%S")), ), ) @@ -847,11 +837,12 @@ def get_subscriptable_field(field: Expression) -> SubscriptableReference | None: """ if isinstance(field, SubscriptableReference): return field - elif isinstance(field, FunctionCall) and len(field.parameters) > 0: - if len(field.parameters) > 0 and isinstance( - field.parameters[0], SubscriptableReference - ): - return field.parameters[0] + if ( + isinstance(field, FunctionCall) + and len(field.parameters) > 0 + and isinstance(field.parameters[0], SubscriptableReference) + ): + return field.parameters[0] return None diff --git a/snuba/web/rpc/common/debug_info.py b/snuba/web/rpc/common/debug_info.py index c1dde075e95..a13110cc5af 100644 --- a/snuba/web/rpc/common/debug_info.py +++ b/snuba/web/rpc/common/debug_info.py @@ -1,5 +1,3 @@ -from typing import List - from sentry_protos.snuba.v1.downsampled_storage_pb2 import DownsampledStorageMeta from sentry_protos.snuba.v1.request_common_pb2 import ( QueryInfo, @@ -16,15 +14,14 @@ def _construct_meta_if_downsampled( - query_results: List[QueryResult], + query_results: list[QueryResult], ) -> DownsampledStorageMeta | None: highest_sampling_tier = Tier.TIER_NO_TIER for query_result in query_results: sampling_tier = query_result.extra.get("stats", {}).get("sampling_tier") - if sampling_tier: - if sampling_tier.value > highest_sampling_tier.value: - highest_sampling_tier = sampling_tier + if sampling_tier and sampling_tier.value > highest_sampling_tier.value: + highest_sampling_tier = sampling_tier return ( DownsampledStorageMeta( @@ -38,10 +35,10 @@ def _construct_meta_if_downsampled( def extract_response_meta( request_id: str, debug: bool, - query_results: List[QueryResult], - timers: List[Timer], + query_results: list[QueryResult], + timers: list[Timer], ) -> ResponseMeta: - query_info: List[QueryInfo] = [] + query_info: list[QueryInfo] = [] downsampled_storage_meta = _construct_meta_if_downsampled(query_results) @@ -67,7 +64,7 @@ def extract_response_meta( else ResponseMeta(request_id=request_id, query_info=query_info) ) - for query_result, timer in zip(query_results, timers): + for query_result, timer in zip(query_results, timers, strict=False): extra = getattr(query_result, "extra", None) or {} stats = extra.get("stats", {}) if isinstance(extra, dict) else {} result = getattr(query_result, "result", None) or {} diff --git a/snuba/web/rpc/common/exceptions.py b/snuba/web/rpc/common/exceptions.py index 73372877893..fe6c5034683 100644 --- a/snuba/web/rpc/common/exceptions.py +++ b/snuba/web/rpc/common/exceptions.py @@ -1,4 +1,4 @@ -from typing import Any, Union +from typing import Any from google.protobuf import any_pb2, struct_pb2 from sentry_protos.snuba.v1.error_pb2 import Error as ErrorProto @@ -50,7 +50,7 @@ def from_args( ) -def convert_rpc_exception_to_proto(exc: Union[RPCRequestException, QueryException]) -> ErrorProto: +def convert_rpc_exception_to_proto(exc: RPCRequestException | QueryException) -> ErrorProto: if isinstance(exc, RPCRequestException): s = struct_pb2.Struct() s.update(exc.details) diff --git a/snuba/web/rpc/proto_visitor.py b/snuba/web/rpc/proto_visitor.py index 17571557daf..b33c6cbe3d7 100644 --- a/snuba/web/rpc/proto_visitor.py +++ b/snuba/web/rpc/proto_visitor.py @@ -1,8 +1,9 @@ from __future__ import annotations from abc import ABC, abstractmethod +from collections.abc import Callable from types import MethodType -from typing import Any, Callable, Generic, TypeVar +from typing import Any, Generic, TypeVar from google.protobuf.message import Message as ProtobufMessage from sentry_protos.snuba.v1.attribute_conditional_aggregation_pb2 import ( @@ -51,7 +52,7 @@ def accept(self, visitor: ProtoVisitor) -> None: ColumnWrapper(conditional.condition.right).accept(visitor) # Note: 'match' is a Python keyword, so use getattr if conditional.HasField("match"): - ColumnWrapper(getattr(conditional, "match")).accept(visitor) + ColumnWrapper(conditional.match).accept(visitor) if conditional.HasField("default"): ColumnWrapper(conditional.default).accept(visitor) @@ -118,7 +119,7 @@ def accept(self, visitor: ProtoVisitor) -> None: TraceItemFilterWrapper(f).accept(visitor) -class ProtoVisitor(ABC): +class ProtoVisitor(ABC): # noqa: B024 dynamic visit dispatch via __getattr__; ABC marks it non-instantiable by design """ Proto visitor design is split into two parts: 1. the visitor. Responsible for only executing work on the object it is visiting diff --git a/snuba/web/rpc/storage_routing/common.py b/snuba/web/rpc/storage_routing/common.py index 779be2432e3..758b0055990 100644 --- a/snuba/web/rpc/storage_routing/common.py +++ b/snuba/web/rpc/storage_routing/common.py @@ -8,9 +8,6 @@ def extract_message_meta(in_msg: ProtobufMessage) -> RequestMeta: if isinstance(in_msg, CreateSubscriptionRequest): return in_msg.time_series_request.meta - elif ( - hasattr(in_msg, "meta") and in_msg.HasField("meta") and isinstance(in_msg.meta, RequestMeta) - ): + if hasattr(in_msg, "meta") and in_msg.HasField("meta") and isinstance(in_msg.meta, RequestMeta): return in_msg.meta - else: - raise ValueError(f"Invalid message type: {type(in_msg)}") + raise ValueError(f"Invalid message type: {type(in_msg)}") diff --git a/snuba/web/rpc/storage_routing/load_retriever.py b/snuba/web/rpc/storage_routing/load_retriever.py index 43b8775f0dc..7839aca874d 100644 --- a/snuba/web/rpc/storage_routing/load_retriever.py +++ b/snuba/web/rpc/storage_routing/load_retriever.py @@ -1,7 +1,8 @@ import inspect import json +from collections.abc import Callable from functools import wraps -from typing import Any, Callable +from typing import Any import sentry_sdk diff --git a/snuba/web/rpc/storage_routing/routing_strategies/outcomes_based.py b/snuba/web/rpc/storage_routing/routing_strategies/outcomes_based.py index 69a002cd037..fc8c539d998 100644 --- a/snuba/web/rpc/storage_routing/routing_strategies/outcomes_based.py +++ b/snuba/web/rpc/storage_routing/routing_strategies/outcomes_based.py @@ -99,22 +99,15 @@ def get_item_types_in_query( item_types = set() # Handle TraceItemTableRequest - if isinstance(in_msg, TraceItemTableRequest): - if hasattr(in_msg, "trace_filters") and in_msg.trace_filters: - for trace_filter in in_msg.trace_filters: - item_types.add(trace_filter.item_type) - - # Handle TimeSeriesRequest - elif isinstance(in_msg, TimeSeriesRequest): + if isinstance(in_msg, (TraceItemTableRequest, TimeSeriesRequest)): if hasattr(in_msg, "trace_filters") and in_msg.trace_filters: for trace_filter in in_msg.trace_filters: item_types.add(trace_filter.item_type) # Handle GetTracesRequest - elif isinstance(in_msg, GetTracesRequest): - if hasattr(in_msg, "filters") and in_msg.filters: - for filter_item in in_msg.filters: - item_types.add(filter_item.item_type) + elif isinstance(in_msg, GetTracesRequest) and hasattr(in_msg, "filters") and in_msg.filters: + for filter_item in in_msg.filters: + item_types.add(filter_item.item_type) # Fallback to meta.trace_item_type if ( diff --git a/snuba/web/rpc/storage_routing/routing_strategies/outcomes_flex_time.py b/snuba/web/rpc/storage_routing/routing_strategies/outcomes_flex_time.py index 7e69f590a0f..9646e90f972 100644 --- a/snuba/web/rpc/storage_routing/routing_strategies/outcomes_flex_time.py +++ b/snuba/web/rpc/storage_routing/routing_strategies/outcomes_flex_time.py @@ -60,7 +60,7 @@ def _get_request_time_window(routing_context: RoutingContext) -> TimeWindow: meta = extract_message_meta(routing_context.in_msg) if routing_context.in_msg.HasField("page_token"): time_window = FlexibleTimeWindowPageWithFilters( - getattr(routing_context.in_msg, "page_token") + getattr(routing_context.in_msg, "page_token") # noqa: B009 # proto dynamic attr ).get_time_window() if time_window: return time_window @@ -148,7 +148,7 @@ def _adjust_time_window(self, routing_context: RoutingContext) -> TimeWindow: window_length = original_end_ts - original_start_ts start_timestamp_proto = TimestampProto( - seconds=original_end_ts - math.floor((window_length / factor)) + seconds=original_end_ts - math.floor(window_length / factor) ) end_timestamp_proto = TimestampProto(seconds=original_end_ts) return TimeWindow(start_timestamp_proto, end_timestamp_proto) diff --git a/snuba/web/rpc/storage_routing/routing_strategies/storage_routing.py b/snuba/web/rpc/storage_routing/routing_strategies/storage_routing.py index 648a8ab981c..fd9260b3e5c 100644 --- a/snuba/web/rpc/storage_routing/routing_strategies/storage_routing.py +++ b/snuba/web/rpc/storage_routing/routing_strategies/storage_routing.py @@ -2,17 +2,13 @@ import os from abc import ABC +from collections.abc import Callable from dataclasses import dataclass, field +from datetime import UTC from typing import ( Any, - Callable, - Dict, - List, NamedTuple, - Optional, - TypeAlias, TypedDict, - Union, cast, final, ) @@ -64,12 +60,10 @@ _START_ESTIMATION_MARK = "start_sampling_in_storage_estimation" _END_ESTIMATION_MARK = "end_sampling_in_storage_estimation" DEFAULT_STORAGE_ROUTING_CONFIG_PREFIX = "StorageRouting" -MetricsBackendType: TypeAlias = Callable[ - [str, Union[int, float], Optional[Dict[str, str]], Optional[str]], None -] +MetricsBackendType = Callable[[str, int | float, dict[str, str] | None, str | None], None] CBRS_HASH = "cbrs" -RoutedRequestType = Union[TimeSeriesRequest, TraceItemTableRequest] -ClickhouseQuerySettings = Dict[str, Any] +RoutedRequestType = TimeSeriesRequest | TraceItemTableRequest +ClickhouseQuerySettings = dict[str, Any] class _OrgOverridableSetting(NamedTuple): @@ -111,7 +105,7 @@ class RoutingContext: timer: Timer in_msg: ProtobufMessage query_id: str - query_result: Optional[QueryResult] = field(default=None) + query_result: QueryResult | None = field(default=None) extra_info: dict[str, Any] = field(default_factory=dict) allocation_policies_recommendations: dict[str, QuotaAllowance] = field(default_factory=dict) cluster_load_info: LoadInfo | None = field(default=None) @@ -139,12 +133,12 @@ def length_hours(self) -> float: return (self.end_timestamp.seconds - self.start_timestamp.seconds) / 3600 def __repr__(self) -> str: - from datetime import datetime, timezone + from datetime import datetime - start = datetime.fromtimestamp(self.start_timestamp.seconds, tz=timezone.utc).strftime( + start = datetime.fromtimestamp(self.start_timestamp.seconds, tz=UTC).strftime( "%Y-%m-%d %H:%M:%S" ) - end = datetime.fromtimestamp(self.end_timestamp.seconds, tz=timezone.utc).strftime( + end = datetime.fromtimestamp(self.end_timestamp.seconds, tz=UTC).strftime( "%Y-%m-%d %H:%M:%S" ) @@ -213,7 +207,7 @@ def get_stats_dict( def _construct_hacky_querylog_payload( - strategy: "BaseRoutingStrategy", routing_decision: RoutingDecision + strategy: BaseRoutingStrategy, routing_decision: RoutingDecision ) -> snuba_queries_v1.Querylog: cur_span = sentry_sdk.get_current_span() assert routing_decision.routing_context is not None @@ -285,7 +279,9 @@ class StrategyData(ConfigurableComponentData): class BaseRoutingStrategy(ConfigurableComponent, ABC): - def __init__(self, default_config_overrides: dict[str, Any] = {}) -> None: + def __init__(self, default_config_overrides: dict[str, Any] | None = None) -> None: + if default_config_overrides is None: + default_config_overrides = {} self._default_config_definitions = [ RoutingStrategyConfig( name="some_default_config", @@ -319,12 +315,11 @@ def _get_default_routing_decision_tier(self) -> Tier: if tier_int == 512: return Tier.TIER_512 - elif tier_int == 64: + if tier_int == 64: return Tier.TIER_64 - elif tier_int == 8: + if tier_int == 8: return Tier.TIER_8 - else: - return Tier.TIER_1 + return Tier.TIER_1 def additional_config_definitions(self) -> list[Configuration]: return self._overridden_additional_config_definitions @@ -344,7 +339,7 @@ def metrics(self) -> MetricsWrapper: ) @classmethod - def create_minimal_instance(cls, resource_identifier: str) -> "ConfigurableComponent": + def create_minimal_instance(cls, resource_identifier: str) -> ConfigurableComponent: return cls( default_config_overrides={}, ) @@ -403,7 +398,7 @@ def _record_value_in_span_and_DD( metrics_backend_func: MetricsBackendType, name: str, value: float | int, - tags: Dict[str, str] | None = None, + tags: dict[str, str] | None = None, ) -> None: name = _SAMPLING_IN_STORAGE_PREFIX + name metrics_backend_func(name, value, tags, None) @@ -437,7 +432,7 @@ def _get_org_clickhouse_setting_overrides( return overrides def _get_combined_allocation_policies_recommendations( - self, policy_recommendations: List[QuotaAllowance] + self, policy_recommendations: list[QuotaAllowance] ) -> CombinedAllocationPoliciesRecommendations: # decides how to combine the recommendations from the allocation policies settings = {} @@ -700,7 +695,7 @@ def _emit_routing_mistake(self, routing_decision: RoutingDecision) -> None: def to_dict(self) -> StrategyData: base_data = super().to_dict() policies = self.get_allocation_policies() + self.get_delete_allocation_policies() - return StrategyData(**base_data, policies_data=[policy.to_dict() for policy in policies]) # type: ignore + return StrategyData(**base_data, policies_data=[policy.to_dict() for policy in policies]) import_submodules_in_directory( diff --git a/snuba/web/rpc/storage_routing/routing_strategy_selector.py b/snuba/web/rpc/storage_routing/routing_strategy_selector.py index d3d20c75035..29ea6d19713 100644 --- a/snuba/web/rpc/storage_routing/routing_strategy_selector.py +++ b/snuba/web/rpc/storage_routing/routing_strategy_selector.py @@ -1,7 +1,8 @@ import hashlib import json +from collections.abc import Iterable from dataclasses import dataclass -from typing import Any, Iterable, Tuple +from typing import Any import sentry_sdk from google.protobuf.message import Message as ProtobufMessage @@ -32,7 +33,7 @@ class StorageRoutingConfig: version: int _routing_strategy_and_percentage_routed: dict[str, float] - def get_routing_strategy_and_percentage_routed(self) -> Iterable[Tuple[str, float]]: + def get_routing_strategy_and_percentage_routed(self) -> Iterable[tuple[str, float]]: return sorted(self._routing_strategy_and_percentage_routed.items()) @classmethod @@ -58,8 +59,10 @@ def from_json(cls, config_dict: dict[str, Any]) -> "StorageRoutingConfig": try: BaseRoutingStrategy.get_from_name(strategy_name)() - except Exception: - raise ValueError(f"{strategy_name} does not inherit from BaseRoutingStrategy") + except Exception as e: + raise ValueError( + f"{strategy_name} does not inherit from BaseRoutingStrategy" + ) from e routing_strategy_and_percentage_routed[strategy_name] = percentage total_percentage += percentage @@ -90,7 +93,7 @@ def get_storage_routing_config(self, in_msg: ProtobufMessage) -> StorageRoutingC organization_id = str(in_msg_meta.organization_id) try: overrides = json.loads(str(get_config(_STORAGE_ROUTING_CONFIG_OVERRIDE_KEY, "{}"))) - if organization_id in overrides.keys(): + if organization_id in overrides: return StorageRoutingConfig.from_json(overrides[organization_id]) config = str(get_config(_DEFAULT_STORAGE_ROUTING_CONFIG_KEY, "{}")) diff --git a/snuba/web/rpc/v1/create_subscription.py b/snuba/web/rpc/v1/create_subscription.py index b58c0f8d5d1..5be526867b9 100644 --- a/snuba/web/rpc/v1/create_subscription.py +++ b/snuba/web/rpc/v1/create_subscription.py @@ -1,5 +1,3 @@ -from typing import Type - from sentry_protos.snuba.v1.endpoint_create_subscription_pb2 import ( CreateSubscriptionRequest as CreateSubscriptionRequestProto, ) @@ -22,11 +20,11 @@ def version(cls) -> str: return "v1" @classmethod - def request_class(cls) -> Type[CreateSubscriptionRequestProto]: + def request_class(cls) -> type[CreateSubscriptionRequestProto]: return CreateSubscriptionRequestProto @classmethod - def response_class(cls) -> Type[CreateSubscriptionResponse]: + def response_class(cls) -> type[CreateSubscriptionResponse]: return CreateSubscriptionResponse def _execute(self, in_msg: CreateSubscriptionRequestProto) -> CreateSubscriptionResponse: diff --git a/snuba/web/rpc/v1/endpoint_delete_trace_items.py b/snuba/web/rpc/v1/endpoint_delete_trace_items.py index 299c518c092..b8cc987b02d 100644 --- a/snuba/web/rpc/v1/endpoint_delete_trace_items.py +++ b/snuba/web/rpc/v1/endpoint_delete_trace_items.py @@ -1,4 +1,5 @@ -from typing import Any, Dict, List, Optional, Sequence, Tuple, Type +from collections.abc import Sequence +from typing import Any from sentry_protos.snuba.v1.endpoint_delete_trace_items_pb2 import ( DeleteTraceItemsRequest, @@ -23,18 +24,17 @@ def _extract_attribute_value(comparison_filter: ComparisonFilter) -> Any: value_field = comparison_filter.value.WhichOneof("value") if value_field == "val_str": return comparison_filter.value.val_str - elif value_field == "val_int": + if value_field == "val_int": return comparison_filter.value.val_int - elif value_field == "val_double": + if value_field == "val_double": return comparison_filter.value.val_double - elif value_field == "val_bool": + if value_field == "val_bool": return comparison_filter.value.val_bool - elif value_field == "val_array": + if value_field == "val_array": return [_scalar_value(elem) for elem in comparison_filter.value.val_array.values] - elif value_field in ("val_str_array", "val_int_array", "val_double_array", "val_float_array"): + if value_field in ("val_str_array", "val_int_array", "val_double_array", "val_float_array"): return list(getattr(comparison_filter.value, value_field).values) - else: - raise BadSnubaRPCRequestException(f"Unsupported attribute value type: {value_field}") + raise BadSnubaRPCRequestException(f"Unsupported attribute value type: {value_field}") def _trace_item_filters_to_attribute_conditions( @@ -57,7 +57,7 @@ def _trace_item_filters_to_attribute_conditions( Raises: BadSnubaRPCRequestException: If unsupported filter types or operations are encountered """ - attributes: Dict[str, Tuple[AttributeKey, List[Any]]] = {} + attributes: dict[str, tuple[AttributeKey, list[Any]]] = {} for filter_with_type in filters: # Extract the actual filter from TraceItemFilterWithType @@ -102,11 +102,11 @@ def version(cls) -> str: return "v1" @classmethod - def request_class(cls) -> Type[DeleteTraceItemsRequest]: + def request_class(cls) -> type[DeleteTraceItemsRequest]: return DeleteTraceItemsRequest @classmethod - def response_class(cls) -> Type[DeleteTraceItemsResponse]: + def response_class(cls) -> type[DeleteTraceItemsResponse]: return DeleteTraceItemsResponse def _execute(self, request: DeleteTraceItemsRequest) -> DeleteTraceItemsResponse: @@ -132,12 +132,12 @@ def _execute(self, request: DeleteTraceItemsRequest) -> DeleteTraceItemsResponse } # Build base conditions that apply to all deletions - conditions: Dict[str, List[Any]] = { + conditions: dict[str, list[Any]] = { "organization_id": [request.meta.organization_id], "project_id": list(request.meta.project_ids), } - attribute_conditions: Optional[AttributeConditions] = None + attribute_conditions: AttributeConditions | None = None if has_trace_ids: # Delete by trace_ids (no attribute filtering) diff --git a/snuba/web/rpc/v1/endpoint_export_trace_items.py b/snuba/web/rpc/v1/endpoint_export_trace_items.py index 471047922e0..51dcb916a11 100644 --- a/snuba/web/rpc/v1/endpoint_export_trace_items.py +++ b/snuba/web/rpc/v1/endpoint_export_trace_items.py @@ -1,6 +1,7 @@ import uuid +from collections.abc import Iterable from datetime import datetime -from typing import Any, Dict, Iterable, NamedTuple, Type, cast +from typing import Any, Literal, NamedTuple, cast import sentry_sdk from google.protobuf.json_format import MessageToDict @@ -71,17 +72,17 @@ class FlexWindow(NamedTuple): @classmethod def from_filters(cls, filters: list[TraceItemFilter]) -> "FlexWindow": - _SCHEMA = [ + _SCHEMA: list[tuple[str, Literal["val_int"]]] = [ (FLEX_WIN_START, "val_int"), (FLEX_WIN_END, "val_int"), ] values = [] - for expected_filter, timestamp_filter in zip(_SCHEMA, filters): + for expected_filter, timestamp_filter in zip(_SCHEMA, filters, strict=False): filter_name, filter_value_type = expected_filter if ( not timestamp_filter.HasField("comparison_filter") or timestamp_filter.comparison_filter.key.name != filter_name - or not timestamp_filter.comparison_filter.value.HasField(filter_value_type) # type: ignore[arg-type] + or not timestamp_filter.comparison_filter.value.HasField(filter_value_type) ): raise ValueError("Invalid timestamp filter in page token") values.append(timestamp_filter.comparison_filter.value.val_int) @@ -144,7 +145,9 @@ def to_filters(self) -> list[TraceItemFilter]: value=AttributeValue(**{val_field: value}), # type: ignore[arg-type] ) ) - for (name, attr_type, val_field), value in zip(_KEYSET_CURSOR_SCHEMA, self) + for (name, attr_type, val_field), value in zip( + _KEYSET_CURSOR_SCHEMA, self, strict=False + ) ] @@ -387,18 +390,17 @@ def _build_snuba_request( def _to_any_value(value: Any) -> AnyValue: if isinstance(value, bool): return AnyValue(bool_value=value) - elif isinstance(value, int): + if isinstance(value, int): return AnyValue(int_value=value) - elif isinstance(value, float): + if isinstance(value, float): return AnyValue(double_value=value) - elif isinstance(value, str): + if isinstance(value, str): return AnyValue(string_value=value) - elif isinstance(value, list): + if isinstance(value, list): return AnyValue(array_value=ArrayValue(values=[_to_any_value(v) for v in value])) - elif isinstance(value, datetime): + if isinstance(value, datetime): return AnyValue(double_value=value.timestamp()) - else: - raise BadSnubaRPCRequestException(f"data type unknown: {type(value)}") + raise BadSnubaRPCRequestException(f"data type unknown: {type(value)}") class ProcessedResults(NamedTuple): @@ -406,7 +408,7 @@ class ProcessedResults(NamedTuple): keyset_cursor: KeysetCursor -def _convert_rows(rows: Iterable[Dict[str, Any]]) -> ProcessedResults: +def _convert_rows(rows: Iterable[dict[str, Any]]) -> ProcessedResults: items: list[TraceItem] = [] last_seen_project_id = 0 last_seen_item_type = TraceItemType.TRACE_ITEM_TYPE_UNSPECIFIED @@ -493,11 +495,11 @@ def version(cls) -> str: return "v1" @classmethod - def request_class(cls) -> Type[ExportTraceItemsRequest]: + def request_class(cls) -> type[ExportTraceItemsRequest]: return ExportTraceItemsRequest @classmethod - def response_class(cls) -> Type[ExportTraceItemsResponse]: + def response_class(cls) -> type[ExportTraceItemsResponse]: return ExportTraceItemsResponse def _execute(self, in_msg: ExportTraceItemsRequest) -> ExportTraceItemsResponse: diff --git a/snuba/web/rpc/v1/endpoint_get_trace.py b/snuba/web/rpc/v1/endpoint_get_trace.py index 411883bdca2..1f57e6169d0 100644 --- a/snuba/web/rpc/v1/endpoint_get_trace.py +++ b/snuba/web/rpc/v1/endpoint_get_trace.py @@ -1,8 +1,9 @@ import random import uuid +from collections.abc import Iterable from datetime import datetime from operator import attrgetter -from typing import Any, Dict, Iterable, NamedTuple, Optional, Type +from typing import Any, NamedTuple, Optional import sentry_sdk from google.protobuf.json_format import MessageToDict @@ -230,16 +231,14 @@ def _build_query( *attributes_array_selected_expressions(), ] selected_columns.extend( - map( - lambda col_name: SelectedExpression( - name=col_name, - expression=column( - col_name, - alias=f"selected_{col_name}", - ), + SelectedExpression( + name=col_name, + expression=column( + col_name, + alias=f"selected_{col_name}", ), - (NORMALIZED_COLUMNS_TO_INCLUDE_EAP_ITEMS), ) + for col_name in (NORMALIZED_COLUMNS_TO_INCLUDE_EAP_ITEMS) ) entity = Entity( @@ -379,28 +378,27 @@ def convert_to_attribute_value(value: Any) -> AttributeValue: return AttributeValue( val_bool=value, ) - elif isinstance(value, int): + if isinstance(value, int): return AttributeValue( val_int=value, ) - elif isinstance(value, float): + if isinstance(value, float): return AttributeValue( val_double=value, ) - elif isinstance(value, str): + if isinstance(value, str): return AttributeValue( val_str=value, ) - elif isinstance(value, (list, tuple)): + if isinstance(value, (list, tuple)): return AttributeValue( val_array=Array(values=[convert_to_attribute_value(v) for v in value]) ) - elif isinstance(value, datetime): + if isinstance(value, datetime): return AttributeValue( val_double=value.timestamp(), ) - else: - raise BadSnubaRPCRequestException(f"data type unknown: {type(value)}") + raise BadSnubaRPCRequestException(f"data type unknown: {type(value)}") def _value_to_attribute(key: str, value: Any) -> tuple[AttributeKey, AttributeValue]: @@ -412,7 +410,7 @@ def _value_to_attribute(key: str, value: Any) -> tuple[AttributeKey, AttributeVa ), convert_to_attribute_value(value), ) - elif isinstance(value, int): + if isinstance(value, int): return ( AttributeKey( name=key, @@ -420,7 +418,7 @@ def _value_to_attribute(key: str, value: Any) -> tuple[AttributeKey, AttributeVa ), convert_to_attribute_value(value), ) - elif isinstance(value, float): + if isinstance(value, float): return ( AttributeKey( name=key, @@ -428,7 +426,7 @@ def _value_to_attribute(key: str, value: Any) -> tuple[AttributeKey, AttributeVa ), convert_to_attribute_value(value), ) - elif isinstance(value, str): + if isinstance(value, str): return ( AttributeKey( name=key, @@ -436,12 +434,12 @@ def _value_to_attribute(key: str, value: Any) -> tuple[AttributeKey, AttributeVa ), convert_to_attribute_value(value), ) - elif isinstance(value, list): + if isinstance(value, list): return ( AttributeKey(name=key, type=AttributeKey.Type.TYPE_ARRAY), convert_to_attribute_value(value), ) - elif isinstance(value, datetime): + if isinstance(value, datetime): return ( AttributeKey( name=key, @@ -449,23 +447,18 @@ def _value_to_attribute(key: str, value: Any) -> tuple[AttributeKey, AttributeVa ), convert_to_attribute_value(value), ) - else: - raise BadSnubaRPCRequestException(f"data type unknown: {type(value)}") + raise BadSnubaRPCRequestException(f"data type unknown: {type(value)}") -ProcessedResults = NamedTuple( - "ProcessedResults", - [ - ("items", list[GetTraceResponse.Item]), - ("last_seen_timestamp_precise", float), - ("last_seen_id", str), - ], -) +class ProcessedResults(NamedTuple): + items: list[GetTraceResponse.Item] + last_seen_timestamp_precise: float + last_seen_id: str @with_span(op="function") def _process_results( - data: Iterable[Dict[str, Any]], + data: Iterable[dict[str, Any]], ) -> ProcessedResults: """ Used to process the results returned from clickhouse in two passes. @@ -496,7 +489,11 @@ def _process_results( attributes: dict[str, GetTraceResponse.Item.Attribute] = {} - def add_attribute(key: str, value: Any) -> None: + def add_attribute( + key: str, + value: Any, + attributes: dict[str, GetTraceResponse.Item.Attribute] = attributes, + ) -> None: attribute_key, attribute_value = _value_to_attribute(key, value) attributes[key] = GetTraceResponse.Item.Attribute( key=attribute_key, @@ -580,11 +577,11 @@ def version(cls) -> str: return "v1" @classmethod - def request_class(cls) -> Type[GetTraceRequest]: + def request_class(cls) -> type[GetTraceRequest]: return GetTraceRequest @classmethod - def response_class(cls) -> Type[GetTraceResponse]: + def response_class(cls) -> type[GetTraceResponse]: return GetTraceResponse def _execute(self, in_msg: GetTraceRequest) -> GetTraceResponse: diff --git a/snuba/web/rpc/v1/endpoint_get_traces.py b/snuba/web/rpc/v1/endpoint_get_traces.py index a7a7a9a18f1..269e132cee1 100644 --- a/snuba/web/rpc/v1/endpoint_get_traces.py +++ b/snuba/web/rpc/v1/endpoint_get_traces.py @@ -1,6 +1,7 @@ import uuid from collections import defaultdict -from typing import Any, Callable, Dict, Iterable, Optional, Type +from collections.abc import Callable, Iterable +from typing import Any from google.protobuf.internal.containers import RepeatedCompositeFieldContainer from google.protobuf.json_format import MessageToDict @@ -157,7 +158,7 @@ def _get_attribute_expression( def _attribute_to_expression( trace_attribute: TraceAttribute, - condition: Optional[Expression], + condition: Expression | None, request_meta: RequestMeta, ) -> Expression: def _get_root_span_attribute( @@ -300,7 +301,7 @@ def _get_earliest_frontend_span_attribute( clickhouse_type, alias=alias, ) - elif key == TraceAttribute.Key.KEY_END_TIMESTAMP: + if key == TraceAttribute.Key.KEY_END_TIMESTAMP: return f.cast( f.max( if_cond( @@ -327,43 +328,41 @@ def _get_earliest_frontend_span_attribute( clickhouse_type, alias=alias, ) - elif key == TraceAttribute.Key.KEY_TOTAL_ITEM_COUNT: + if key == TraceAttribute.Key.KEY_TOTAL_ITEM_COUNT: return f.count(alias=alias) - elif key == TraceAttribute.Key.KEY_FILTERED_ITEM_COUNT: + if key == TraceAttribute.Key.KEY_FILTERED_ITEM_COUNT: if condition: return f.countIf(condition, alias=alias) - else: - return f.count(alias=alias) - elif key == TraceAttribute.Key.KEY_ROOT_SPAN_NAME: + return f.count(alias=alias) + if key == TraceAttribute.Key.KEY_ROOT_SPAN_NAME: return _get_root_span_attribute("sentry.raw_description", AttributeKey.Type.TYPE_STRING) - elif key == TraceAttribute.Key.KEY_ROOT_SPAN_DURATION_MS: + if key == TraceAttribute.Key.KEY_ROOT_SPAN_DURATION_MS: return _get_root_span_attribute("sentry.duration_ms", AttributeKey.Type.TYPE_DOUBLE) - elif key == TraceAttribute.Key.KEY_ROOT_SPAN_PROJECT_ID: + if key == TraceAttribute.Key.KEY_ROOT_SPAN_PROJECT_ID: return _get_root_span_attribute("sentry.project_id", AttributeKey.Type.TYPE_INT) - elif key == TraceAttribute.Key.KEY_EARLIEST_SPAN_NAME: + if key == TraceAttribute.Key.KEY_EARLIEST_SPAN_NAME: return _get_earliest_span_attribute( "sentry.raw_description", AttributeKey.Type.TYPE_STRING ) - elif key == TraceAttribute.Key.KEY_EARLIEST_SPAN_PROJECT_ID: + if key == TraceAttribute.Key.KEY_EARLIEST_SPAN_PROJECT_ID: return _get_earliest_span_attribute("sentry.project_id", AttributeKey.Type.TYPE_INT) - elif key == TraceAttribute.Key.KEY_EARLIEST_SPAN_DURATION_MS: + if key == TraceAttribute.Key.KEY_EARLIEST_SPAN_DURATION_MS: return _get_earliest_span_attribute("sentry.duration_ms", AttributeKey.Type.TYPE_DOUBLE) - elif key == TraceAttribute.Key.KEY_EARLIEST_FRONTEND_SPAN: + if key == TraceAttribute.Key.KEY_EARLIEST_FRONTEND_SPAN: return _get_earliest_frontend_span_attribute( "sentry.raw_description", AttributeKey.Type.TYPE_STRING ) - elif key == TraceAttribute.Key.KEY_EARLIEST_FRONTEND_SPAN_PROJECT_ID: + if key == TraceAttribute.Key.KEY_EARLIEST_FRONTEND_SPAN_PROJECT_ID: return _get_earliest_frontend_span_attribute( "sentry.project_id", AttributeKey.Type.TYPE_INT ) - elif key == TraceAttribute.Key.KEY_EARLIEST_FRONTEND_SPAN_DURATION_MS: + if key == TraceAttribute.Key.KEY_EARLIEST_FRONTEND_SPAN_DURATION_MS: return _get_earliest_frontend_span_attribute( "sentry.duration_ms", AttributeKey.Type.TYPE_DOUBLE ) - elif key == TraceAttribute.Key.KEY_TRACE_ID: + if key == TraceAttribute.Key.KEY_TRACE_ID: return column("trace_id", alias="hex_trace_id") - else: - return f.cast(column(attribute_name), clickhouse_type, alias=alias) + return f.cast(column(attribute_name), clickhouse_type, alias=alias) raise BadSnubaRPCRequestException(f"{key} had an unknown or unset type: {trace_attribute.type}") @@ -371,9 +370,11 @@ def _get_earliest_frontend_span_attribute( def _build_snuba_request( request: GetTracesRequest, query: Query, - clickhouse_settings: dict[str, Any] = {}, + clickhouse_settings: dict[str, Any] | None = None, query_settings: QuerySettings | None = None, ) -> SnubaRequest: + if clickhouse_settings is None: + clickhouse_settings = {} query_settings = query_settings or ( setup_trace_query_settings() if request.meta.debug else HTTPQuerySettings() ) @@ -402,7 +403,7 @@ def _build_snuba_request( def _convert_results( request: GetTracesRequest, - data: Iterable[Dict[str, Any]], + data: Iterable[dict[str, Any]], ) -> list[GetTracesResponse.Trace]: res: list[GetTracesResponse.Trace] = [] column_ordering = { @@ -446,8 +447,8 @@ def _get_page_token( def _validate_order_by(in_msg: GetTracesRequest) -> None: - order_by_cols = set([ob.key for ob in in_msg.order_by]) - selected_columns = set([c.key for c in in_msg.attributes]) + order_by_cols = {ob.key for ob in in_msg.order_by} + selected_columns = {c.key for c in in_msg.attributes} if not order_by_cols.issubset(selected_columns): raise BadSnubaRPCRequestException( f"Ordered by columns {order_by_cols} not selected: {selected_columns}" @@ -460,11 +461,11 @@ def version(cls) -> str: return "v1" @classmethod - def request_class(cls) -> Type[GetTracesRequest]: + def request_class(cls) -> type[GetTracesRequest]: return GetTracesRequest @classmethod - def response_class(cls) -> Type[GetTracesResponse]: + def response_class(cls) -> type[GetTracesResponse]: return GetTracesResponse def _execute_with_subquery_optimization(self, in_msg: GetTracesRequest) -> GetTracesResponse: @@ -515,11 +516,10 @@ def _execute(self, in_msg: GetTracesRequest) -> GetTracesResponse: # Get a dict of trace IDs and timestamps. if use_cross_item_path: return self._execute_with_subquery_optimization(in_msg) - else: - trace_ids, trace_ids_query_result = self._get_trace_ids_for_single_item_query( - request=in_msg - ) - query_results.append(trace_ids_query_result) + trace_ids, trace_ids_query_result = self._get_trace_ids_for_single_item_query( + request=in_msg + ) + query_results.append(trace_ids_query_result) if len(trace_ids) == 0: response_meta = extract_response_meta( @@ -553,7 +553,7 @@ def _execute(self, in_msg: GetTracesRequest) -> GetTracesResponse: def _is_cross_event_query( self, filters: RepeatedCompositeFieldContainer[GetTracesRequest.TraceFilter] ) -> bool: - return len(set([f.item_type for f in filters])) > 1 + return len({f.item_type for f in filters}) > 1 def _get_trace_item_filter_expressions( self, filters: RepeatedCompositeFieldContainer[GetTracesRequest.TraceFilter] @@ -677,9 +677,7 @@ def _get_metadata_for_traces( trace_item_filters_expression = next(iter(filter_expressions_by_item_type.values())) item_type = next(iter(filter_expressions_by_item_type.keys())) elif len(filter_expressions_by_item_type) > 1: - trace_item_filters_expression = or_cond( - *[expression for expression in filter_expressions_by_item_type.values()] - ) + trace_item_filters_expression = or_cond(*list(filter_expressions_by_item_type.values())) else: item_type = TraceItemType.TRACE_ITEM_TYPE_SPAN @@ -793,9 +791,7 @@ def _get_metadata_for_traces_with_subquery( trace_item_filters_expression = next(iter(filter_expressions_by_item_type.values())) item_type = next(iter(filter_expressions_by_item_type.keys())) elif len(filter_expressions_by_item_type) > 1: - trace_item_filters_expression = or_cond( - *[expression for expression in filter_expressions_by_item_type.values()] - ) + trace_item_filters_expression = or_cond(*list(filter_expressions_by_item_type.values())) else: item_type = TraceItemType.TRACE_ITEM_TYPE_SPAN diff --git a/snuba/web/rpc/v1/endpoint_time_series.py b/snuba/web/rpc/v1/endpoint_time_series.py index 7df79e289c0..16a25e83317 100644 --- a/snuba/web/rpc/v1/endpoint_time_series.py +++ b/snuba/web/rpc/v1/endpoint_time_series.py @@ -1,5 +1,4 @@ import math -from typing import Type from sentry_protos.snuba.v1.endpoint_time_series_pb2 import ( Expression, @@ -19,25 +18,23 @@ preprocess_expression_labels, ) -_VALID_GRANULARITY_SECS = set( - [ - 15, - 30, - 60, # seconds - 2 * 60, - 5 * 60, - 10 * 60, - 15 * 60, - 30 * 60, # minutes - 1 * 3600, - 2 * 3600, - 3 * 3600, - 4 * 3600, - 6 * 3600, - 12 * 3600, - 24 * 3600, # hours - ] -) +_VALID_GRANULARITY_SECS = { + 15, + 30, + 60, # seconds + 2 * 60, + 5 * 60, + 10 * 60, + 15 * 60, + 30 * 60, # minutes + 1 * 3600, + 2 * 3600, + 3 * 3600, + 4 * 3600, + 6 * 3600, + 12 * 3600, + 24 * 3600, # hours +} # MAX 1 minute granularity over 7 days (10080 buckets) + additional buckets to allow for partial time buckets on _MAX_BUCKETS_IN_REQUEST = 10100 @@ -100,11 +97,11 @@ def version(cls) -> str: return "v1" @classmethod - def request_class(cls) -> Type[TimeSeriesRequest]: + def request_class(cls) -> type[TimeSeriesRequest]: return TimeSeriesRequest @classmethod - def response_class(cls) -> Type[TimeSeriesResponse]: + def response_class(cls) -> type[TimeSeriesResponse]: return TimeSeriesResponse def get_resolver( diff --git a/snuba/web/rpc/v1/endpoint_trace_item_attribute_names.py b/snuba/web/rpc/v1/endpoint_trace_item_attribute_names.py index 49835fb0d76..b61dff60066 100644 --- a/snuba/web/rpc/v1/endpoint_trace_item_attribute_names.py +++ b/snuba/web/rpc/v1/endpoint_trace_item_attribute_names.py @@ -1,5 +1,4 @@ import uuid -from typing import Type from google.protobuf.json_format import MessageToDict from sentry_protos.snuba.v1.endpoint_trace_item_attributes_pb2 import ( @@ -123,23 +122,23 @@ def _add_substring_match_optimization( if request.type == AttributeKey.Type.TYPE_STRING: return and_cond(condition, f.arrayExists(like_lambda, column("attributes_string"))) - elif request.type in ( + if request.type in ( AttributeKey.Type.TYPE_FLOAT, AttributeKey.Type.TYPE_DOUBLE, AttributeKey.Type.TYPE_INT, ): return and_cond(condition, f.arrayExists(like_lambda, column("attributes_float"))) - elif request.type == AttributeKey.Type.TYPE_BOOLEAN: + if request.type == AttributeKey.Type.TYPE_BOOLEAN: return and_cond(condition, f.arrayExists(like_lambda, column("attributes_bool"))) - else: # TYPE_UNSPECIFIED - check all arrays with OR - return and_cond( - condition, - or_cond( - f.arrayExists(like_lambda, column("attributes_string")), - f.arrayExists(like_lambda, column("attributes_float")), - f.arrayExists(like_lambda, column("attributes_bool")), - ), - ) + # TYPE_UNSPECIFIED - check all arrays with OR + return and_cond( + condition, + or_cond( + f.arrayExists(like_lambda, column("attributes_string")), + f.arrayExists(like_lambda, column("attributes_float")), + f.arrayExists(like_lambda, column("attributes_bool")), + ), + ) def get_co_occurring_attributes( @@ -377,11 +376,11 @@ def version(cls) -> str: return "v1" @classmethod - def request_class(cls) -> Type[TraceItemAttributeNamesRequest]: + def request_class(cls) -> type[TraceItemAttributeNamesRequest]: return TraceItemAttributeNamesRequest @classmethod - def response_class(cls) -> Type[TraceItemAttributeNamesResponse]: + def response_class(cls) -> type[TraceItemAttributeNamesResponse]: return TraceItemAttributeNamesResponse def _build_response( diff --git a/snuba/web/rpc/v1/endpoint_trace_item_details.py b/snuba/web/rpc/v1/endpoint_trace_item_details.py index 232326661c2..bb7c3679feb 100644 --- a/snuba/web/rpc/v1/endpoint_trace_item_details.py +++ b/snuba/web/rpc/v1/endpoint_trace_item_details.py @@ -1,5 +1,6 @@ import uuid -from typing import Any, Dict, Iterable, Tuple, Type +from collections.abc import Iterable +from typing import Any from google.protobuf.json_format import MessageToDict from google.protobuf.timestamp_pb2 import Timestamp @@ -131,8 +132,8 @@ def _build_snuba_request(request: TraceItemDetailsRequest) -> SnubaRequest: def _convert_results( - data: Iterable[Dict[str, Any]], -) -> Tuple[str, Timestamp, list[TraceItemDetailsAttribute]]: + data: Iterable[dict[str, Any]], +) -> tuple[str, Timestamp, list[TraceItemDetailsAttribute]]: row = next(iter(data)) item_id = row.pop("hex_item_id") dt = row.pop("timestamp") @@ -140,14 +141,13 @@ def _convert_results( timestamp.FromSeconds(dt) attrs = [] - if (val := row.pop("trace_id")) is not None: - if val != "0" * 32: - attrs.append( - TraceItemDetailsAttribute( - name="sentry.trace_id", - value=AttributeValue(val_str=str(uuid.UUID(val))), - ) + if (val := row.pop("trace_id")) is not None and val != "0" * 32: + attrs.append( + TraceItemDetailsAttribute( + name="sentry.trace_id", + value=AttributeValue(val_str=str(uuid.UUID(val))), ) + ) if (val := row.pop("organization_id")) is not None: attrs.append( TraceItemDetailsAttribute( @@ -200,11 +200,11 @@ def version(cls) -> str: return "v1" @classmethod - def request_class(cls) -> Type[TraceItemDetailsRequest]: + def request_class(cls) -> type[TraceItemDetailsRequest]: return TraceItemDetailsRequest @classmethod - def response_class(cls) -> Type[TraceItemDetailsResponse]: + def response_class(cls) -> type[TraceItemDetailsResponse]: return TraceItemDetailsResponse def _execute(self, in_msg: TraceItemDetailsRequest) -> TraceItemDetailsResponse: @@ -216,13 +216,12 @@ def _execute(self, in_msg: TraceItemDetailsRequest) -> TraceItemDetailsResponse: raise BadSnubaRPCRequestException("This endpoint requires item_id to be set.") if in_msg.trace_id == "": raise BadSnubaRPCRequestException("This endpoint requires trace_id to be set.") - else: - try: - _ = uuid.UUID(in_msg.trace_id) - except ValueError: - raise BadSnubaRPCRequestException( - "This endpoint requires trace_id to be a valid UUID." - ) + try: + _ = uuid.UUID(in_msg.trace_id) + except ValueError as e: + raise BadSnubaRPCRequestException( + "This endpoint requires trace_id to be a valid UUID." + ) from e snuba_request = _build_snuba_request(in_msg) res = run_query( @@ -232,11 +231,11 @@ def _execute(self, in_msg: TraceItemDetailsRequest) -> TraceItemDetailsResponse: ) try: item_id, timestamp, attributes = _convert_results(res.result.get("data", [])) - except StopIteration: + except StopIteration as e: raise RPCRequestException( status_code=404, message=f"no item found with ID={in_msg.item_id}", - ) + ) from e response_meta = extract_response_meta( in_msg.meta.request_id, in_msg.meta.debug, diff --git a/snuba/web/rpc/v1/endpoint_trace_item_stats.py b/snuba/web/rpc/v1/endpoint_trace_item_stats.py index efa91ec5948..3a1dfd9b79c 100644 --- a/snuba/web/rpc/v1/endpoint_trace_item_stats.py +++ b/snuba/web/rpc/v1/endpoint_trace_item_stats.py @@ -1,5 +1,3 @@ -from typing import Type - from sentry_protos.snuba.v1.downsampled_storage_pb2 import DownsampledStorageConfig from sentry_protos.snuba.v1.endpoint_trace_item_stats_pb2 import ( TraceItemStatsRequest, @@ -16,9 +14,9 @@ def downgrade_tier(tier: Tier) -> Tier: if tier == Tier.TIER_1: return Tier.TIER_8 - elif tier == Tier.TIER_8: + if tier == Tier.TIER_8: return Tier.TIER_64 - elif tier == Tier.TIER_64: + if tier == Tier.TIER_64: return Tier.TIER_512 return tier @@ -29,11 +27,11 @@ def version(cls) -> str: return "v1" @classmethod - def request_class(cls) -> Type[TraceItemStatsRequest]: + def request_class(cls) -> type[TraceItemStatsRequest]: return TraceItemStatsRequest @classmethod - def response_class(cls) -> Type[TraceItemStatsResponse]: + def response_class(cls) -> type[TraceItemStatsResponse]: return TraceItemStatsResponse def get_resolver( diff --git a/snuba/web/rpc/v1/endpoint_trace_item_table.py b/snuba/web/rpc/v1/endpoint_trace_item_table.py index 5f2a0cceac5..e77c86563f6 100644 --- a/snuba/web/rpc/v1/endpoint_trace_item_table.py +++ b/snuba/web/rpc/v1/endpoint_trace_item_table.py @@ -1,5 +1,3 @@ -from typing import Type - from sentry_protos.snuba.v1.downsampled_storage_pb2 import DownsampledStorageConfig from sentry_protos.snuba.v1.endpoint_trace_item_table_pb2 import ( Column, @@ -55,8 +53,8 @@ def _validate_select_and_groupby(in_msg: TraceItemTableRequest) -> None: if not in_msg.columns: raise BadSnubaRPCRequestException("At least one column must be specified in the request") - non_aggregted_columns = set([c.key.name for c in in_msg.columns if c.HasField("key")]) - grouped_by_columns = set([c.name for c in in_msg.group_by]) + non_aggregted_columns = {c.key.name for c in in_msg.columns if c.HasField("key")} + grouped_by_columns = {c.name for c in in_msg.group_by} vis = ContainsAggregateVisitor() TraceItemTableRequestWrapper(in_msg).accept(vis) @@ -147,7 +145,7 @@ def version(cls) -> str: return "v1" @classmethod - def request_class(cls) -> Type[TraceItemTableRequest]: + def request_class(cls) -> type[TraceItemTableRequest]: return TraceItemTableRequest def get_resolver( @@ -159,7 +157,7 @@ def get_resolver( ) @classmethod - def response_class(cls) -> Type[TraceItemTableResponse]: + def response_class(cls) -> type[TraceItemTableResponse]: return TraceItemTableResponse def _execute(self, in_msg: TraceItemTableRequest) -> TraceItemTableResponse: diff --git a/snuba/web/rpc/v1/resolvers/R_eap_items/resolver_time_series.py b/snuba/web/rpc/v1/resolvers/R_eap_items/resolver_time_series.py index 5e03379def2..da5a36d97d1 100644 --- a/snuba/web/rpc/v1/resolvers/R_eap_items/resolver_time_series.py +++ b/snuba/web/rpc/v1/resolvers/R_eap_items/resolver_time_series.py @@ -1,8 +1,9 @@ import uuid from collections import defaultdict +from collections.abc import Callable, Iterable from dataclasses import replace from datetime import datetime -from typing import Any, Callable, Dict, Iterable, Optional +from typing import Any import sentry_sdk from google.protobuf.json_format import MessageToDict @@ -86,7 +87,7 @@ def _get_attribute_key_to_expression_function( def _convert_result_timeseries( - request: TimeSeriesRequest, data: list[Dict[str, Any]] + request: TimeSeriesRequest, data: list[dict[str, Any]] ) -> Iterable[TimeSeries]: """This function takes the results of the clickhouse query and converts it to a list of TimeSeries objects. It also handles zerofilling data points where data was not present for a specific bucket. @@ -129,9 +130,9 @@ def _convert_result_timeseries( """ # the aggregations that we will include in the result - aggregation_labels = set([expr.label for expr in request.expressions]) + aggregation_labels = {expr.label for expr in request.expressions} - group_by_labels = set([attr.name for attr in request.group_by]) + group_by_labels = {attr.name for attr in request.group_by} # create a mapping with (all the group by attribute key,val pairs as strs, label name) # In the example in the docstring it would look like: @@ -144,7 +145,7 @@ def _convert_result_timeseries( # time_converted_to_integer_timestamp: row_data_for_that_time_bucket # } # } - result_timeseries_timestamp_to_row: defaultdict[tuple[str, str], dict[int, Dict[str, Any]]] = ( + result_timeseries_timestamp_to_row: defaultdict[tuple[str, str], dict[int, dict[str, Any]]] = ( defaultdict(dict) ) @@ -217,9 +218,9 @@ def _remove_non_requested_expressions( expressions: Iterable[ProtoExpression], result_timeseries: dict[tuple[str, str], TimeSeries], ) -> None: - requested_expressions = set([expr.label for expr in expressions]) + requested_expressions = {expr.label for expr in expressions} to_remove = [] - for timeseries_key in result_timeseries.keys(): + for timeseries_key in result_timeseries: if timeseries_key[1] not in requested_expressions: to_remove.append(timeseries_key) for timeseries_key in to_remove: @@ -339,7 +340,7 @@ def _proto_expression_to_ast_expression( def build_query( - request: TimeSeriesRequest, sampling_tier: Optional[Tier] = None, timer: Optional[Timer] = None + request: TimeSeriesRequest, sampling_tier: Tier | None = None, timer: Timer | None = None ) -> Query: entity = Entity( key=EntityKey("eap_items"), @@ -442,8 +443,8 @@ def build_query( def _build_snuba_request( request: TimeSeriesRequest, query_settings: HTTPQuerySettings, - sampling_tier: Optional[Tier] = None, - timer: Optional[Timer] = None, + sampling_tier: Tier | None = None, + timer: Timer | None = None, ) -> SnubaRequest: if request.meta.trace_item_type == TraceItemType.TRACE_ITEM_TYPE_LOG: team = "ourlogs" diff --git a/snuba/web/rpc/v1/resolvers/R_eap_items/resolver_trace_item_stats.py b/snuba/web/rpc/v1/resolvers/R_eap_items/resolver_trace_item_stats.py index 01b0df008fb..3e6f296cbf2 100644 --- a/snuba/web/rpc/v1/resolvers/R_eap_items/resolver_trace_item_stats.py +++ b/snuba/web/rpc/v1/resolvers/R_eap_items/resolver_trace_item_stats.py @@ -1,7 +1,8 @@ import uuid from collections import OrderedDict +from collections.abc import Iterable from datetime import datetime -from typing import Any, Dict, Iterable, Tuple +from typing import Any from google.protobuf.json_format import MessageToDict from google.protobuf.timestamp_pb2 import Timestamp @@ -65,12 +66,12 @@ def _transform_attr_distribution_results( - results: Iterable[Dict[str, Any]], + results: Iterable[dict[str, Any]], request_meta: RequestMeta, ) -> Iterable[AttributeDistribution]: # Maintain the order of keys, so it is in descending order # of most prevelant key-value pair. - res: OrderedDict[Tuple[str, str], AttributeDistribution] = OrderedDict() + res: OrderedDict[tuple[str, str], AttributeDistribution] = OrderedDict() for row in results: attr_key = row["attr_key"] diff --git a/snuba/web/rpc/v1/resolvers/R_eap_items/resolver_trace_item_table.py b/snuba/web/rpc/v1/resolvers/R_eap_items/resolver_trace_item_table.py index 8d7260fa573..0ffa722f3e3 100644 --- a/snuba/web/rpc/v1/resolvers/R_eap_items/resolver_trace_item_table.py +++ b/snuba/web/rpc/v1/resolvers/R_eap_items/resolver_trace_item_table.py @@ -1,7 +1,8 @@ import uuid +from collections.abc import Callable, Sequence from dataclasses import replace from itertools import islice -from typing import Any, Callable, List, Optional, Sequence +from typing import Any import sentry_sdk from google.protobuf.json_format import MessageToDict @@ -155,7 +156,7 @@ def transform_expressions(expression: Expression) -> Expression: ) return f.transform( f.CAST(f.ifNull(attribute_expression, literal("")), "String"), - literals_array(None, [literal(k) for k in context.value_map.keys()]), + literals_array(None, [literal(k) for k in context.value_map]), literals_array(None, [literal(v) for v in context.value_map.values()]), literal(context.default_value if context.default_value != "" else "unknown"), alias=context.to_column_name, @@ -191,7 +192,7 @@ def aggregation_filter_to_expression( raise BadSnubaRPCRequestException( "Cannot use formula and conditional aggregation in the same ComparisonFilter" ) - elif agg_filter.comparison_filter.HasField("formula"): + if agg_filter.comparison_filter.HasField("formula"): return op_expr( _formula_to_expression(agg_filter.comparison_filter.formula, request_meta), agg_filter.comparison_filter.val, @@ -231,7 +232,7 @@ def aggregation_filter_to_expression( def _convert_order_by( - groupby: List[Expression], + groupby: list[Expression], order_by: Sequence[TraceItemTableRequest.OrderBy], request_meta: RequestMeta, ) -> Sequence[OrderBy]: @@ -334,15 +335,18 @@ def _get_reliability_context_columns( context_cols.extend(_get_reliability_context_columns(col, request_meta)) # Note: 'match' is a Python keyword, so use getattr - for col in [getattr(conditional, "match"), conditional.default]: - if not col.HasField("formula") and not col.HasField("conditional_formula"): - if col.label: - context_cols.append( - SelectedExpression( - name=col.label, - expression=_column_to_expression(col, request_meta), - ) + for col in [conditional.match, conditional.default]: + if ( + not col.HasField("formula") + and not col.HasField("conditional_formula") + and col.label + ): + context_cols.append( + SelectedExpression( + name=col.label, + expression=_column_to_expression(col, request_meta), ) + ) context_cols.extend(_get_reliability_context_columns(col, request_meta)) return context_cols @@ -445,7 +449,7 @@ def _conditional_formula_to_expression( comparison_expr = COMPARISON_OP_TO_EXPR[condition.op](left_expr, right_expr) # 'match' is the value when condition is true, 'default' is when false # Note: 'match' is a Python keyword in 3.10+, but protobuf accesses it as an attribute - match_expr = _column_to_expression(getattr(conditional_formula, "match"), request_meta) + match_expr = _column_to_expression(conditional_formula.match, request_meta) default_expr = _column_to_expression(conditional_formula.default, request_meta) return if_cond(comparison_expr, match_expr, default_expr) @@ -457,7 +461,7 @@ def _column_to_expression(column: Column, request_meta: RequestMeta) -> Expressi """ if column.HasField("key"): return attribute_key_to_expression(column.key) - elif column.HasField("conditional_aggregation"): + if column.HasField("conditional_aggregation"): function_expr = aggregation_to_expression( column.conditional_aggregation, attribute_key_to_expression, @@ -483,23 +487,22 @@ def _column_to_expression(column: Column, request_meta: RequestMeta) -> Expressi # aggregation label may not be set and the column label takes priority anyways. function_expr = replace(function_expr, alias=column.label) return function_expr - elif column.HasField("formula"): + if column.HasField("formula"): formula_expr = _formula_to_expression(column.formula, request_meta) formula_expr = replace(formula_expr, alias=column.label) return formula_expr - elif column.HasField("conditional_formula"): + if column.HasField("conditional_formula"): conditional_expr = _conditional_formula_to_expression( column.conditional_formula, request_meta, ) conditional_expr = replace(conditional_expr, alias=column.label) return conditional_expr - elif column.HasField("literal"): + if column.HasField("literal"): return literal(column.literal.val_double) - else: - raise BadSnubaRPCRequestException( - "Column is not one of: aggregate, attribute key, formula, or conditional_formula" - ) + raise BadSnubaRPCRequestException( + "Column is not one of: aggregate, attribute key, formula, or conditional_formula" + ) def _get_offset_from_page_token(page_token: PageToken | None) -> int: @@ -513,8 +516,8 @@ def _get_offset_from_page_token(page_token: PageToken | None) -> int: def build_query( request: TraceItemTableRequest, time_window: TimeWindow | None = None, - sampling_tier: Optional[Tier] = None, - timer: Optional[Timer] = None, + sampling_tier: Tier | None = None, + timer: Timer | None = None, ) -> Query: entity = Entity( key=EntityKey("eap_items"), @@ -538,7 +541,7 @@ def build_query( item_type_conds = [f.equals(snuba_column("item_type"), request.meta.trace_item_type)] # Handle cross item queries by first getting trace IDs - additional_conditions: List[Expression] = [] + additional_conditions: list[Expression] = [] if request.trace_filters and timer is not None and sampling_tier is not None: trace_ids_sql, _ = get_trace_ids_sql_for_cross_item_query( request, request.meta, list(request.trace_filters), sampling_tier, timer @@ -614,30 +617,27 @@ def _get_page_token( return FlexibleTimeWindowPageWithFilters.create( request, time_window, response ).page_token - else: - if time_window.start_timestamp.seconds <= original_time_window.start_timestamp.seconds: - # this is the last window because our start timestamp is the same as the original start timestamp - # we tell the client that there is no more data to fetch - return PageToken(end_pagination=True) - else: - # there are no more rows in this window so we return the next window - # return the next window where the end timestamp is the start timestamp and the start timestamp is the original start timestamp - # the routing strategy will properly truncate the time window of the next request - return FlexibleTimeWindowPageWithFilters.create( - request, - TimeWindow(original_time_window.start_timestamp, time_window.start_timestamp), - response, - ).page_token - else: - return PageToken(offset=request.page_token.offset + num_rows_in_response) + if time_window.start_timestamp.seconds <= original_time_window.start_timestamp.seconds: + # this is the last window because our start timestamp is the same as the original start timestamp + # we tell the client that there is no more data to fetch + return PageToken(end_pagination=True) + # there are no more rows in this window so we return the next window + # return the next window where the end timestamp is the start timestamp and the start timestamp is the original start timestamp + # the routing strategy will properly truncate the time window of the next request + return FlexibleTimeWindowPageWithFilters.create( + request, + TimeWindow(original_time_window.start_timestamp, time_window.start_timestamp), + response, + ).page_token + return PageToken(offset=request.page_token.offset + num_rows_in_response) def _build_snuba_request( request: TraceItemTableRequest, query_settings: HTTPQuerySettings, time_window: TimeWindow | None = None, - sampling_tier: Optional[Tier] = None, - timer: Optional[Timer] = None, + sampling_tier: Tier | None = None, + timer: Timer | None = None, ) -> SnubaRequest: if request.meta.trace_item_type == TraceItemType.TRACE_ITEM_TYPE_LOG: team = "ourlogs" diff --git a/snuba/web/rpc/v1/resolvers/common/aggregation.py b/snuba/web/rpc/v1/resolvers/common/aggregation.py index 5cc7a0135cb..91c958ff6c9 100644 --- a/snuba/web/rpc/v1/resolvers/common/aggregation.py +++ b/snuba/web/rpc/v1/resolvers/common/aggregation.py @@ -3,9 +3,10 @@ import math from abc import ABC, abstractmethod from bisect import bisect_left +from collections.abc import Callable from dataclasses import dataclass from functools import cached_property -from typing import Any, Callable, Dict, List, Optional +from typing import Any from sentry_protos.snuba.v1.attribute_conditional_aggregation_pb2 import ( AttributeConditionalAggregation, @@ -130,12 +131,11 @@ def _resolve_field_and_existence( else: raise RuntimeError("expected existence_checks to never be empty, but it is") return field, existence - elif aggregation.key.type == AttributeKey.Type.TYPE_ARRAY: + if aggregation.key.type == AttributeKey.Type.TYPE_ARRAY: field = type_array_to_stored_array_json_path(aggregation.key) return field, f.notEmpty(field) - else: - field = attribute_key_to_expression(aggregation.key) - return field, get_field_existence_expression(field) + field = attribute_key_to_expression(aggregation.key) + return field, get_field_existence_expression(field) @dataclass(frozen=True) @@ -158,7 +158,7 @@ def reliability(self) -> Reliability.ValueType: @staticmethod def from_row( column_label: str, - row_data: Dict[str, Any], + row_data: dict[str, Any], ) -> ExtrapolationContext: value = row_data[column_label] is_extrapolated = False @@ -237,7 +237,7 @@ def reliability(self) -> Reliability.ValueType: # than the value, so reliability is low. if self.confidence_interval == 0: return Reliability.RELIABILITY_HIGH - elif self.value == 0: + if self.value == 0: return Reliability.RELIABILITY_LOW if abs(self.confidence_interval / self.value) <= CONFIDENCE_INTERVAL_THRESHOLD: @@ -291,7 +291,7 @@ class CustomColumnInformation: # A column that this custom column depends on or attached to. # For example, if we are computing the confidence interval for an aggregation column, we need to know for which column we are computing a confidence interval. - referenced_column: Optional[str] + referenced_column: str | None # Metadata about the custom column that can be used to encode additional information in the column. # E.g. the aggregation function type for the confidence interval column. @@ -306,7 +306,7 @@ def to_alias(self) -> str: return alias @staticmethod - def from_alias(alias: str) -> "CustomColumnInformation": + def from_alias(alias: str) -> CustomColumnInformation: if not alias.startswith(CUSTOM_COLUMN_PREFIX): raise ValueError(f"Alias {alias} does not start with {CUSTOM_COLUMN_PREFIX}") @@ -330,18 +330,19 @@ def _get_sampling_weight_expression( if extrapolation_mode == ExtrapolationMode.EXTRAPOLATION_MODE_CLIENT_ONLY: # Use client sample rate attribute, convert to weight (1/rate) return f.divide(1, client_sample_rate_column) - elif extrapolation_mode == ExtrapolationMode.EXTRAPOLATION_MODE_SERVER_ONLY: + if extrapolation_mode == ExtrapolationMode.EXTRAPOLATION_MODE_SERVER_ONLY: # Use server sample rate attribute, convert to weight (1/rate) return f.divide(1, server_sample_rate_column) - else: - # Default behavior for existing modes - always use sampling_factor now - return f.divide(1, sampling_factor_column) + # Default behavior for existing modes - always use sampling_factor now + return f.divide(1, sampling_factor_column) def get_attribute_confidence_interval_alias( aggregation: AttributeAggregation | AttributeConditionalAggregation, - additional_metadata: dict[str, str] = {}, + additional_metadata: dict[str, str] | None = None, ) -> str | None: + if additional_metadata is None: + additional_metadata = {} function_alias_map = { Function.FUNCTION_COUNT: "count", Function.FUNCTION_AVG: "avg", @@ -423,7 +424,7 @@ def get_count_column( ) -def _get_possible_percentiles(percentile: float, granularity: float, width: float) -> List[float]: +def _get_possible_percentiles(percentile: float, granularity: float, width: float) -> list[float]: """ Returns a list of possible percentiles to use for the confidence interval calculation from the range percentile - width to percentile + width, with a granularity of granularity. diff --git a/snuba/web/rpc/v1/resolvers/common/cross_item_queries.py b/snuba/web/rpc/v1/resolvers/common/cross_item_queries.py index fc6fb1389d7..0ad21cabde5 100644 --- a/snuba/web/rpc/v1/resolvers/common/cross_item_queries.py +++ b/snuba/web/rpc/v1/resolvers/common/cross_item_queries.py @@ -1,7 +1,7 @@ import uuid from google.protobuf.json_format import MessageToDict -from proto import Message # type: ignore +from proto import Message # type: ignore[import-untyped] from sentry_protos.snuba.v1.endpoint_get_traces_pb2 import GetTracesRequest from sentry_protos.snuba.v1.request_common_pb2 import ( RequestMeta, @@ -67,7 +67,7 @@ def get_trace_ids_sql_for_cross_item_query( """ filter_expressions = [] if trace_filters: - converted_trace_filters = [trace_filter for trace_filter in trace_filters] + converted_trace_filters = list(trace_filters) if isinstance(trace_filters[0], GetTracesRequest.TraceFilter): converted_trace_filters = [ TraceItemFilterWithType( diff --git a/snuba/web/rpc/v1/resolvers/common/formula_reliability.py b/snuba/web/rpc/v1/resolvers/common/formula_reliability.py index 31395a6fb78..fe48e99a71f 100644 --- a/snuba/web/rpc/v1/resolvers/common/formula_reliability.py +++ b/snuba/web/rpc/v1/resolvers/common/formula_reliability.py @@ -1,6 +1,6 @@ from dataclasses import dataclass from datetime import datetime -from typing import Any, Dict +from typing import Any from google.protobuf.timestamp_pb2 import Timestamp from sentry_protos.snuba.v1.endpoint_time_series_pb2 import TimeSeriesRequest @@ -23,7 +23,7 @@ class FormulaReliabilityCalculator: def __init__( self, request: TimeSeriesRequest, - clickhouse_data: list[Dict[str, Any]], + clickhouse_data: list[dict[str, Any]], time_buckets: list[Timestamp], ) -> None: """ diff --git a/snuba/web/rpc/v1/resolvers/common/trace_item_table.py b/snuba/web/rpc/v1/resolvers/common/trace_item_table.py index dfc325ed00f..b399c24ba35 100644 --- a/snuba/web/rpc/v1/resolvers/common/trace_item_table.py +++ b/snuba/web/rpc/v1/resolvers/common/trace_item_table.py @@ -1,7 +1,8 @@ import json import re from collections import defaultdict -from typing import Any, Callable, Dict, Iterable +from collections.abc import Callable, Iterable +from typing import Any from sentry_protos.snuba.v1.endpoint_trace_item_table_pb2 import ( Column, @@ -72,20 +73,17 @@ def _get_converter_for_type( """Returns a converter function for the given attribute type.""" if key_type == AttributeKey.TYPE_BOOLEAN: return lambda x: AttributeValue(val_bool=bool(x)) - elif key_type == AttributeKey.TYPE_STRING: + if key_type == AttributeKey.TYPE_STRING: return lambda x: AttributeValue(val_str=str(x)) - elif key_type == AttributeKey.TYPE_INT: + if key_type == AttributeKey.TYPE_INT: return lambda x: AttributeValue(val_int=int(x)) - elif key_type == AttributeKey.TYPE_FLOAT: + if key_type == AttributeKey.TYPE_FLOAT: return lambda x: AttributeValue(val_float=float(x)) - elif key_type == AttributeKey.TYPE_DOUBLE: + if key_type == AttributeKey.TYPE_DOUBLE: return lambda x: AttributeValue(val_double=float(x)) - elif key_type == AttributeKey.TYPE_ARRAY: + if key_type == AttributeKey.TYPE_ARRAY: return _array_raw_to_attribute_value - else: - raise BadSnubaRPCRequestException( - f"unknown attribute type: {AttributeKey.Type.Name(key_type)}" - ) + raise BadSnubaRPCRequestException(f"unknown attribute type: {AttributeKey.Type.Name(key_type)}") def _get_double_converter() -> Callable[[Any], AttributeValue]: @@ -93,7 +91,7 @@ def _get_double_converter() -> Callable[[Any], AttributeValue]: return lambda x: AttributeValue(val_double=float(x)) -def _add_converter(column: Column, converters: Dict[str, Callable[[Any], AttributeValue]]) -> None: +def _add_converter(column: Column, converters: dict[str, Callable[[Any], AttributeValue]]) -> None: if column.HasField("key"): converters[column.label] = _get_converter_for_type(column.key.type) elif column.HasField("aggregation"): @@ -124,7 +122,7 @@ def _add_converter(column: Column, converters: Dict[str, Callable[[Any], Attribu _add_converter(conditional.condition.left, converters) _add_converter(conditional.condition.right, converters) if conditional.HasField("match"): - _add_converter(getattr(conditional, "match"), converters) + _add_converter(conditional.match, converters) if conditional.HasField("default"): _add_converter(conditional.default, converters) elif column.HasField("literal"): @@ -137,12 +135,12 @@ def _add_converter(column: Column, converters: Dict[str, Callable[[Any], Attribu def get_converters_for_columns( columns: Iterable[Column], -) -> Dict[str, Callable[[Any], AttributeValue]]: +) -> dict[str, Callable[[Any], AttributeValue]]: """ Returns a dictionary of column labels to their corresponding converters. Converters are functions that convert a value returned by a clickhouse query to an AttributeValue. """ - converters: Dict[str, Callable[[Any], AttributeValue]] = {} + converters: dict[str, Callable[[Any], AttributeValue]] = {} for column in columns: _add_converter(column, converters) return converters @@ -158,7 +156,7 @@ def _is_sub_column(result_column_name: str, column: Column) -> bool: def _get_reliabilities_for_formula( - column: Column, res: Dict[str, TraceItemColumnValues] + column: Column, res: dict[str, TraceItemColumnValues] ) -> list[Reliability.ValueType]: """ Compute and return the reliabilities for the given formula column, @@ -205,14 +203,14 @@ def _get_reliabilities_for_formula( def convert_results( - request: TraceItemTableRequest, data: Iterable[Dict[str, Any]] + request: TraceItemTableRequest, data: Iterable[dict[str, Any]] ) -> list[TraceItemColumnValues]: converters = get_converters_for_columns(request.columns) res: defaultdict[str, TraceItemColumnValues] = defaultdict(TraceItemColumnValues) for row in data: for column_name, value in row.items(): - if column_name in converters.keys(): + if column_name in converters: extrapolation_context = ExtrapolationContext.from_row(column_name, row) res[column_name].attribute_name = column_name if value is None: @@ -236,14 +234,11 @@ def convert_results( res[column.label].reliabilities.append(e) # remove any columns that were not explicitly requested by the user in the request - requested_column_labels = set(e.label for e in request.columns) + requested_column_labels = {e.label for e in request.columns} to_delete = list(filter(lambda k: k not in requested_column_labels, res.keys())) for name in to_delete: del res[name] column_ordering = {column.label: i for i, column in enumerate(request.columns)} - return list( - # we return the columns in the order they were requested - sorted(res.values(), key=lambda c: column_ordering.__getitem__(c.attribute_name)) - ) + return sorted(res.values(), key=lambda c: column_ordering.__getitem__(c.attribute_name)) diff --git a/snuba/web/rpc/v1/trace_item_attribute_values.py b/snuba/web/rpc/v1/trace_item_attribute_values.py index 7a3057cb8d9..ea7953523fe 100644 --- a/snuba/web/rpc/v1/trace_item_attribute_values.py +++ b/snuba/web/rpc/v1/trace_item_attribute_values.py @@ -1,5 +1,4 @@ import uuid -from typing import Type from google.protobuf.json_format import MessageToDict from sentry_protos.snuba.v1.endpoint_trace_item_attributes_pb2 import ( @@ -177,11 +176,11 @@ def version(cls) -> str: return "v1" @classmethod - def request_class(cls) -> Type[TraceItemAttributeValuesRequest]: + def request_class(cls) -> type[TraceItemAttributeValuesRequest]: return TraceItemAttributeValuesRequest @classmethod - def response_class(cls) -> Type[TraceItemAttributeValuesResponse]: + def response_class(cls) -> type[TraceItemAttributeValuesResponse]: return TraceItemAttributeValuesResponse def _execute(self, in_msg: TraceItemAttributeValuesRequest) -> TraceItemAttributeValuesResponse: diff --git a/snuba/web/rpc/v1/visitors/sparse_aggregate_attribute_transformer.py b/snuba/web/rpc/v1/visitors/sparse_aggregate_attribute_transformer.py index c3b23598f27..cecca405b14 100644 --- a/snuba/web/rpc/v1/visitors/sparse_aggregate_attribute_transformer.py +++ b/snuba/web/rpc/v1/visitors/sparse_aggregate_attribute_transformer.py @@ -42,31 +42,30 @@ def transform(self) -> TraceItemTableRequest: # get all the keys that are used in aggregates agg_keys = [] for column in self.req.columns: - if column.WhichOneof("column") == "conditional_aggregation": - # not supported for KeyExpression in conditional_aggregate - if column.conditional_aggregation.key != AttributeKey(): - agg_keys.append(column.conditional_aggregation.key) + # not supported for KeyExpression in conditional_aggregate + if ( + column.WhichOneof("column") == "conditional_aggregation" + and column.conditional_aggregation.key != AttributeKey() + ): + agg_keys.append(column.conditional_aggregation.key) if len(agg_keys) == 0: return self.req - else: - # add the exists filters for the agg_keys - filter_to_add = TraceItemFilter( - or_filter=OrFilter( - filters=[ - TraceItemFilter(exists_filter=ExistsFilter(key=key)) for key in agg_keys - ] - ) + # add the exists filters for the agg_keys + filter_to_add = TraceItemFilter( + or_filter=OrFilter( + filters=[TraceItemFilter(exists_filter=ExistsFilter(key=key)) for key in agg_keys] + ) + ) + # combine the new filters with the existing one + if self.req.HasField("filter"): + new_filter = TraceItemFilter( + and_filter=AndFilter(filters=[self.req.filter, filter_to_add]) ) - # combine the new filters with the existing one - if self.req.HasField("filter"): - new_filter = TraceItemFilter( - and_filter=AndFilter(filters=[self.req.filter, filter_to_add]) - ) - else: - new_filter = filter_to_add + else: + new_filter = filter_to_add - new_req = TraceItemTableRequest() - new_req.CopyFrom(self.req) - new_req.filter.CopyFrom(new_filter) - return new_req + new_req = TraceItemTableRequest() + new_req.CopyFrom(self.req) + new_req.filter.CopyFrom(new_filter) + return new_req diff --git a/snuba/web/rpc/v1/visitors/time_series_request_visitor.py b/snuba/web/rpc/v1/visitors/time_series_request_visitor.py index b46318c58cc..82fc42f7490 100644 --- a/snuba/web/rpc/v1/visitors/time_series_request_visitor.py +++ b/snuba/web/rpc/v1/visitors/time_series_request_visitor.py @@ -126,11 +126,14 @@ def visit_TraceItemFilter(self, node: TraceItemFilter) -> None: self.visit(f) elif node.HasField("comparison_filter"): k = node.comparison_filter.key - if k.name == "sentry.timestamp" and k.type == AttributeKey.TYPE_STRING: - if get_config("eap.reject_string_timestamp_filters", 1): - raise BadSnubaRPCRequestException( - "sentry.timestamp can only be compared to TYPE_INT or TYPE_DOUBLE, got TYPE_STRING" - ) + if ( + k.name == "sentry.timestamp" + and k.type == AttributeKey.TYPE_STRING + and get_config("eap.reject_string_timestamp_filters", 1) + ): + raise BadSnubaRPCRequestException( + "sentry.timestamp can only be compared to TYPE_INT or TYPE_DOUBLE, got TYPE_STRING" + ) class GetSubformulaLabelsVisitor(RequestVisitor): diff --git a/snuba/web/rpc/v1/visitors/visitor_v2.py b/snuba/web/rpc/v1/visitors/visitor_v2.py index 1b73492cc21..45eedf2b138 100644 --- a/snuba/web/rpc/v1/visitors/visitor_v2.py +++ b/snuba/web/rpc/v1/visitors/visitor_v2.py @@ -4,7 +4,7 @@ from google.protobuf.message import Message -class RequestVisitor(ABC): +class RequestVisitor(ABC): # noqa: B024 dynamic visit dispatch via getattr; ABC marks it non-instantiable by design """ When you call visitor.visit(msg), the visitor will call the appropriate visit function visit_TraceItemTableRequest, visit_Column, visit_AttributeAggregation, etc. diff --git a/snuba/web/views.py b/snuba/web/views.py index 96bd76b1e5f..cd5ae44e145 100644 --- a/snuba/web/views.py +++ b/snuba/web/views.py @@ -4,18 +4,10 @@ import functools import logging import time +from collections.abc import Callable, Mapping, MutableMapping, MutableSequence, Sequence from datetime import datetime from typing import ( Any, - Callable, - Dict, - Mapping, - MutableMapping, - MutableSequence, - Sequence, - Text, - Tuple, - Union, ) from uuid import UUID @@ -79,7 +71,7 @@ logger = logging.getLogger("snuba.api") # Flask wants a Dict, not a Mapping -RespTuple = Tuple[Text, int, Dict[Any, Any]] +RespTuple = tuple[str, int, dict[Any, Any]] def truncate_dataset(dataset: Dataset) -> None: @@ -114,7 +106,7 @@ def truncate_dataset(dataset: Dataset) -> None: @application.errorhandler(InvalidJsonRequestException) def handle_invalid_json(exception: InvalidJsonRequestException) -> Response: cause = getattr(exception, "__cause__", None) - data: Mapping[str, Mapping[str, Union[str, Sequence[str]]]] + data: Mapping[str, Mapping[str, str | Sequence[str]]] if isinstance(cause, json.JSONDecodeError): data = {"error": {"type": "json", "message": str(cause)}} elif isinstance(cause, jsonschema.ValidationError): @@ -134,8 +126,7 @@ def default_encode(value: Callable[..., str]) -> str: # generated by callables, rather than constants. if callable(value): return value() - else: - raise TypeError() + raise TypeError() return Response( json.dumps(data, indent=4, default=default_encode), @@ -215,12 +206,12 @@ def health() -> Response: return Response(health_info.body, health_info.status, health_info.content_type) -def parse_request_body(http_request: Request) -> Dict[str, Any]: +def parse_request_body(http_request: Request) -> dict[str, Any]: with sentry_sdk.start_span(description="parse_request_body", op="parse"): metrics.timing("http_request_body_length", len(http_request.data)) try: body = json.loads(http_request.data) - assert isinstance(body, Dict) + assert isinstance(body, dict) return body except json.JSONDecodeError as error: raise JsonDecodeException(str(error)) from error @@ -241,7 +232,7 @@ def _trace_transaction(dataset_name: str) -> None: scope.transaction.set_tag("referrer", http_request.referrer) -def _set_snql_api_error_tags(body: Dict[str, Any], http_referrer: str | None) -> None: +def _set_snql_api_error_tags(body: dict[str, Any], http_referrer: str | None) -> None: """Set Sentry tags for SnQL API error tracking. Tags all errors in the SnQL API with: @@ -274,16 +265,15 @@ def _set_snql_api_error_tags(body: Dict[str, Any], http_referrer: str | None) -> @application.route("/query", methods=["GET", "POST"]) @util.time_request("query") -def unqualified_query_view(*, timer: Timer) -> Union[Response, str, WerkzeugResponse]: +def unqualified_query_view(*, timer: Timer) -> Response | str | WerkzeugResponse: if http_request.method == "GET": return redirect(f"/{settings.DEFAULT_DATASET_NAME}/query", code=302) - elif http_request.method == "POST": + if http_request.method == "POST": body = parse_request_body(http_request) dataset_name = str(body.pop("dataset", settings.DEFAULT_DATASET_NAME)) _trace_transaction(dataset_name) return dataset_query(dataset_name, body, timer) - else: - assert False, "unexpected fallthrough" + raise AssertionError("unexpected fallthrough") @application.route("/rpc//", methods=["POST"]) @@ -291,44 +281,41 @@ def rpc(*, name: str, version: str) -> Response: result_proto = run_rpc_handler(name, version, http_request.data) if isinstance(result_proto, ErrorProto): return Response(result_proto.SerializeToString(), status=result_proto.code) - else: - return Response(result_proto.SerializeToString(), status=200) + return Response(result_proto.SerializeToString(), status=200) @application.route("//snql", methods=["GET", "POST"]) @util.time_request("query") -def snql_dataset_query_view(*, dataset: Dataset, timer: Timer) -> Union[Response, str]: +def snql_dataset_query_view(*, dataset: Dataset, timer: Timer) -> Response | str: if http_request.method == "GET": schema = RequestSchema.build(HTTPQuerySettings) return render_template( "query.html", query_template=json.dumps(schema.generate_template(), indent=4), ) - elif http_request.method == "POST": + if http_request.method == "POST": body = parse_request_body(http_request) dataset_name = get_dataset_name(dataset) _trace_transaction(dataset_name) _set_snql_api_error_tags(body, http_request.referrer) return dataset_query(dataset_name, body, timer) - else: - assert False, "unexpected fallthrough" + raise AssertionError("unexpected fallthrough") @application.route("//mql", methods=["GET", "POST"]) @util.time_request("query", {"mql": "true"}) -def mql_dataset_query_view(*, dataset: Dataset, timer: Timer) -> Union[Response, str]: +def mql_dataset_query_view(*, dataset: Dataset, timer: Timer) -> Response | str: if http_request.method == "POST": dataset_name = get_dataset_name(dataset) body = parse_request_body(http_request) _trace_transaction(dataset_name) return dataset_query(dataset_name, body, timer, is_mql=True) - else: - assert False, "unexpected fallthrough" + raise AssertionError("unexpected fallthrough") @application.route("/", methods=["DELETE"]) @util.time_request("delete_query") -def storage_delete(*, storage: WritableTableStorage, timer: Timer) -> Union[Response, str]: +def storage_delete(*, storage: WritableTableStorage, timer: Timer) -> Response | str: if http_request.method == "DELETE": body = parse_request_body(http_request) @@ -373,8 +360,7 @@ def storage_delete(*, storage: WritableTableStorage, timer: Timer) -> Union[Resp # i put the result inside "data" bc thats how sentry utils/snuba.py expects the result return Response(dump_payload({"data": payload}), 200, {"Content-Type": "application/json"}) - else: - assert False, "unexpected fallthrough" + raise AssertionError("unexpected fallthrough") def _sanitize_payload(payload: MutableMapping[str, Any], res: MutableMapping[str, Any]) -> None: @@ -420,7 +406,7 @@ def dump_payload(payload: MutableMapping[str, Any]) -> str: @with_span() def dataset_query( - dataset_name: str, body: Dict[str, Any], timer: Timer, is_mql: bool = False + dataset_name: str, body: dict[str, Any], timer: Timer, is_mql: bool = False ) -> Response: assert http_request.method == "POST" referrer = http_request.referrer or "" # mypy @@ -611,7 +597,7 @@ def eventstream(*, entity: Entity) -> RespTuple: record = json.loads(http_request.data) version = record[0] if version != 2: - raise RuntimeError("Unsupported protocol version: %s" % record) + raise RuntimeError(f"Unsupported protocol version: {record}") message: Message[KafkaPayload] = Message( BrokerValue( @@ -675,7 +661,7 @@ def commit(offsets: Mapping[Partition, int], force: bool = False) -> None: # TODO: This is a temporary workaround so that we return a more useful error when # attempting to write to a dataset where the migration hasn't been run. This should be # no longer necessary once we have more advanced dataset management in place. - raise InternalServerError(str(e), original_exception=e) + raise InternalServerError(str(e), original_exception=e) from e return ("ok", 200, {"Content-Type": "text/plain"}) @@ -689,6 +675,6 @@ def drop(*, dataset: Dataset) -> RespTuple: @application.route("/tests/error") def error() -> RespTuple: - 1 / 0 + 1 / 0 # noqa: B018 intentionally raises ZeroDivisionError to exercise the error path # unreachable. A valid response is added for mypy return ("error", 500, {"Content-Type": "text/plain"}) diff --git a/snuba/writer.py b/snuba/writer.py index 068290cb1bc..37bd04f30f6 100644 --- a/snuba/writer.py +++ b/snuba/writer.py @@ -2,7 +2,8 @@ import logging from abc import ABC, abstractmethod -from typing import Any, Generic, Iterable, List, Mapping, TypeVar +from collections.abc import Iterable, Mapping +from typing import Any, Generic, TypeVar from snuba.utils.codecs import Encoder, TDecoded, TEncoded @@ -47,7 +48,7 @@ def __init__( ): self.__writer = writer self.__buffer_size = buffer_size - self.__buffer: List[TEncoded] = [] + self.__buffer: list[TEncoded] = [] self.__encoder = encoder def __flush(self) -> None: diff --git a/test_distributed_migrations/conftest.py b/test_distributed_migrations/conftest.py index 93ce9efb869..d79a6a2e6e9 100644 --- a/test_distributed_migrations/conftest.py +++ b/test_distributed_migrations/conftest.py @@ -50,12 +50,8 @@ def pytest_configure() -> None: http_port=cluster_node["http_port"], storage_sets=cluster_node["storage_sets"], single_node=cluster_node["single_node"], - cluster_name=cluster_node["cluster_name"] if "cluster_name" in cluster_node else None, - distributed_cluster_name=( - cluster_node["distributed_cluster_name"] - if "distributed_cluster_name" in cluster_node - else None - ), + cluster_name=cluster_node.get("cluster_name", None), + distributed_cluster_name=(cluster_node.get("distributed_cluster_name", None)), secure=cluster_node.get("secure", False), ca_certs=cluster_node.get("ca_certs", None), verify=cluster_node.get("verify", False), diff --git a/test_initialization/test_initialize.py b/test_initialization/test_initialize.py index 8ac465eae9d..92e27ef1958 100644 --- a/test_initialization/test_initialize.py +++ b/test_initialization/test_initialize.py @@ -6,7 +6,7 @@ class TestInitialization: def test_init( self, - ): + ) -> None: # first make sure all the factories are not initialized # this is accessing private module variables but we don't have a # better way of knowing things are initialized (2022-10-27) diff --git a/tests/admin/clickhouse/test_querylog.py b/tests/admin/clickhouse/test_querylog.py index b5d998ccdfb..0e0d0f714e5 100644 --- a/tests/admin/clickhouse/test_querylog.py +++ b/tests/admin/clickhouse/test_querylog.py @@ -1,7 +1,5 @@ from __future__ import annotations -from typing import Type - import pytest from snuba import state @@ -32,7 +30,7 @@ def test_get_clickhouse_threads(config_val: str | int, expected_threads: int) -> ], ) @pytest.mark.redis_db -def test_get_clickhouse_threads_error(config_val: str | int, error: Type[Exception]) -> None: +def test_get_clickhouse_threads_error(config_val: str | int, error: type[Exception]) -> None: state.set_config("admin.querylog_threads", str(config_val)) with pytest.raises(error): _get_clickhouse_threads() diff --git a/tests/admin/clickhouse_migrations/test_api.py b/tests/admin/clickhouse_migrations/test_api.py index 8e45e92c0c1..3fd477dc159 100644 --- a/tests/admin/clickhouse_migrations/test_api.py +++ b/tests/admin/clickhouse_migrations/test_api.py @@ -1,8 +1,9 @@ from __future__ import annotations import tempfile +from collections.abc import Mapping, Sequence from dataclasses import asdict -from typing import Any, Mapping, Optional, Sequence, Type +from typing import Any from unittest.mock import patch import pytest @@ -35,13 +36,13 @@ def generate_migration_test_role( group: str, policy: str, override_resource: bool = False, - name: Optional[str] = None, + name: str | None = None, ) -> Role: if not name: name = f"{group}-{policy}" if policy == "all": - action: Type[MigrationAction] = ExecuteAllAction + action: type[MigrationAction] = ExecuteAllAction elif policy == "non_blocking": action = ExecuteNonBlockingAction else: @@ -264,21 +265,23 @@ def print_something(*args: Any, **kwargs: Any) -> None: ) if action == "reverse": - with patch.object(Runner, method) as mock_run_migration: - # allowed non blocking - with patch( + # allowed non blocking + with ( + patch.object(Runner, method) as mock_run_migration, + patch( "snuba.migrations.runner.Runner.get_status", return_value=(Status.IN_PROGRESS, None), - ): - migration_key = MigrationKey( - group=MigrationGroup.QUERYLOG, - migration_id="0001_querylog", - ) - response = admin_api.post(f"/migrations/querylog/{action}/0001_querylog") - assert response.status_code == 200 - mock_run_migration.assert_called_once_with( - migration_key, force=False, fake=False, dry_run=False - ) + ), + ): + migration_key = MigrationKey( + group=MigrationGroup.QUERYLOG, + migration_id="0001_querylog", + ) + response = admin_api.post(f"/migrations/querylog/{action}/0001_querylog") + assert response.status_code == 200 + mock_run_migration.assert_called_once_with( + migration_key, force=False, fake=False, dry_run=False + ) # allow dry runs with patch.object(Runner, method) as mock_run_migration: @@ -301,7 +304,7 @@ def test_get_iam_roles(caplog: Any) -> None: "snuba.admin.auth.DEFAULT_ROLES", [system_role, tool_role], ): - iam_file = tempfile.NamedTemporaryFile() + iam_file = tempfile.NamedTemporaryFile() # noqa: SIM115 handle reused across statements and explicitly closed below iam_file.write( json.dumps( { @@ -389,7 +392,7 @@ def test_get_iam_roles_cache() -> None: "snuba.admin.auth.DEFAULT_ROLES", [system_role, tool_role], ): - iam_file = tempfile.NamedTemporaryFile() + iam_file = tempfile.NamedTemporaryFile() # noqa: SIM115 handle reused across statements and explicitly closed below iam_file.write( json.dumps( { @@ -442,7 +445,7 @@ def test_get_iam_roles_cache() -> None: tool_role, ] - iam_file = tempfile.NamedTemporaryFile() + iam_file = tempfile.NamedTemporaryFile() # noqa: SIM115 handle reused across statements and explicitly closed below iam_file.write(json.dumps({"bindings": []}).encode("utf-8")) iam_file.flush() @@ -478,7 +481,7 @@ def test_get_iam_roles_cache_fail(mock_redis: Any) -> None: "snuba.admin.auth.DEFAULT_ROLES", [system_role, tool_role], ): - iam_file = tempfile.NamedTemporaryFile() + iam_file = tempfile.NamedTemporaryFile() # noqa: SIM115 handle reused across statements and explicitly closed below iam_file.write(json.dumps({"bindings": []}).encode("utf-8")) iam_file.flush() diff --git a/tests/admin/clickhouse_migrations/test_migration_checks.py b/tests/admin/clickhouse_migrations/test_migration_checks.py index acacd610af2..0678e3632fc 100644 --- a/tests/admin/clickhouse_migrations/test_migration_checks.py +++ b/tests/admin/clickhouse_migrations/test_migration_checks.py @@ -1,4 +1,4 @@ -from typing import Optional, Sequence, Tuple +from collections.abc import Sequence from unittest.mock import Mock, patch import pytest @@ -9,6 +9,7 @@ RunReason, RunResult, StatusChecker, + run_migration_checks_and_policies, ) from snuba.migrations.group_loader import DirectoryLoader, GroupLoader from snuba.migrations.groups import MigrationGroup @@ -51,7 +52,7 @@ def test_status_checker_run( mock_loader: Mock, migration_id: str, expected_allowed: bool, - expected_reason: Optional[RunReason], + expected_reason: RunReason | None, ) -> None: group = MigrationGroup("querylog") checker = StatusChecker(group, RUN_MIGRATIONS) @@ -84,7 +85,7 @@ def test_status_checker_reverse( mock_loader: Mock, migration_id: str, expected_allowed: bool, - expected_reason: Optional[ReverseReason], + expected_reason: ReverseReason | None, ) -> None: group = MigrationGroup("querylog") checker = StatusChecker(group, REVERSE_MIGRATIONS) @@ -122,9 +123,6 @@ def test_status_checker_errors() -> None: checker.can_reverse(MigrationKey(MigrationGroup("events"), migration_id)) -from snuba.admin.clickhouse.migration_checks import run_migration_checks_and_policies - - @patch( "snuba.admin.clickhouse.migration_checks.get_group_loader", return_value=group_loader(), @@ -146,9 +144,9 @@ def test_run_migration_checks_and_policies( group_loader: Mock, mock_checker: Mock, mock_runner: Mock, - policy_result: Tuple[bool, bool], - status_result: Tuple[bool, bool], - expected: Tuple[bool, bool], + policy_result: tuple[bool, bool], + status_result: tuple[bool, bool], + expected: tuple[bool, bool], ) -> None: mock_policy = Mock() checker = mock_checker() diff --git a/tests/admin/test_api.py b/tests/admin/test_api.py index 4fb913df028..937012f5fde 100644 --- a/tests/admin/test_api.py +++ b/tests/admin/test_api.py @@ -1,7 +1,8 @@ from __future__ import annotations +from collections.abc import Sequence from datetime import UTC, datetime, timedelta -from typing import Any, Sequence, Tuple, Type +from typing import Any from unittest import mock import pytest @@ -50,7 +51,7 @@ def admin_api() -> FlaskClient: @pytest.fixture(scope="session") -def rpc_test_setup() -> Tuple[Type[Any], Type[RPCEndpoint[Any, TimeSeriesResponse]]]: +def rpc_test_setup() -> tuple[type[Any], type[RPCEndpoint[Any, TimeSeriesResponse]]]: class TestRPC(RPCEndpoint[TimeSeriesRequest, TimeSeriesResponse]): @classmethod def version(cls) -> str: @@ -337,7 +338,7 @@ def test_query_trace_bad_query(admin_api: FlaskClient) -> None: "Exception: Unknown expression or function identifier" in data["error"]["message"] or "Exception: Missing columns" in data["error"]["message"] ) - assert "clickhouse" == data["error"]["type"] + assert data["error"]["type"] == "clickhouse" @pytest.mark.redis_db @@ -352,7 +353,7 @@ def test_query_trace_invalid_query(admin_api: FlaskClient) -> None: assert response.status_code == 400 data = json.loads(response.data) assert "; is not allowed in the query" in data["error"]["message"] - assert "validation" == data["error"]["type"] + assert data["error"]["type"] == "validation" @pytest.mark.redis_db @@ -924,7 +925,7 @@ def mock_record(user: Any, action: Any, data: Any, notify: Any) -> None: nonlocal auditlog_records auditlog_records.append((user, action, data, notify)) - def mock_get_from_name(strategy_name: str) -> Type[BaseRoutingStrategy]: + def mock_get_from_name(strategy_name: str) -> type[BaseRoutingStrategy]: return FakeRoutingStrategy with ( @@ -1295,7 +1296,7 @@ def test_clickhouse_node_info( response_data = json.loads(response.data) assert ( len(response_data) > 0 - and {k: response_data[0][k] for k in expected_result.keys()} == expected_result + and {k: response_data[0][k] for k in expected_result} == expected_result ) @@ -1365,7 +1366,7 @@ def test_clickhouse_system_settings( @pytest.mark.clickhouse_db def test_execute_rpc_endpoint_success( admin_api: FlaskClient, - rpc_test_setup: Tuple[Type[Any], Type[RPCEndpoint[Any, TimeSeriesResponse]]], + rpc_test_setup: tuple[type[Any], type[RPCEndpoint[Any, TimeSeriesResponse]]], ) -> None: MyRequest, TestRPC = rpc_test_setup @@ -1421,7 +1422,7 @@ def test_execute_rpc_endpoint_unknown_endpoint(admin_api: FlaskClient) -> None: @pytest.mark.redis_db def test_execute_rpc_endpoint_invalid_payload( admin_api: FlaskClient, - rpc_test_setup: Tuple[Type[Any], Type[RPCEndpoint[Any, TimeSeriesResponse]]], + rpc_test_setup: tuple[type[Any], type[RPCEndpoint[Any, TimeSeriesResponse]]], ) -> None: MyRequest, TestRPC = rpc_test_setup @@ -1438,7 +1439,7 @@ def test_execute_rpc_endpoint_invalid_payload( @pytest.mark.redis_db def test_execute_rpc_endpoint_org_id_not_allowed( admin_api: FlaskClient, - rpc_test_setup: Tuple[Type[Any], Type[RPCEndpoint[Any, TimeSeriesResponse]]], + rpc_test_setup: tuple[type[Any], type[RPCEndpoint[Any, TimeSeriesResponse]]], ) -> None: MyRequest, TestRPC = rpc_test_setup @@ -1479,5 +1480,5 @@ def test_list_rpc_endpoints(admin_api: FlaskClient) -> None: assert isinstance(endpoint[1], str) registered_endpoints = {tuple(name.split("__")) for name in RPCEndpoint.all_names()} - response_endpoints = set(tuple(endpoint) for endpoint in endpoint_names) + response_endpoints = {tuple(endpoint) for endpoint in endpoint_names} assert response_endpoints == registered_endpoints diff --git a/tests/admin/test_jwt.py b/tests/admin/test_jwt.py index 4b567c145a5..eb567927f6f 100644 --- a/tests/admin/test_jwt.py +++ b/tests/admin/test_jwt.py @@ -5,6 +5,7 @@ import time from unittest.mock import patch +import jwt import pytest from cryptography.hazmat.primitives import hashes, serialization from cryptography.hazmat.primitives.asymmetric import ec @@ -97,9 +98,9 @@ def test_validate_assertion_rejects_wrong_audience( with ( patch("snuba.admin.jwt._certs", return_value={KID: public_pem}), patch("snuba.admin.jwt._audience", return_value=AUDIENCE), + pytest.raises(jwt.InvalidAudienceError), ): - with pytest.raises(Exception): - validate_assertion(token) + validate_assertion(token) def test_validate_assertion_rejects_bad_signature( @@ -131,6 +132,6 @@ def test_validate_assertion_rejects_bad_signature( with ( patch("snuba.admin.jwt._certs", return_value={KID: other_pem}), patch("snuba.admin.jwt._audience", return_value=AUDIENCE), + pytest.raises(jwt.InvalidSignatureError), ): - with pytest.raises(Exception): - validate_assertion(token) + validate_assertion(token) diff --git a/tests/admin/test_migration_policies.py b/tests/admin/test_migration_policies.py index c70b507e871..edd73422dd7 100644 --- a/tests/admin/test_migration_policies.py +++ b/tests/admin/test_migration_policies.py @@ -1,4 +1,4 @@ -from typing import Sequence, Set +from collections.abc import Sequence import pytest @@ -42,12 +42,12 @@ ), ], ) -def test_get_group_policies(roles: Sequence[Role], expected_policies: Set[MigrationPolicy]) -> None: +def test_get_group_policies(roles: Sequence[Role], expected_policies: set[MigrationPolicy]) -> None: user = AdminUser("meredith@sentry.io", "123", roles=roles) results = get_migration_group_policies(user) - assert set(r.__class__ for r in results["test_migration"]) == set( + assert {r.__class__ for r in results["test_migration"]} == { e.__class__ for e in expected_policies - ) + } def test_get_migration_group_policies_sans_roles() -> None: diff --git a/tests/admin/test_querylog_audit_log.py b/tests/admin/test_querylog_audit_log.py index 4449e2aeea1..a3f86976290 100644 --- a/tests/admin/test_querylog_audit_log.py +++ b/tests/admin/test_querylog_audit_log.py @@ -34,7 +34,7 @@ def test_audit_log_failure() -> None: def failed_query(query: str, user: str) -> ClickhouseResult: raise Exception() - with pytest.raises(Exception): + with pytest.raises(Exception): # noqa: B017 decorated fn raises bare Exception; verifying audit_log re-raises it failed_query("test_bad_query", "test_bad_user") assert len(cap_logs) == 1 diff --git a/tests/admin/test_system_queries.py b/tests/admin/test_system_queries.py index d395bc743fa..14cc96b735e 100644 --- a/tests/admin/test_system_queries.py +++ b/tests/admin/test_system_queries.py @@ -1,13 +1,14 @@ import ast +import contextlib +from collections.abc import Sequence from pathlib import Path -from typing import Sequence from unittest.mock import patch import pytest from snuba import settings from snuba.admin.auth_roles import ROLES, Role -from snuba.admin.clickhouse.common import InvalidNodeError +from snuba.admin.clickhouse.common import InvalidCustomQuery, InvalidNodeError from snuba.admin.clickhouse.system_queries import ( UnauthorizedForSudo, is_valid_system_query, @@ -97,7 +98,7 @@ def test_is_valid_system_query(sql_query: str) -> None: ) @pytest.mark.events_db def test_invalid_system_query(sql_query: str) -> None: - with pytest.raises(Exception): + with pytest.raises(InvalidCustomQuery): is_valid_system_query( settings.CLUSTERS[0]["host"], int(settings.CLUSTERS[0]["port"]), @@ -147,7 +148,7 @@ def test_sudo_queries(sudo_query: str, expected: bool) -> None: False, ) # Should no-op else: - with pytest.raises(Exception): + with pytest.raises(InvalidCustomQuery): validate_query( settings.CLUSTERS[0]["host"], int(settings.CLUSTERS[0]["port"]), @@ -222,7 +223,7 @@ def run_query() -> None: with pytest.raises(UnauthorizedForSudo): run_query() elif not expect_valid: - with pytest.raises(Exception): + with pytest.raises(InvalidCustomQuery): run_query() else: run_query() @@ -353,7 +354,8 @@ def test_sudo_mode_skips_experimental_analyzer(sql_query: str, sudo_mode: bool) mock_result = type("MockResult", (), {"results": []})() mock_run.return_value = mock_result - try: + # We don't care if validation fails, we just want to check the query + with contextlib.suppress(Exception): is_valid_system_query( settings.CLUSTERS[0]["host"], int(settings.CLUSTERS[0]["port"]), @@ -362,8 +364,6 @@ def test_sudo_mode_skips_experimental_analyzer(sql_query: str, sudo_mode: bool) False, sudo_mode, ) - except Exception: - pass # We don't care if validation fails, we just want to check the query # Check that the EXPLAIN QUERY TREE was called calls = [call for call in mock_run.call_args_list if "EXPLAIN QUERY TREE" in str(call)] diff --git a/tests/assertions.py b/tests/assertions.py index a6f5a6e2944..5f0aba80da0 100644 --- a/tests/assertions.py +++ b/tests/assertions.py @@ -1,6 +1,7 @@ import operator +from collections.abc import Callable, Iterator from contextlib import contextmanager -from typing import Any, Callable, Iterator, TypeVar, cast +from typing import Any, TypeVar, cast T = TypeVar("T") diff --git a/tests/backends/metrics.py b/tests/backends/metrics.py index 7e18c2ce1d2..5d6bc26aaa6 100644 --- a/tests/backends/metrics.py +++ b/tests/backends/metrics.py @@ -1,4 +1,5 @@ -from typing import MutableSequence, NamedTuple, Optional, Union +from collections.abc import MutableSequence +from typing import NamedTuple from snuba.utils.metrics.backends.abstract import MetricsBackend from snuba.utils.metrics.types import Tags @@ -6,30 +7,30 @@ class Increment(NamedTuple): name: str - value: Union[int, float] - tags: Optional[Tags] - unit: Optional[str] = None + value: int | float + tags: Tags | None + unit: str | None = None class Gauge(NamedTuple): name: str - value: Union[int, float] - tags: Optional[Tags] - unit: Optional[str] = None + value: int | float + tags: Tags | None + unit: str | None = None class Timing(NamedTuple): name: str - value: Union[int, float] - tags: Optional[Tags] - unit: Optional[str] = None + value: int | float + tags: Tags | None + unit: str | None = None class Distribution(NamedTuple): name: str - value: Union[int, float] - tags: Optional[Tags] - unit: Optional[str] = None + value: int | float + tags: Tags | None + unit: str | None = None class Events(NamedTuple): @@ -37,7 +38,7 @@ class Events(NamedTuple): text: str alert_type: str priority: str - tags: Optional[Tags] + tags: Tags | None class TestingMetricsBackend(MetricsBackend): @@ -50,41 +51,41 @@ class TestingMetricsBackend(MetricsBackend): # TODO: This might make sense to extend the dummy metrics backend. def __init__(self) -> None: - self.calls: MutableSequence[Union[Increment, Gauge, Timing, Distribution, Events]] = [] + self.calls: MutableSequence[Increment | Gauge | Timing | Distribution | Events] = [] def increment( self, name: str, - value: Union[int, float] = 1, - tags: Optional[Tags] = None, - unit: Optional[str] = None, + value: int | float = 1, + tags: Tags | None = None, + unit: str | None = None, ) -> None: self.calls.append(Increment(name, value, tags, unit)) def gauge( self, name: str, - value: Union[int, float], - tags: Optional[Tags] = None, - unit: Optional[str] = None, + value: int | float, + tags: Tags | None = None, + unit: str | None = None, ) -> None: self.calls.append(Gauge(name, value, tags, unit)) def timing( self, name: str, - value: Union[int, float], - tags: Optional[Tags] = None, - unit: Optional[str] = None, + value: int | float, + tags: Tags | None = None, + unit: str | None = None, ) -> None: self.calls.append(Timing(name, value, tags, unit)) def distribution( self, name: str, - value: Union[int, float], - tags: Optional[Tags] = None, - unit: Optional[str] = None, + value: int | float, + tags: Tags | None = None, + unit: str | None = None, ) -> None: self.calls.append(Distribution(name, value, tags, unit)) @@ -94,6 +95,6 @@ def events( text: str, alert_type: str, priority: str, - tags: Optional[Tags] = None, + tags: Tags | None = None, ) -> None: self.calls.append(Events(title, text, alert_type, priority, tags)) diff --git a/tests/base.py b/tests/base.py index a4f2518cdb4..1c08a331f10 100644 --- a/tests/base.py +++ b/tests/base.py @@ -1,4 +1,5 @@ -from typing import Any, Callable +from collections.abc import Callable +from typing import Any class BaseApiTest: diff --git a/tests/cli/test_consumer.py b/tests/cli/test_consumer.py index b4c6dca0a05..b7320177092 100644 --- a/tests/cli/test_consumer.py +++ b/tests/cli/test_consumer.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from unittest.mock import Mock, patch from click.testing import CliRunner diff --git a/tests/cli/test_migrations.py b/tests/cli/test_migrations.py index 5150a3011b5..56784cc4401 100644 --- a/tests/cli/test_migrations.py +++ b/tests/cli/test_migrations.py @@ -1,4 +1,4 @@ -from typing import Optional, Sequence +from collections.abc import Sequence import pytest from click import Command @@ -7,7 +7,7 @@ from snuba.cli.migrations import list, migrate, reverse, reverse_in_progress, run -def _check_run(runner: CliRunner, func: Command, args: Optional[Sequence[str]] = None) -> None: +def _check_run(runner: CliRunner, func: Command, args: Sequence[str] | None = None) -> None: result = runner.invoke(func, args) assert result.exit_code == 0 diff --git a/tests/cli/test_subscriptions.py b/tests/cli/test_subscriptions.py index 9dbc3081c06..0b8cdbe0dc4 100644 --- a/tests/cli/test_subscriptions.py +++ b/tests/cli/test_subscriptions.py @@ -1,9 +1,12 @@ -from typing import Sequence +from collections.abc import Sequence from unittest.mock import Mock, patch from click import Command from click.testing import CliRunner +import snuba.cli.subscriptions_executor +import snuba.cli.subscriptions_scheduler +import snuba.cli.subscriptions_scheduler_executor import snuba.subscriptions.executor_consumer from snuba.cli.subscriptions_executor import subscriptions_executor from snuba.cli.subscriptions_scheduler import subscriptions_scheduler diff --git a/tests/clickhouse/optimize/test_optimize.py b/tests/clickhouse/optimize/test_optimize.py index 5a8e13aeaaf..cf5f60b7e54 100644 --- a/tests/clickhouse/optimize/test_optimize.py +++ b/tests/clickhouse/optimize/test_optimize.py @@ -1,6 +1,6 @@ import uuid +from collections.abc import Callable, Mapping from datetime import UTC, datetime, timedelta -from typing import Callable, Mapping from unittest.mock import Mock, patch import pytest @@ -259,21 +259,23 @@ def test_optimize_partition_runner_errors( expire_time=datetime.now() + timedelta(minutes=10), ) - with time_machine.travel(current_time, tick=False): - with patch( + with ( + time_machine.travel(current_time, tick=False), + patch( "snuba.clickhouse.optimize.optimize.optimize_partitions", side_effect=ClickhouseError(), - ): - with pytest.raises(ClickhouseError): - optimize.optimize_partition_runner( - clickhouse=clickhouse, - database=database, - table=table, - partitions=["1", "2", "3", "4", "5", "6", "7", "8", "9", "10"], - default_parallel_threads=3, - tracker=tracker, - clickhouse_host="some-hostname.domain.com", - ) + ), + pytest.raises(ClickhouseError), + ): + optimize.optimize_partition_runner( + clickhouse=clickhouse, + database=database, + table=table, + partitions=["1", "2", "3", "4", "5", "6", "7", "8", "9", "10"], + default_parallel_threads=3, + tracker=tracker, + clickhouse_host="some-hostname.domain.com", + ) # For ClickHouse 23.3 and 23.8 parts from previous test runs # interfere with following tests, so best to drop the tables @@ -351,20 +353,24 @@ def test_optimize_partitions_raises_exception_with_cutoff_time() -> None: dummy_partition = "(90,'2022-03-28')" tracker.update_all_partitions([dummy_partition]) - with time_machine.travel( - last_midnight + timedelta(hours=settings.OPTIMIZE_JOB_CUTOFF_TIME) + timedelta(minutes=15), - tick=False, + with ( + time_machine.travel( + last_midnight + + timedelta(hours=settings.OPTIMIZE_JOB_CUTOFF_TIME) + + timedelta(minutes=15), + tick=False, + ), + pytest.raises(OptimizedSchedulerTimeout), ): - with pytest.raises(OptimizedSchedulerTimeout): - optimize_partition_runner( - clickhouse=clickhouse_pool, - database=database, - table=table, - partitions=[dummy_partition], - default_parallel_threads=2, - tracker=tracker, - clickhouse_host="some-hostname.domain.com", - ) + optimize_partition_runner( + clickhouse=clickhouse_pool, + database=database, + table=table, + partitions=[dummy_partition], + default_parallel_threads=2, + tracker=tracker, + clickhouse_host="some-hostname.domain.com", + ) tracker.delete_all_states() settings.OPTIMIZE_JOB_CUTOFF_TIME = prev_job_cutoff_time diff --git a/tests/clickhouse/optimize/test_optimize_scheduler.py b/tests/clickhouse/optimize/test_optimize_scheduler.py index 0ffb8b806a6..3baa053b1ff 100644 --- a/tests/clickhouse/optimize/test_optimize_scheduler.py +++ b/tests/clickhouse/optimize/test_optimize_scheduler.py @@ -1,5 +1,5 @@ +from collections.abc import Sequence from datetime import UTC, datetime, timedelta -from typing import Sequence import pytest import time_machine @@ -228,11 +228,13 @@ def test_get_next_schedule( def test_get_next_schedule_raises_exception() -> None: with time_machine.travel(last_midnight, tick=False): optimize_scheduler = OptimizeScheduler(default_parallel_threads=1) - with time_machine.travel( - last_midnight - + timedelta(hours=settings.OPTIMIZE_JOB_CUTOFF_TIME) - + timedelta(minutes=20), - tick=False, + with ( + time_machine.travel( + last_midnight + + timedelta(hours=settings.OPTIMIZE_JOB_CUTOFF_TIME) + + timedelta(minutes=20), + tick=False, + ), + pytest.raises(OptimizedSchedulerTimeout), ): - with pytest.raises(OptimizedSchedulerTimeout): - optimize_scheduler.get_next_schedule(["(90,'2022-03-28')", "(90,'2022-03-21')"]) + optimize_scheduler.get_next_schedule(["(90,'2022-03-28')", "(90,'2022-03-21')"]) diff --git a/tests/clickhouse/optimize/test_optimize_tracker.py b/tests/clickhouse/optimize/test_optimize_tracker.py index ee03e5e38a9..2707ab9c1e5 100644 --- a/tests/clickhouse/optimize/test_optimize_tracker.py +++ b/tests/clickhouse/optimize/test_optimize_tracker.py @@ -1,7 +1,6 @@ import time import uuid from datetime import datetime, timedelta -from typing import Optional, Set from unittest.mock import call, patch import pytest @@ -52,9 +51,9 @@ def test_optimized_partition_tracker(tracker: OptimizedPartitionTracker) -> None: def assert_partitions( *, - all: Optional[Set[str]] = None, - completed: Optional[Set[str]] = None, - pending: Optional[Set[str]] = None, + all: set[str] | None = None, + completed: set[str] | None = None, + pending: set[str] | None = None, ) -> None: """ Assert partition status with sleep + retry. This is needed for when the diff --git a/tests/clickhouse/query_dsl/test_accessors.py b/tests/clickhouse/query_dsl/test_accessors.py index ba26fd9f65f..a118392ce2c 100644 --- a/tests/clickhouse/query_dsl/test_accessors.py +++ b/tests/clickhouse/query_dsl/test_accessors.py @@ -26,7 +26,7 @@ Literal(None, 1), ), ), - set([1]), + {1}, id="equals_case", ), pytest.param( @@ -38,12 +38,12 @@ FunctionCall(None, "array", tuple([Literal(None, i) for i in range(1, 5)])), ), ), - set([1, 2, 3, 4]), + {1, 2, 3, 4}, id="in_case", ), pytest.param( Literal(None, 1), - set([]), + set(), id="empty_case", ), ], diff --git a/tests/clickhouse/query_dsl/test_project_id.py b/tests/clickhouse/query_dsl/test_project_id.py index a5fffcd04ac..936324a1702 100644 --- a/tests/clickhouse/query_dsl/test_project_id.py +++ b/tests/clickhouse/query_dsl/test_project_id.py @@ -1,4 +1,5 @@ -from typing import Any, Mapping, MutableMapping, Optional, Sequence, Set, Tuple, cast +from collections.abc import Mapping, MutableMapping, Sequence +from typing import Any, cast import pytest from snuba_sdk.legacy import json_to_snql @@ -12,7 +13,7 @@ from snuba.query.query_settings import HTTPQuerySettings from snuba.query.snql.parser import parse_snql_query -test_cases: Sequence[Tuple[Mapping[str, Any], Optional[Set[int]]]] = [ +test_cases: Sequence[tuple[Mapping[str, Any], set[int] | None]] = [ ( { "selected_columns": ["event_id"], @@ -174,7 +175,7 @@ @pytest.mark.redis_db @pytest.mark.parametrize("query_body, expected_projects", test_cases) def test_find_projects( - query_body: MutableMapping[str, Any], expected_projects: Optional[Set[int]] + query_body: MutableMapping[str, Any], expected_projects: set[int] | None ) -> None: events = get_dataset("events") if expected_projects is None: diff --git a/tests/clickhouse/test_http.py b/tests/clickhouse/test_http.py index 9039e8e93ef..7ef84177bb5 100644 --- a/tests/clickhouse/test_http.py +++ b/tests/clickhouse/test_http.py @@ -20,11 +20,11 @@ def test_encode_preserves_column_order(values_encoder: ValuesRowEncoder) -> None "col1": FunctionCall( None, "test", - tuple([FunctionCall(None, "inner", tuple([Literal(None, "inner_arg")]))]), + (FunctionCall(None, "inner", (Literal(None, "inner_arg"),)),), ), } ) - assert encoded == "(test(inner('inner_arg')),5,'test_string')".encode("utf-8") + assert encoded == b"(test(inner('inner_arg')),5,'test_string')" def test_encode_fails_on_non_expression(values_encoder: ValuesRowEncoder) -> None: diff --git a/tests/clickhouse/test_native.py b/tests/clickhouse/test_native.py index 1eabd3518f1..bb6ba8a275b 100644 --- a/tests/clickhouse/test_native.py +++ b/tests/clickhouse/test_native.py @@ -1,6 +1,7 @@ import queue +from collections.abc import Callable from datetime import datetime, timedelta -from typing import Any, Callable +from typing import Any from unittest import mock import pytest @@ -46,11 +47,11 @@ def test_robust_concurrency_limit() -> None: assert connection.execute.call_count == 3, "Expected three attempts" -class TestError(errors.Error): # type: ignore +class TestError(errors.Error): # type: ignore[misc] code = 1 -class TestConcurrentError(errors.Error): # type: ignore +class TestConcurrentError(errors.Error): # type: ignore[misc] code = errors.ErrorCodes.TOO_MANY_SIMULTANEOUS_QUERIES diff --git a/tests/clickhouse/test_profile_events.py b/tests/clickhouse/test_profile_events.py index 0e3ccde6d92..4df8c187557 100644 --- a/tests/clickhouse/test_profile_events.py +++ b/tests/clickhouse/test_profile_events.py @@ -11,8 +11,8 @@ def test_hostname_resolves() -> None: - assert hostname_resolves("localhost") == True - assert hostname_resolves("invalid-hostname-that-doesnt-exist-123.local") == False + assert hostname_resolves("localhost") + assert not hostname_resolves("invalid-hostname-that-doesnt-exist-123.local") def test_parse_trace_for_query_ids() -> None: @@ -106,18 +106,20 @@ def test_gather_profile_events_retry_logic() -> None: "snuba.admin.clickhouse.profile_events.run_system_query_on_host_with_sql" ) as mock_query: mock_query.side_effect = [empty_result, empty_result, success_result] - with patch("snuba.admin.clickhouse.profile_events.hostname_resolves", return_value=True): - with patch("time.sleep") as mock_sleep: - from flask import Flask + with ( + patch("snuba.admin.clickhouse.profile_events.hostname_resolves", return_value=True), + patch("time.sleep") as mock_sleep, + ): + from flask import Flask - app = Flask(__name__) - with app.app_context(): - g.user = "test_user" + app = Flask(__name__) + with app.app_context(): + g.user = "test_user" - gather_profile_events(trace_output, "test_storage") + gather_profile_events(trace_output, "test_storage") - assert mock_query.call_count == 3 - assert mock_sleep.call_count == 2 + assert mock_query.call_count == 3 + assert mock_sleep.call_count == 2 - assert mock_sleep.call_args_list[0][0][0] == 2 - assert mock_sleep.call_args_list[1][0][0] == 4 + assert mock_sleep.call_args_list[0][0][0] == 2 + assert mock_sleep.call_args_list[1][0][0] == 4 diff --git a/tests/clickhouse/test_query_format.py b/tests/clickhouse/test_query_format.py index 53c9262e602..ece708a2a3f 100644 --- a/tests/clickhouse/test_query_format.py +++ b/tests/clickhouse/test_query_format.py @@ -1,4 +1,5 @@ -from typing import Any, Sequence +from collections.abc import Sequence +from typing import Any import pytest @@ -484,7 +485,7 @@ ), selected_columns=[ SelectedExpression("group_id", Column("group_id", "err", "group_id")), - SelectedExpression("events", FunctionCall("events", "count", tuple())), + SelectedExpression("events", FunctionCall("events", "count", ())), ], groupby=[Column(None, "groups", "id")], ), diff --git a/tests/clusters/fake_cluster.py b/tests/clusters/fake_cluster.py index 4be4782d5b9..fdf78f45f56 100644 --- a/tests/clusters/fake_cluster.py +++ b/tests/clusters/fake_cluster.py @@ -1,4 +1,5 @@ -from typing import Any, List, Mapping, MutableMapping, Optional, Sequence, Set, Tuple +from collections.abc import Mapping, MutableMapping, Sequence +from typing import Any from snuba.clickhouse.native import ClickhousePool, ClickhouseResult, Params from snuba.clusters.cluster import ( @@ -15,7 +16,7 @@ class ServerExplodedException(SerializableException): class FakeClickhousePool(ClickhousePool): def __init__(self, host_name: str) -> None: - self.__queries: List[str] = [] + self.__queries: list[str] = [] self.host = host_name def execute( @@ -23,8 +24,8 @@ def execute( query: str, params: Params = None, with_column_types: bool = False, - query_id: Optional[str] = None, - settings: Optional[Mapping[str, Any]] = None, + query_id: str | None = None, + settings: Mapping[str, Any] | None = None, types_check: bool = False, columnar: bool = False, capture_trace: bool = False, @@ -39,7 +40,7 @@ def get_queries(self) -> Sequence[str]: class FakeFailingClickhousePool(FakeClickhousePool): def __init__(self, host_name: str) -> None: - self.__queries: List[str] = [] + self.__queries: list[str] = [] self.host = host_name def execute( @@ -47,8 +48,8 @@ def execute( query: str, params: Params = None, with_column_types: bool = False, - query_id: Optional[str] = None, - settings: Optional[Mapping[str, Any]] = None, + query_id: str | None = None, + settings: Mapping[str, Any] | None = None, types_check: bool = False, columnar: bool = False, capture_trace: bool = False, @@ -70,14 +71,14 @@ def __init__( database: str, http_port: int, secure: bool, - ca_certs: Optional[str], - verify: Optional[bool], - storage_sets: Set[str], + ca_certs: str | None, + verify: bool | None, + storage_sets: set[str], single_node: bool, # The cluster name and distributed cluster name only apply if single_node is set to False - cluster_name: Optional[str] = None, - distributed_cluster_name: Optional[str] = None, - nodes: Optional[Sequence[Tuple[ClickhouseNode, bool]]] = None, + cluster_name: str | None = None, + distributed_cluster_name: str | None = None, + nodes: Sequence[tuple[ClickhouseNode, bool]] | None = None, ): super().__init__( host=host, @@ -98,7 +99,7 @@ def __init__( self.__cluster_name = cluster_name self.__nodes = {node.host_name: (node, healthy) for node, healthy in nodes} if nodes else {} self.__connections: MutableMapping[ - Tuple[ClickhouseNode, ClickhouseClientSettings], FakeClickhousePool + tuple[ClickhouseNode, ClickhouseClientSettings], FakeClickhousePool ] = {} def get_queries( diff --git a/tests/clusters/test_cluster.py b/tests/clusters/test_cluster.py index e8ef94dbdb6..7c00232dc34 100644 --- a/tests/clusters/test_cluster.py +++ b/tests/clusters/test_cluster.py @@ -1,5 +1,5 @@ import importlib -from typing import Generator +from collections.abc import Generator from unittest.mock import patch import pytest @@ -111,7 +111,7 @@ @pytest.fixture(autouse=True) -def setup_teardown(clickhouse_db: None) -> Generator[None, None, None]: +def setup_teardown(clickhouse_db: None) -> Generator[None]: yield importlib.reload(settings) importlib.reload(cluster) @@ -135,21 +135,29 @@ def test_clusters() -> None: @patch("snuba.settings.CLUSTERS", FULL_CONFIG) @pytest.mark.clickhouse_db def test_cache_partition() -> None: - get_storage( - StorageKey("transactions") - ).get_cluster().get_reader().cache_partition_id == "host_2_cache" + assert ( + get_storage(StorageKey("transactions")).get_cluster().get_reader().cache_partition_id + == "host_2_cache" + ) - get_storage(StorageKey("errors")).get_cluster().get_reader().cache_partition_id is None + assert get_storage(StorageKey("errors")).get_cluster().get_reader().cache_partition_id is None @patch("snuba.settings.CLUSTERS", FULL_CONFIG) @pytest.mark.clickhouse_db def test_query_settings_prefix() -> None: - get_storage( - StorageKey("transactions") - ).get_cluster().get_reader().get_query_settings_prefix() == "transactions" + assert ( + get_storage(StorageKey("transactions")) + .get_cluster() + .get_reader() + .get_query_settings_prefix() + == "transactions" + ) - get_storage(StorageKey("errors")).get_cluster().get_reader().get_query_settings_prefix() is None + assert ( + get_storage(StorageKey("errors")).get_cluster().get_reader().get_query_settings_prefix() + is None + ) @patch("snuba.settings.CLUSTERS", FULL_CONFIG) @@ -168,9 +176,8 @@ def test_disabled_cluster() -> None: cluster.get_cluster(StorageSetKey.OUTCOMES) - with patch("snuba.settings.ENABLE_DEV_FEATURES", False): - with pytest.raises(AssertionError): - cluster.get_cluster(StorageSetKey.OUTCOMES) + with patch("snuba.settings.ENABLE_DEV_FEATURES", False), pytest.raises(AssertionError): + cluster.get_cluster(StorageSetKey.OUTCOMES) @patch("snuba.settings.CLUSTERS", FULL_CONFIG) @@ -267,14 +274,14 @@ def test_sliced_cluster() -> None: res_cluster = cluster.get_cluster(StorageSetKey.GENERIC_METRICS_DISTRIBUTIONS, 1) - assert res_cluster.is_single_node() == True + assert res_cluster.is_single_node() assert res_cluster.get_database() == "slice_1_default" assert res_cluster.get_host() == "host_slice" assert res_cluster.get_port() == 9001 res_cluster_default = cluster.get_cluster(StorageSetKey.GENERIC_METRICS_DISTRIBUTIONS, 0) - assert res_cluster_default.is_single_node() == True + assert res_cluster_default.is_single_node() assert res_cluster_default.get_database() == "default" assert res_cluster_default.get_host() == "host_slice" assert res_cluster_default.get_port() == 9000 diff --git a/tests/conftest.py b/tests/conftest.py index ba6da0e1c73..bedc978b828 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,17 +1,8 @@ import json import traceback +from collections.abc import Callable, Generator, Sequence from typing import ( Any, - Callable, - Dict, - FrozenSet, - Generator, - List, - Optional, - Sequence, - Set, - Tuple, - Union, ) import pytest @@ -25,17 +16,17 @@ ) from snuba.core.initialize import initialize_snuba from snuba.datasets.factory import reset_dataset_factory -from snuba.datasets.schemas.tables import WritableTableSchema +from snuba.datasets.schemas.tables import TableSchema, WritableTableSchema from snuba.datasets.storages.factory import get_all_storage_keys, get_storage from snuba.datasets.storages.storage_key import StorageKey from snuba.environment import setup_sentry from snuba.migrations.groups import MigrationGroup from snuba.redis import all_redis_clients -NodeTableCache = Dict[Tuple[ClickhouseCluster, ClickhouseNode], Dict[str, str]] -CacheKey = Optional[FrozenSet[MigrationGroup]] +NodeTableCache = dict[tuple[ClickhouseCluster, ClickhouseNode], dict[str, str]] +CacheKey = frozenset[MigrationGroup] | None -DB_MIGRATIONS_CACHE: Dict[CacheKey, NodeTableCache] = {} +DB_MIGRATIONS_CACHE: dict[CacheKey, NodeTableCache] = {} def pytest_configure() -> None: @@ -65,12 +56,8 @@ def create_databases() -> None: verify=cluster["verify"], storage_sets=cluster["storage_sets"], single_node=cluster["single_node"], - cluster_name=cluster["cluster_name"] if "cluster_name" in cluster else None, - distributed_cluster_name=( - cluster["distributed_cluster_name"] - if "distributed_cluster_name" in cluster - else None - ), + cluster_name=cluster.get("cluster_name", None), + distributed_cluster_name=(cluster.get("distributed_cluster_name", None)), ) database_name = cluster["database"] @@ -110,7 +97,7 @@ def pytest_collection_modifyitems(items: Sequence[Any]) -> None: class BlockedObject: def __init__(self, message: str) -> None: - self.__failures: List[List[str]] = [] + self.__failures: list[list[str]] = [] self.__message = message def snuba_test_teardown(self) -> None: @@ -126,7 +113,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any: @pytest.fixture -def block_redis_db(monkeypatch: pytest.MonkeyPatch) -> Generator[None, None, None]: +def block_redis_db(monkeypatch: pytest.MonkeyPatch) -> Generator[None]: from snuba.redis import _redis_clients blocked = BlockedObject( @@ -146,7 +133,7 @@ def block_redis_db(monkeypatch: pytest.MonkeyPatch) -> Generator[None, None, Non @pytest.fixture -def block_clickhouse_db(monkeypatch: pytest.MonkeyPatch) -> Generator[None, None, None]: +def block_clickhouse_db(monkeypatch: pytest.MonkeyPatch) -> Generator[None]: from snuba.clusters.cluster import ClickhouseCluster blocked = BlockedObject( @@ -164,7 +151,7 @@ def block_clickhouse_db(monkeypatch: pytest.MonkeyPatch) -> Generator[None, None @pytest.fixture -def redis_db(request: pytest.FixtureRequest) -> Generator[None, None, None]: +def redis_db(request: pytest.FixtureRequest) -> Generator[None]: if not request.node.get_closest_marker("redis_db"): # Make people use the marker explicitly so `-m` works on CLI pytest.fail("Need to use redis_db marker if redis_db fixture is used") @@ -212,7 +199,7 @@ def _apply_db_cache(cache_key: CacheKey) -> None: """Re-apply cached table-creation DDL. Uses IF NOT EXISTS for idempotency.""" for (cluster, node), tables in DB_MIGRATIONS_CACHE[cache_key].items(): connection = cluster.get_node_connection(ClickhouseClientSettings.MIGRATE, node) - for table_name, create_table_query in tables.items(): + for _table_name, create_table_query in tables.items(): idempotent_query = create_table_query.replace( "CREATE TABLE", "CREATE TABLE IF NOT EXISTS" ).replace( @@ -225,9 +212,9 @@ def _apply_db_cache(cache_key: CacheKey) -> None: def _run_db_fixture( request: pytest.FixtureRequest, marker_name: str, - groups: Optional[Sequence[MigrationGroup]], + groups: Sequence[MigrationGroup] | None, cache_key: CacheKey, -) -> Generator[None, None, None]: +) -> Generator[None]: """Shared body for clickhouse table creation fixtures. Args: @@ -276,7 +263,8 @@ def _clear_db() -> None: or storage_key == StorageKey.OUTCOMES_HOURLY or storage_key == StorageKey.OUTCOMES_DAILY ): - table_name = schema.get_local_table_name() # type: ignore + assert isinstance(schema, TableSchema) + table_name = schema.get_local_table_name() nodes = [*cluster.get_local_nodes(), *cluster.get_distributed_nodes()] for node in nodes: @@ -285,7 +273,7 @@ def _clear_db() -> None: def _drop_tables() -> None: - clusters: Set[ClickhouseCluster] = set() + clusters: set[ClickhouseCluster] = set() for storage_key in get_all_storage_keys(): storage = get_storage(storage_key) cluster = storage.get_cluster() @@ -308,7 +296,7 @@ def _drop_tables() -> None: @pytest.fixture def custom_clickhouse_db( request: pytest.FixtureRequest, -) -> Generator[None, None, None]: +) -> Generator[None]: if not request.node.get_closest_marker("custom_clickhouse_db"): # Make people use the marker explicitly so `-m` works on CLI pytest.fail( @@ -322,9 +310,7 @@ def custom_clickhouse_db( @pytest.fixture -def clickhouse_db( - request: pytest.FixtureRequest, create_databases: None -) -> Generator[None, None, None]: +def clickhouse_db(request: pytest.FixtureRequest, create_databases: None) -> Generator[None]: yield from _run_db_fixture( request=request, marker_name="clickhouse_db", @@ -334,9 +320,7 @@ def clickhouse_db( @pytest.fixture -def events_db( - request: pytest.FixtureRequest, create_databases: None -) -> Generator[None, None, None]: +def events_db(request: pytest.FixtureRequest, create_databases: None) -> Generator[None]: groups = [ MigrationGroup.EVENTS, MigrationGroup.TRANSACTIONS, @@ -353,7 +337,7 @@ def events_db( @pytest.fixture -def eap(request: pytest.FixtureRequest, create_databases: None) -> Generator[None, None, None]: +def eap(request: pytest.FixtureRequest, create_databases: None) -> Generator[None]: groups = [MigrationGroup.EVENTS_ANALYTICS_PLATFORM, MigrationGroup.OUTCOMES] yield from _run_db_fixture( request=request, @@ -364,9 +348,7 @@ def eap(request: pytest.FixtureRequest, create_databases: None) -> Generator[Non @pytest.fixture -def genmetrics_db( - request: pytest.FixtureRequest, create_databases: None -) -> Generator[None, None, None]: +def genmetrics_db(request: pytest.FixtureRequest, create_databases: None) -> Generator[None]: groups = [MigrationGroup.GENERIC_METRICS] yield from _run_db_fixture( request=request, @@ -377,7 +359,7 @@ def genmetrics_db( @pytest.fixture(autouse=True) -def clear_recorded_metrics() -> Generator[None, None, None]: +def clear_recorded_metrics() -> Generator[None]: from snuba.utils.metrics.backends.testing import clear_recorded_metric_calls yield @@ -398,7 +380,7 @@ def convert(data: str, entity: str) -> str: @pytest.fixture def _build_snql_post_methods( request: Any, - test_entity: Union[str, Tuple[str, str]], + test_entity: str | tuple[str, str], test_app: Any, convert_legacy_to_snql: Callable[[str, str], str], ) -> Callable[..., Any]: diff --git a/tests/consumers/test_consumer_builder.py b/tests/consumers/test_consumer_builder.py index 5592e1d3a15..ed87e7c60f9 100644 --- a/tests/consumers/test_consumer_builder.py +++ b/tests/consumers/test_consumer_builder.py @@ -116,7 +116,7 @@ @pytest.mark.parametrize("con_build", [consumer_builder, consumer_builder_with_opt]) -def test_consumer_builder_non_optional_attributes(con_build) -> None: # type: ignore +def test_consumer_builder_non_optional_attributes(con_build: ConsumerBuilder) -> None: # Ensures that the ConsumerBuilders are assigning a # not-None value to the required attributes @@ -139,7 +139,7 @@ def test_consumer_builder_non_optional_attributes(con_build) -> None: # type: i @pytest.mark.parametrize("con_build", [consumer_builder, consumer_builder_with_opt]) -def test_consumer_builder_optional_attributes(con_build) -> None: # type: ignore +def test_consumer_builder_optional_attributes(con_build: ConsumerBuilder) -> None: # Ensures that the ConsumerBuilders are assigning # some value, None or not, to the optional attributes @@ -147,16 +147,16 @@ def test_consumer_builder_optional_attributes(con_build) -> None: # type: ignor # are passed in, stronger checks are performed # in a separate test below - con_build.replacements_topic - con_build.commit_log_topic + con_build.replacements_topic # noqa: B018 access verifies attribute is assigned + con_build.commit_log_topic # noqa: B018 access verifies attribute is assigned - con_build.replacements_producer - con_build.commit_log_producer + con_build.replacements_producer # noqa: B018 access verifies attribute is assigned + con_build.commit_log_producer # noqa: B018 access verifies attribute is assigned - con_build.strict_offset_reset - con_build.processes - con_build.input_block_size - con_build.output_block_size + con_build.strict_offset_reset # noqa: B018 access verifies attribute is assigned + con_build.processes # noqa: B018 access verifies attribute is assigned + con_build.input_block_size # noqa: B018 access verifies attribute is assigned + con_build.output_block_size # noqa: B018 access verifies attribute is assigned @pytest.mark.events_db @@ -181,7 +181,7 @@ def test_run_processing_strategy() -> None: strategy.submit(message) # Wait for the commit - for i in range(10): + for _i in range(10): time.sleep(0.5) strategy.poll() if commit.call_count == 1: diff --git a/tests/consumers/test_message_processors.py b/tests/consumers/test_message_processors.py index 019fcee425d..7208a6ca3d1 100644 --- a/tests/consumers/test_message_processors.py +++ b/tests/consumers/test_message_processors.py @@ -3,7 +3,7 @@ import json import time from datetime import datetime -from typing import Type +from typing import Any, cast from unittest.mock import patch import pytest @@ -32,7 +32,7 @@ ], ) @patch("snuba.settings.DISCARD_OLD_EVENTS", False) -def test_message_processors(topic: str, processor: Type[DatasetMessageProcessor]) -> None: +def test_message_processors(topic: str, processor: type[DatasetMessageProcessor]) -> None: """ Tests the output of Python and Rust message processors is the same """ @@ -96,7 +96,7 @@ def test_replays_message_processor() -> None: millis_since_epoch = int(time.time() * 1000) rust_processed_message = bytes( - rust_snuba.process_message( # type: ignore + rust_snuba.process_message( # type: ignore[attr-defined] processor_name, data_bytes, partition, offset, millis_since_epoch )[0] ) @@ -116,18 +116,18 @@ def test_replays_message_processor() -> None: continue parsed_rust_message = json.loads(line) - parsed_python_message = python_processed_message.rows[0] + parsed_python_message = cast(dict[str, Any], python_processed_message.rows[0]) # timestamp is sometimes in different formats so we'll coerce. ts1 = parsed_rust_message.pop("timestamp", None) - ts2 = parsed_python_message.pop("timestamp", None) # type: ignore + ts2 = parsed_python_message.pop("timestamp", None) if isinstance(ts2, datetime): ts2 = int(ts2.timestamp()) assert ts1 == ts2 # replay_start_timestamp is sometimes in different formats so we'll coerce. sts1 = parsed_rust_message.pop("replay_start_timestamp", None) - sts2 = parsed_python_message.pop("replay_start_timestamp", None) # type: ignore + sts2 = parsed_python_message.pop("replay_start_timestamp", None) if isinstance(sts2, datetime): sts2 = int(sts2.timestamp()) assert sts1 == sts2 @@ -135,7 +135,7 @@ def test_replays_message_processor() -> None: # event_hash is generated by the consumer and not always consistent if # no segment_id is present so we'll coerce. parsed_rust_message.pop("event_hash", None) - parsed_python_message.pop("event_hash", None) # type: ignore + parsed_python_message.pop("event_hash", None) # The python message is a subset of the rust message which contains the complete # row definition. This is due to a defect in the python processor. We take the diff --git a/tests/consumers/test_schemas.py b/tests/consumers/test_schemas.py index 5d7f3bd9c99..0dd45d73614 100644 --- a/tests/consumers/test_schemas.py +++ b/tests/consumers/test_schemas.py @@ -1,6 +1,7 @@ +from collections.abc import Iterator from dataclasses import dataclass from datetime import datetime -from typing import Any, Iterator, Optional +from typing import Any import pytest import sentry_kafka_schemas @@ -24,7 +25,7 @@ class Case: example: Example processor: MessageProcessor - replacer_processor: Optional[ReplacerProcessor[Any]] + replacer_processor: ReplacerProcessor[Any] | None def __repr__(self) -> str: return repr(self.example) @@ -101,6 +102,6 @@ def test_has_kafka_schema() -> None: sentry_kafka_schemas.get_codec(topic_name) except sentry_kafka_schemas.SchemaNotFound: if topic_name in DEPRECATED_TOPICS: - print("Skipped validation for topic: %s" % topic_name) + print(f"Skipped validation for topic: {topic_name}") else: raise diff --git a/tests/consumers/test_utils.py b/tests/consumers/test_utils.py index 69bd5bd6bf3..b48d11e3249 100644 --- a/tests/consumers/test_utils.py +++ b/tests/consumers/test_utils.py @@ -2,7 +2,10 @@ import pytest from confluent_kafka import KafkaException -from confluent_kafka.admin import AdminClient, ClusterMetadata +from confluent_kafka.admin import ( # type: ignore[attr-defined] # ClusterMetadata lacks explicit re-export + AdminClient, + ClusterMetadata, +) from snuba import settings from snuba.consumers.utils import TopicNotFound, get_partition_count diff --git a/tests/datasets/cdc/test_groupassignee.py b/tests/datasets/cdc/test_groupassignee.py index 3939e3d22de..3ba6fb48738 100644 --- a/tests/datasets/cdc/test_groupassignee.py +++ b/tests/datasets/cdc/test_groupassignee.py @@ -1,4 +1,4 @@ -from datetime import datetime, timezone +from datetime import UTC, datetime import pytest @@ -136,7 +136,7 @@ class TestGroupassignee: "record_deleted": 0, "user_id": 1, "team_id": None, - "date_added": datetime(2019, 9, 19, 0, 17, 55, tzinfo=timezone.utc), + "date_added": datetime(2019, 9, 19, 0, 17, 55, tzinfo=UTC), } PROCESSED_UPDATE = { @@ -146,7 +146,7 @@ class TestGroupassignee: "record_deleted": 0, "user_id": 1, "team_id": None, - "date_added": datetime(2019, 9, 19, 0, 17, 55, tzinfo=timezone.utc), + "date_added": datetime(2019, 9, 19, 0, 17, 55, tzinfo=UTC), } DELETED = { @@ -167,7 +167,7 @@ def test_messages(self) -> None: ret = processor.process_message(self.INSERT_MSG, metadata) assert ret == InsertBatch( [self.PROCESSED], - datetime(2019, 9, 19, 0, 17, 55, 32443, tzinfo=timezone.utc), + datetime(2019, 9, 19, 0, 17, 55, 32443, tzinfo=UTC), ) write_processed_messages(self.storage, [ret]) results = ( @@ -189,7 +189,7 @@ def test_messages(self) -> None: ret = processor.process_message(self.UPDATE_MSG_NO_KEY_CHANGE, metadata) assert ret == InsertBatch( [self.PROCESSED], - datetime(2019, 9, 19, 0, 6, 56, 376853, tzinfo=timezone.utc), + datetime(2019, 9, 19, 0, 6, 56, 376853, tzinfo=UTC), ) # Tests an update with key change which becomes a two inserts: @@ -197,13 +197,13 @@ def test_messages(self) -> None: ret = processor.process_message(self.UPDATE_MSG_WITH_KEY_CHANGE, metadata) assert ret == InsertBatch( [self.DELETED, self.PROCESSED_UPDATE], - datetime(2019, 9, 19, 0, 6, 56, 376853, tzinfo=timezone.utc), + datetime(2019, 9, 19, 0, 6, 56, 376853, tzinfo=UTC), ) ret = processor.process_message(self.DELETE_MSG, metadata) assert ret == InsertBatch( [self.DELETED], - datetime(2019, 9, 19, 0, 17, 21, 447870, tzinfo=timezone.utc), + datetime(2019, 9, 19, 0, 17, 21, 447870, tzinfo=UTC), ) def test_bulk_load(self) -> None: diff --git a/tests/datasets/cdc/test_groupedmessage.py b/tests/datasets/cdc/test_groupedmessage.py index 2a2e62e660d..bec0b55b8dd 100644 --- a/tests/datasets/cdc/test_groupedmessage.py +++ b/tests/datasets/cdc/test_groupedmessage.py @@ -1,4 +1,4 @@ -from datetime import datetime, timezone +from datetime import UTC, datetime import pytest @@ -206,9 +206,9 @@ class TestGroupedMessage: "id": 74, "record_deleted": 0, "status": 0, - "last_seen": datetime(2019, 6, 19, 6, 46, 28, tzinfo=timezone.utc), - "first_seen": datetime(2019, 6, 19, 6, 45, 32, tzinfo=timezone.utc), - "active_at": datetime(2019, 6, 19, 6, 45, 32, tzinfo=timezone.utc), + "last_seen": datetime(2019, 6, 19, 6, 46, 28, tzinfo=UTC), + "first_seen": datetime(2019, 6, 19, 6, 45, 32, tzinfo=UTC), + "active_at": datetime(2019, 6, 19, 6, 45, 32, tzinfo=UTC), "first_release_id": None, } @@ -232,7 +232,7 @@ def test_messages(self) -> None: ret = processor.process_message(self.INSERT_MSG, metadata) assert ret == InsertBatch( [self.PROCESSED], - datetime(2019, 9, 19, 0, 17, 21, 447870, tzinfo=timezone.utc), + datetime(2019, 9, 19, 0, 17, 21, 447870, tzinfo=UTC), ) write_processed_messages(self.storage, [ret]) results = ( @@ -256,13 +256,13 @@ def test_messages(self) -> None: ret = processor.process_message(self.UPDATE_MSG, metadata) assert ret == InsertBatch( [self.PROCESSED], - datetime(2019, 9, 19, 0, 17, 21, 447870, tzinfo=timezone.utc), + datetime(2019, 9, 19, 0, 17, 21, 447870, tzinfo=UTC), ) ret = processor.process_message(self.DELETE_MSG, metadata) assert ret == InsertBatch( [self.DELETED], - datetime(2019, 9, 19, 0, 17, 21, 447870, tzinfo=timezone.utc), + datetime(2019, 9, 19, 0, 17, 21, 447870, tzinfo=UTC), ) def test_bulk_load(self) -> None: diff --git a/tests/datasets/configuration/test_entity_loader.py b/tests/datasets/configuration/test_entity_loader.py index e0e9973e3e6..5cfca1c3e95 100644 --- a/tests/datasets/configuration/test_entity_loader.py +++ b/tests/datasets/configuration/test_entity_loader.py @@ -12,6 +12,7 @@ ) from snuba.datasets.configuration.entity_builder import build_entity_from_config from snuba.query.expressions import Column, FunctionCall, Literal +from snuba.utils.schemas import FixedString from tests.datasets.configuration.utils import ConfigurationTest @@ -29,8 +30,12 @@ def test_entity_loader_fixed_string(self) -> None: ) columns = list(entity.get_data_model()) assert len(columns) == 3 - assert columns[0].type.length == 420 # type: ignore - assert columns[2].type.length == 69 # type: ignore + column_0_type = columns[0].type + column_2_type = columns[2].type + assert isinstance(column_0_type, FixedString) + assert isinstance(column_2_type, FixedString) + assert column_0_type.length == 420 + assert column_2_type.length == 69 def test_bad_configuration_broken_query_processor(self) -> None: with pytest.raises(JsonSchemaValueException): diff --git a/tests/datasets/configuration/test_storage_loader.py b/tests/datasets/configuration/test_storage_loader.py index 28afffd0f97..41863bbaae9 100644 --- a/tests/datasets/configuration/test_storage_loader.py +++ b/tests/datasets/configuration/test_storage_loader.py @@ -97,26 +97,24 @@ def test_processor_with_constructor(self) -> None: mapping_optimizer_qp, clickhouse_settings_override_qp, ) = storage.get_query_processors() - assert getattr(mapping_optimizer_qp, "_MappingOptimizer__column_name") == "a" - assert getattr(mapping_optimizer_qp, "_MappingOptimizer__hash_map_name") == "hashmap" - assert getattr(mapping_optimizer_qp, "_MappingOptimizer__killswitch") == "kill" + assert mapping_optimizer_qp._MappingOptimizer__column_name == "a" + assert mapping_optimizer_qp._MappingOptimizer__hash_map_name == "hashmap" + assert mapping_optimizer_qp._MappingOptimizer__killswitch == "kill" assert ( - getattr( - clickhouse_settings_override_qp, - "_ClickhouseSettingsOverride__settings", - )["max_rows_to_group_by"] + clickhouse_settings_override_qp._ClickhouseSettingsOverride__settings[ + "max_rows_to_group_by" + ] == 1000000 ) assert ( - getattr( - clickhouse_settings_override_qp, - "_ClickhouseSettingsOverride__settings", - )["group_by_overflow_mode"] + clickhouse_settings_override_qp._ClickhouseSettingsOverride__settings[ + "group_by_overflow_mode" + ] == "any" ) assert storage.required_time_column == "timestamp" assert len(policies := storage.get_allocation_policies()) == 2 - assert set([p.class_name() for p in policies]) == { + assert {p.class_name() for p in policies} == { "BytesScannedWindowAllocationPolicy", "PassthroughPolicy", } @@ -132,12 +130,10 @@ def test_processor_with_constructor(self) -> None: assert len(storage.get_deletion_processors()) == 1 column_filter_processor = storage.get_deletion_processors()[0] - assert getattr(column_filter_processor, "_ColumnFilterProcessor__column_filters") == { - "some_column" - } + assert column_filter_processor._ColumnFilterProcessor__column_filters == {"some_column"} assert len(delete_policies := storage.get_delete_allocation_policies()) == 1 - assert set([p.class_name() for p in delete_policies]) == { + assert {p.class_name() for p in delete_policies} == { "DeleteConcurrentRateLimitAllocationPolicy", } diff --git a/tests/datasets/configuration/utils.py b/tests/datasets/configuration/utils.py index 83750e94a23..3ad236d7df8 100644 --- a/tests/datasets/configuration/utils.py +++ b/tests/datasets/configuration/utils.py @@ -1,4 +1,4 @@ -from typing import Generator +from collections.abc import Generator import pytest @@ -7,6 +7,6 @@ class ConfigurationTest: @pytest.fixture(autouse=True) - def reset_configuration(self) -> Generator[None, None, None]: + def reset_configuration(self) -> Generator[None]: reset_dataset_factory() yield diff --git a/tests/datasets/entities/storage_selectors/test_errors.py b/tests/datasets/entities/storage_selectors/test_errors.py index 570e133b463..f0bab1e56a1 100644 --- a/tests/datasets/entities/storage_selectors/test_errors.py +++ b/tests/datasets/entities/storage_selectors/test_errors.py @@ -1,5 +1,3 @@ -from typing import List - import pytest from snuba import state @@ -104,7 +102,7 @@ def test_query_storage_selector( snql_query: str, dataset: Dataset, - storage_connections: List[EntityStorageConnection], + storage_connections: list[EntityStorageConnection], selector: QueryStorageSelector, use_readable: bool, expected_storage: Storage, diff --git a/tests/datasets/entities/storage_selectors/test_selector.py b/tests/datasets/entities/storage_selectors/test_selector.py index 287004dddbd..a250d955e73 100644 --- a/tests/datasets/entities/storage_selectors/test_selector.py +++ b/tests/datasets/entities/storage_selectors/test_selector.py @@ -1,5 +1,3 @@ -from typing import List - import pytest from snuba.datasets.dataset import Dataset @@ -62,7 +60,7 @@ def test_default_query_storage_selector( snql_query: str, dataset: Dataset, - storage_connections: List[EntityStorageConnection], + storage_connections: list[EntityStorageConnection], selector: QueryStorageSelector, expected_storage: Storage, ) -> None: diff --git a/tests/datasets/entities/test_entity_key.py b/tests/datasets/entities/test_entity_key.py index f277307aa0d..18e89477028 100644 --- a/tests/datasets/entities/test_entity_key.py +++ b/tests/datasets/entities/test_entity_key.py @@ -7,7 +7,7 @@ def test_entity_key() -> None: initialize_entity_factory() with pytest.raises(AttributeError): - EntityKey.NON_EXISTENT_ENTITY + EntityKey.NON_EXISTENT_ENTITY # noqa: B018 access triggers expected AttributeError assert ( REGISTERED_ENTITY_KEYS["GENERIC_METRICS_DISTRIBUTIONS"] == "generic_metrics_distributions" diff --git a/tests/datasets/entities/test_pluggable_entity.py b/tests/datasets/entities/test_pluggable_entity.py index 76835b3af88..49b7394001f 100644 --- a/tests/datasets/entities/test_pluggable_entity.py +++ b/tests/datasets/entities/test_pluggable_entity.py @@ -1,5 +1,5 @@ -from datetime import datetime, timedelta, timezone -from typing import Mapping +from collections.abc import Mapping +from datetime import UTC, datetime, timedelta import pytest @@ -29,15 +29,14 @@ from snuba.query.query_settings import HTTPQuerySettings from snuba.query.snql.parser import parse_snql_query from snuba.request import Request -from snuba.utils.schemas import AggregateFunction +from snuba.utils.schemas import AggregateFunction, DateTime, Nested, UInt from snuba.utils.schemas import Column as SchemaColumn -from snuba.utils.schemas import DateTime, Nested, UInt @pytest.fixture def start_time() -> datetime: return (datetime.utcnow() - timedelta(days=1)).replace( - hour=12, minute=15, second=0, microsecond=0, tzinfo=timezone.utc + hour=12, minute=15, second=0, microsecond=0, tzinfo=UTC ) diff --git a/tests/datasets/plans/test_cluster_selector.py b/tests/datasets/plans/test_cluster_selector.py index 19db631660d..e77576150de 100644 --- a/tests/datasets/plans/test_cluster_selector.py +++ b/tests/datasets/plans/test_cluster_selector.py @@ -1,5 +1,4 @@ import os -from typing import Optional from unittest.mock import patch import pytest @@ -163,7 +162,7 @@ def test_column_based_partition_selector( def test_should_use_mega_cluster( storage_set: StorageSetKey, logical_partition: int, - override_config: Optional[str], + override_config: str | None, expected: bool, ) -> None: if override_config: diff --git a/tests/datasets/plans/translator/test_mapping.py b/tests/datasets/plans/translator/test_mapping.py index 1354d1462a3..0366801f6f4 100644 --- a/tests/datasets/plans/translator/test_mapping.py +++ b/tests/datasets/plans/translator/test_mapping.py @@ -162,11 +162,7 @@ FunctionCall( "alias", "f", - ( - Column( - alias=None, table_name=None, column_name="users_crashed" - ), - ), + (Column(alias=None, table_name=None, column_name="users_crashed"),), ), ), ], @@ -284,9 +280,7 @@ "testF", ( Column(alias=None, table_name=None, column_name="platform"), - Column( - alias=None, table_name=None, column_name="tags_value" - ), + Column(alias=None, table_name=None, column_name="tags_value"), ), ), ), @@ -306,7 +300,7 @@ FunctionCall( "alias", "f2", - tuple(), + (), ), ), ], @@ -379,7 +373,7 @@ FunctionCall( "alias", "f2", - tuple(), + (), ), ), ], diff --git a/tests/datasets/storages/processors/test_replaced_groups.py b/tests/datasets/storages/processors/test_replaced_groups.py index fda74dc59bd..f2e20423fa5 100644 --- a/tests/datasets/storages/processors/test_replaced_groups.py +++ b/tests/datasets/storages/processors/test_replaced_groups.py @@ -1,5 +1,5 @@ +from collections.abc import Sequence from datetime import datetime, timedelta -from typing import Sequence import pytest @@ -139,25 +139,20 @@ def test_without_turbo_with_projects_needing_final(query: ClickhouseQuery) -> No ) query_settings = HTTPQuerySettings() - PostReplacementConsistencyEnforcer( - "project_id", ReplacerState.ERRORS - ).process_query(query, query_settings) + PostReplacementConsistencyEnforcer("project_id", ReplacerState.ERRORS).process_query( + query, query_settings + ) assert query.get_condition() == build_in("project_id", [2]) assert query.get_from_clause().final assert ( - query_settings.get_clickhouse_settings()[ - "do_not_merge_across_partitions_select_final" - ] - == 1 + query_settings.get_clickhouse_settings()["do_not_merge_across_partitions_select_final"] == 1 ) @pytest.mark.redis_db def test_without_turbo_without_projects_needing_final(query: ClickhouseQuery) -> None: - PostReplacementConsistencyEnforcer("project_id", None).process_query( - query, HTTPQuerySettings() - ) + PostReplacementConsistencyEnforcer("project_id", None).process_query(query, HTTPQuerySettings()) assert query.get_condition() == build_in("project_id", [2]) assert not query.get_from_clause().final @@ -171,16 +166,16 @@ def test_remove_final_subscriptions(query: ClickhouseQuery) -> None: ReplacementType.EXCLUDE_GROUPS, # Arbitrary replacement type, no impact on tests ) - PostReplacementConsistencyEnforcer( - "project_id", ReplacerState.ERRORS - ).process_query(query, SubscriptionQuerySettings()) + PostReplacementConsistencyEnforcer("project_id", ReplacerState.ERRORS).process_query( + query, SubscriptionQuerySettings() + ) assert query.get_condition() == build_in("project_id", [2]) assert query.get_from_clause().final state.set_config("skip_final_subscriptions_projects", "[2,3,4]") - PostReplacementConsistencyEnforcer( - "project_id", ReplacerState.ERRORS - ).process_query(query, SubscriptionQuerySettings()) + PostReplacementConsistencyEnforcer("project_id", ReplacerState.ERRORS).process_query( + query, SubscriptionQuerySettings() + ) assert not query.get_from_clause().final @@ -194,9 +189,9 @@ def test_not_many_groups_to_exclude(query: ClickhouseQuery) -> None: ReplacementType.EXCLUDE_GROUPS, # Arbitrary replacement type, no impact on tests ) - PostReplacementConsistencyEnforcer( - "project_id", ReplacerState.ERRORS - ).process_query(query, HTTPQuerySettings()) + PostReplacementConsistencyEnforcer("project_id", ReplacerState.ERRORS).process_query( + query, HTTPQuerySettings() + ) assert query.get_condition() == build_and( FunctionCall( @@ -230,9 +225,9 @@ def test_too_many_groups_to_exclude(query: ClickhouseQuery) -> None: ReplacementType.EXCLUDE_GROUPS, # Arbitrary replacement type, no impact on tests ) - PostReplacementConsistencyEnforcer( - "project_id", ReplacerState.ERRORS - ).process_query(query, HTTPQuerySettings()) + PostReplacementConsistencyEnforcer("project_id", ReplacerState.ERRORS).process_query( + query, HTTPQuerySettings() + ) assert query.get_condition() == build_in("project_id", [2]) assert query.get_from_clause().final diff --git a/tests/datasets/storages/test_storages.py b/tests/datasets/storages/test_storages.py index 38fa977bc95..433713aff54 100644 --- a/tests/datasets/storages/test_storages.py +++ b/tests/datasets/storages/test_storages.py @@ -5,4 +5,4 @@ def test_storage_key() -> None: with pytest.raises(AttributeError): - StorageKey.NON_EXISTENT_STORAGE + StorageKey.NON_EXISTENT_STORAGE # noqa: B018 access triggers expected AttributeError diff --git a/tests/datasets/test_context_promotion.py b/tests/datasets/test_context_promotion.py index 9a9899dbac7..70ea53aeb48 100644 --- a/tests/datasets/test_context_promotion.py +++ b/tests/datasets/test_context_promotion.py @@ -39,9 +39,7 @@ ], ) @pytest.mark.clickhouse_db -def test_span_id_promotion( - entity_name: str, dataset_name: str, expected_table_name: str -) -> None: +def test_span_id_promotion(entity_name: str, dataset_name: str, expected_table_name: str) -> None: """In order to save space in the contexts column and provide faster query performance, we promote span_id to a proper column and don't store it in the actual contexts object in the DB. @@ -104,9 +102,7 @@ def test_span_id_promotion( SelectedExpression( name="contexts[trace.span_id]", # the select converts the span_id into a lowecase hex string - expression=HexIntColumnProcessor(columns="span_id")._process_expressions( - column - ), + expression=HexIntColumnProcessor(columns="span_id")._process_expressions(column), ) ] @@ -116,9 +112,7 @@ def __init__(self) -> None: super().__init__() def visit_function_call(self, exp: FunctionCall) -> None: - if exp.function_name == "equals" and exp.parameters[0] == Column( - None, None, "span_id" - ): + if exp.function_name == "equals" and exp.parameters[0] == Column(None, None, "span_id"): self.found_span_condition = True # and here we can see that the hex string the client queried us with # has been converted to the correct uint64 diff --git a/tests/datasets/test_dataset_factory.py b/tests/datasets/test_dataset_factory.py index 5b2b96a59e5..0b21db62cf7 100644 --- a/tests/datasets/test_dataset_factory.py +++ b/tests/datasets/test_dataset_factory.py @@ -1,4 +1,4 @@ -from typing import Iterator +from collections.abc import Iterator import pytest @@ -38,7 +38,7 @@ def test_get_dataset() -> None: @pytest.fixture(scope="function") def disable_datasets() -> Iterator[None]: og_disabled = settings.DISABLED_DATASETS - settings.DISABLED_DATASETS = set(["events"]) + settings.DISABLED_DATASETS = {"events"} yield settings.DISABLED_DATASETS = og_disabled diff --git a/tests/datasets/test_discover.py b/tests/datasets/test_discover.py index b36eef291dc..20fe66a075e 100644 --- a/tests/datasets/test_discover.py +++ b/tests/datasets/test_discover.py @@ -1,4 +1,5 @@ -from typing import Any, MutableMapping +from collections.abc import MutableMapping +from typing import Any import pytest from snuba_sdk.legacy import json_to_snql diff --git a/tests/datasets/test_errors_processor.py b/tests/datasets/test_errors_processor.py index f5478177f37..d4474f0a7eb 100644 --- a/tests/datasets/test_errors_processor.py +++ b/tests/datasets/test_errors_processor.py @@ -3,9 +3,10 @@ import json import random import uuid +from collections.abc import Mapping, MutableMapping, Sequence from dataclasses import dataclass -from datetime import datetime, timedelta, timezone -from typing import Any, Mapping, MutableMapping, Never, Sequence +from datetime import UTC, datetime, timedelta +from typing import Any, Never from unittest.mock import ANY import pytest @@ -265,7 +266,7 @@ def serialize(self) -> tuple[int, str, Mapping[str, Any], dict[Never, Never]]: def build_result(self, meta: KafkaMessageMetadata) -> MutableMapping[str, Any]: expected_result = { "project_id": self.project_id, - "timestamp": int(self.timestamp.replace(tzinfo=timezone.utc).timestamp()), + "timestamp": int(self.timestamp.replace(tzinfo=UTC).timestamp()), "event_id": self.event_id, "platform": self.platform, "dist": self.dist, @@ -330,15 +331,15 @@ def build_result(self, meta: KafkaMessageMetadata) -> MutableMapping[str, Any]: ], "partition": meta.partition, "offset": meta.offset, - "message_timestamp": int(self.timestamp.replace(tzinfo=timezone.utc).timestamp()), - "timestamp_ms": int(self.timestamp.replace(tzinfo=timezone.utc).timestamp() * 1000), + "message_timestamp": int(self.timestamp.replace(tzinfo=UTC).timestamp()), + "timestamp_ms": int(self.timestamp.replace(tzinfo=UTC).timestamp() * 1000), "retention_days": 90, "deleted": 0, "group_id": self.group_id, "group_first_seen": int((self.timestamp - timedelta(days=2)).timestamp()), "primary_hash": "04233d08-ac90-cf6f-c015-b1be5932e7e2", "received": int( - self.received_timestamp.replace(tzinfo=timezone.utc) + self.received_timestamp.replace(tzinfo=UTC) .replace(tzinfo=None, microsecond=0) .timestamp() ), diff --git a/tests/datasets/test_errors_replacer.py b/tests/datasets/test_errors_replacer.py index b00c8dba62f..55e86b55755 100644 --- a/tests/datasets/test_errors_replacer.py +++ b/tests/datasets/test_errors_replacer.py @@ -1,8 +1,9 @@ import importlib import re import uuid +from collections.abc import Callable from datetime import datetime, timedelta -from typing import Any, Callable, Optional, Tuple, Union +from typing import Any import pytest import simplejson as json @@ -31,7 +32,7 @@ class BaseTest: @pytest.fixture - def test_entity(self) -> Union[str, Tuple[str, str]]: + def test_entity(self) -> str | tuple[str, str]: return "events" @pytest.fixture @@ -54,16 +55,16 @@ def setup_method(self) -> None: # Total query time range is 24h before to 24h after now to account # for local machine time zones - self.from_time = datetime.now().replace( - minute=0, second=0, microsecond=0 - ) - timedelta(days=1) + self.from_time = datetime.now().replace(minute=0, second=0, microsecond=0) - timedelta( + days=1 + ) self.to_time = self.from_time + timedelta(days=2) self.project_id = 1 self.event = get_raw_event() - def _wrap(self, msg: Tuple[Any, ...]) -> Message[KafkaPayload]: + def _wrap(self, msg: tuple[Any, ...]) -> Message[KafkaPayload]: return Message( BrokerValue( KafkaPayload(None, json.dumps(msg).encode("utf-8"), []), @@ -91,7 +92,7 @@ def _clear_redis_and_force_merge(self) -> None: clickhouse = cluster.get_query_connection(ClickhouseClientSettings.OPTIMIZE) run_optimize(clickhouse, self.storage, cluster.get_database()) - def _issue_count(self, project_id: int, group_id: Optional[int] = None) -> Any: + def _issue_count(self, project_id: int, group_id: int | None = None) -> Any: args = { "project": [project_id], "selected_columns": [], @@ -103,17 +104,15 @@ def _issue_count(self, project_id: int, group_id: Optional[int] = None) -> Any: } if group_id: - args.setdefault("conditions", list()).append(("group_id", "=", group_id)) + args.setdefault("conditions", []).append(("group_id", "=", group_id)) return json.loads(self.post(json.dumps(args)).data)["data"] - def _get_group_id(self, project_id: int, event_id: str) -> Optional[int]: + def _get_group_id(self, project_id: int, event_id: str) -> int | None: args = { "project": [project_id], "selected_columns": ["group_id"], - "conditions": [ - ["event_id", "=", str(uuid.UUID(event_id)).replace("-", "")] - ], + "conditions": [["event_id", "=", str(uuid.UUID(event_id)).replace("-", "")]], "from_date": self.from_time.isoformat(), "to_date": self.to_time.isoformat(), "tenant_ids": {"referrer": "r", "organization_id": 1234}, @@ -349,9 +348,7 @@ def test_process_offset_twice(self) -> None: "previous_group_id": 1, "new_group_id": 2, "hashes": ["a" * 32], - "datetime": datetime.utcnow().strftime( - PAYLOAD_DATETIME_FORMAT - ), + "datetime": datetime.utcnow().strftime(PAYLOAD_DATETIME_FORMAT), }, ) ).encode("utf-8"), @@ -441,9 +438,7 @@ def test_reset_consumer_group_offset_check(self) -> None: "previous_group_id": 1, "new_group_id": 2, "hashes": ["a" * 32], - "datetime": datetime.utcnow().strftime( - PAYLOAD_DATETIME_FORMAT - ), + "datetime": datetime.utcnow().strftime(PAYLOAD_DATETIME_FORMAT), }, ) ).encode("utf-8"), @@ -547,13 +542,15 @@ def test_delete_promoted_tag_process(self) -> None: assert ( re.sub("[\n ]+", " ", replacement.get_count_query("foo")).strip() - == "SELECT count() FROM %(table_name)s FINAL WHERE project_id = %(project_id)s AND received <= CAST('%(timestamp)s' AS DateTime) AND NOT deleted AND has(`tags.key`, %(tag_str)s)" - % query_args + == "SELECT count() FROM {table_name} FINAL WHERE project_id = {project_id} AND received <= CAST('{timestamp}' AS DateTime) AND NOT deleted AND has(`tags.key`, {tag_str})".format( + **query_args + ) ) assert ( re.sub("[\n ]+", " ", replacement.get_insert_query("foo")).strip() - == "INSERT INTO %(table_name)s (%(all_columns)s) SELECT %(select_columns)s FROM %(table_name)s FINAL WHERE project_id = %(project_id)s AND received <= CAST('%(timestamp)s' AS DateTime) AND NOT deleted AND has(`tags.key`, %(tag_str)s)" - % query_args + == "INSERT INTO {table_name} ({all_columns}) SELECT {select_columns} FROM {table_name} FINAL WHERE project_id = {project_id} AND received <= CAST('{timestamp}' AS DateTime) AND NOT deleted AND has(`tags.key`, {tag_str})".format( + **query_args + ) ) assert replacement.get_query_time_flags() == errors_replacer.NeedsFinal() @@ -585,20 +582,20 @@ def test_delete_unpromoted_tag_process(self) -> None: assert ( re.sub("[\n ]+", " ", replacement.get_count_query("foo")).strip() - == "SELECT count() FROM %(table_name)s FINAL WHERE project_id = %(project_id)s AND received <= CAST('%(timestamp)s' AS DateTime) AND NOT deleted AND has(`tags.key`, %(tag_str)s)" - % query_args + == "SELECT count() FROM {table_name} FINAL WHERE project_id = {project_id} AND received <= CAST('{timestamp}' AS DateTime) AND NOT deleted AND has(`tags.key`, {tag_str})".format( + **query_args + ) ) assert ( re.sub("[\n ]+", " ", replacement.get_insert_query("foo")).strip() - == "INSERT INTO %(table_name)s (%(all_columns)s) SELECT %(select_columns)s FROM %(table_name)s FINAL WHERE project_id = %(project_id)s AND received <= CAST('%(timestamp)s' AS DateTime) AND NOT deleted AND has(`tags.key`, %(tag_str)s)" - % query_args + == "INSERT INTO {table_name} ({all_columns}) SELECT {select_columns} FROM {table_name} FINAL WHERE project_id = {project_id} AND received <= CAST('{timestamp}' AS DateTime) AND NOT deleted AND has(`tags.key`, {tag_str})".format( + **query_args + ) ) assert replacement.get_query_time_flags() == errors_replacer.NeedsFinal() - @pytest.mark.parametrize( - "old_primary_hash", ["e3d704f3542b44a621ebed70dc0efe13", False, None] - ) + @pytest.mark.parametrize("old_primary_hash", ["e3d704f3542b44a621ebed70dc0efe13", False, None]) def test_tombstone_events_process(self, old_primary_hash) -> None: timestamp = datetime.now() message_kwargs = { @@ -617,9 +614,7 @@ def test_tombstone_events_process(self, old_primary_hash) -> None: _, replacement = meta_and_replacement old_primary_condition = ( - " AND primary_hash = 'e3d704f3-542b-44a6-21eb-ed70dc0efe13'" - if old_primary_hash - else "" + " AND primary_hash = 'e3d704f3-542b-44a6-21eb-ed70dc0efe13'" if old_primary_hash else "" ) query_args = { @@ -671,14 +666,16 @@ def test_replace_group_process(self) -> None: assert ( re.sub("[\n ]+", " ", replacement.get_count_query("foo")).strip() - == "SELECT count() FROM %(table_name)s FINAL PREWHERE event_id IN (%(event_ids)s) WHERE project_id = %(project_id)s AND NOT deleted" - % query_args + == "SELECT count() FROM {table_name} FINAL PREWHERE event_id IN ({event_ids}) WHERE project_id = {project_id} AND NOT deleted".format( + **query_args + ) ) assert ( re.sub("[\n ]+", " ", replacement.get_insert_query("foo")).strip() - == "INSERT INTO %(table_name)s (%(all_columns)s) SELECT %(select_columns)s FROM %(table_name)s FINAL PREWHERE event_id IN (%(event_ids)s) WHERE project_id = %(project_id)s AND NOT deleted" - % query_args + == "INSERT INTO {table_name} ({all_columns}) SELECT {select_columns} FROM {table_name} FINAL PREWHERE event_id IN ({event_ids}) WHERE project_id = {project_id} AND NOT deleted".format( + **query_args + ) ) assert replacement.get_query_time_flags() is None @@ -710,14 +707,16 @@ def test_replace_group_process_alternate_date(self) -> None: assert ( re.sub("[\n ]+", " ", replacement.get_count_query("foo")).strip() - == "SELECT count() FROM %(table_name)s FINAL PREWHERE event_id IN (%(event_ids)s) WHERE project_id = %(project_id)s AND NOT deleted" - % query_args + == "SELECT count() FROM {table_name} FINAL PREWHERE event_id IN ({event_ids}) WHERE project_id = {project_id} AND NOT deleted".format( + **query_args + ) ) assert ( re.sub("[\n ]+", " ", replacement.get_insert_query("foo")).strip() - == "INSERT INTO %(table_name)s (%(all_columns)s) SELECT %(select_columns)s FROM %(table_name)s FINAL PREWHERE event_id IN (%(event_ids)s) WHERE project_id = %(project_id)s AND NOT deleted" - % query_args + == "INSERT INTO {table_name} ({all_columns}) SELECT {select_columns} FROM {table_name} FINAL PREWHERE event_id IN ({event_ids}) WHERE project_id = {project_id} AND NOT deleted".format( + **query_args + ) ) assert replacement.get_query_time_flags() is None @@ -750,18 +749,18 @@ def test_merge_process(self) -> None: assert ( re.sub("[\n ]+", " ", replacement.get_count_query("foo")).strip() - == "SELECT count() FROM %(table_name)s FINAL PREWHERE group_id IN (%(previous_group_ids)s) WHERE project_id = %(project_id)s AND received <= CAST('%(timestamp)s' AS DateTime) AND NOT deleted" - % query_args + == "SELECT count() FROM {table_name} FINAL PREWHERE group_id IN ({previous_group_ids}) WHERE project_id = {project_id} AND received <= CAST('{timestamp}' AS DateTime) AND NOT deleted".format( + **query_args + ) ) assert ( re.sub("[\n ]+", " ", replacement.get_insert_query("foo")).strip() - == "INSERT INTO %(table_name)s (%(all_columns)s) SELECT %(select_columns)s FROM %(table_name)s FINAL PREWHERE group_id IN (%(previous_group_ids)s) WHERE project_id = %(project_id)s AND received <= CAST('%(timestamp)s' AS DateTime) AND NOT deleted" - % query_args + == "INSERT INTO {table_name} ({all_columns}) SELECT {select_columns} FROM {table_name} FINAL PREWHERE group_id IN ({previous_group_ids}) WHERE project_id = {project_id} AND received <= CAST('{timestamp}' AS DateTime) AND NOT deleted".format( + **query_args + ) ) - assert replacement.get_query_time_flags() == errors_replacer.ExcludeGroups( - [1, 2] - ) + assert replacement.get_query_time_flags() == errors_replacer.ExcludeGroups([1, 2]) def test_unmerge_process(self) -> None: timestamp = datetime.now() @@ -793,13 +792,15 @@ def test_unmerge_process(self) -> None: assert ( re.sub("[\n ]+", " ", replacement.get_count_query("foo")).strip() - == "SELECT count() FROM %(table_name)s FINAL PREWHERE primary_hash IN (%(hashes)s) WHERE group_id = %(previous_group_id)s AND project_id = %(project_id)s AND received <= CAST('%(timestamp)s' AS DateTime) AND NOT deleted" - % query_args + == "SELECT count() FROM {table_name} FINAL PREWHERE primary_hash IN ({hashes}) WHERE group_id = {previous_group_id} AND project_id = {project_id} AND received <= CAST('{timestamp}' AS DateTime) AND NOT deleted".format( + **query_args + ) ) assert ( re.sub("[\n ]+", " ", replacement.get_insert_query("foo")).strip() - == "INSERT INTO %(table_name)s (%(all_columns)s) SELECT %(select_columns)s FROM %(table_name)s FINAL PREWHERE primary_hash IN (%(hashes)s) WHERE group_id = %(previous_group_id)s AND project_id = %(project_id)s AND received <= CAST('%(timestamp)s' AS DateTime) AND NOT deleted" - % query_args + == "INSERT INTO {table_name} ({all_columns}) SELECT {select_columns} FROM {table_name} FINAL PREWHERE primary_hash IN ({hashes}) WHERE group_id = {previous_group_id} AND project_id = {project_id} AND received <= CAST('{timestamp}' AS DateTime) AND NOT deleted".format( + **query_args + ) ) assert replacement.get_query_time_flags() == errors_replacer.NeedsFinal() @@ -869,13 +870,15 @@ def test_delete_groups_process(self) -> None: } assert ( re.sub("[\n ]+", " ", replacement.get_count_query("foo")).strip() - == "SELECT count() FROM %(table_name)s FINAL PREWHERE group_id IN (%(group_ids)s) WHERE project_id = %(project_id)s AND received <= CAST('%(timestamp)s' AS DateTime) AND NOT deleted" - % query_args + == "SELECT count() FROM {table_name} FINAL PREWHERE group_id IN ({group_ids}) WHERE project_id = {project_id} AND received <= CAST('{timestamp}' AS DateTime) AND NOT deleted".format( + **query_args + ) ) assert ( re.sub("[\n ]+", " ", replacement.get_insert_query("foo")).strip() - == "INSERT INTO %(table_name)s (%(required_columns)s) SELECT %(select_columns)s FROM %(table_name)s FINAL PREWHERE group_id IN (%(group_ids)s) WHERE project_id = %(project_id)s AND received <= CAST('%(timestamp)s' AS DateTime) AND NOT deleted" - % query_args + == "INSERT INTO {table_name} ({required_columns}) SELECT {select_columns} FROM {table_name} FINAL PREWHERE group_id IN ({group_ids}) WHERE project_id = {project_id} AND received <= CAST('{timestamp}' AS DateTime) AND NOT deleted".format( + **query_args + ) ) assert replacement.get_project_id() == self.project_id assert replacement.get_query_time_flags() == errors_replacer.ExcludeGroups( @@ -904,9 +907,7 @@ def test_project_bypass(self) -> None: _, replacement = meta_and_replacement assert replacement is not None - set_config( - "replacements_bypass_projects", f"[{self.project_id + 1},{self.project_id}]" - ) + set_config("replacements_bypass_projects", f"[{self.project_id + 1},{self.project_id}]") meta_and_replacement = self.replacer.process_message(self._wrap(message)) assert meta_and_replacement is None delete_config("replacements_bypass_projects") diff --git a/tests/datasets/test_events.py b/tests/datasets/test_events.py index b3e5637edfc..3119e5e93f1 100644 --- a/tests/datasets/test_events.py +++ b/tests/datasets/test_events.py @@ -42,10 +42,8 @@ def test_tags_hash_map(self) -> None: tag1, tag2 = hashed[0] event = clickhouse.execute( - ( - f"SELECT replaceAll(toString(event_id), '-', '') FROM {table_name} WHERE has(_tags_hash_map, {tag1}) " - f"AND has(_tags_hash_map, {tag2})" - ) + f"SELECT replaceAll(toString(event_id), '-', '') FROM {table_name} WHERE has(_tags_hash_map, {tag1}) " + f"AND has(_tags_hash_map, {tag2})" ).results assert len(event) == 1 assert event[0][0] == self.event["data"]["id"] diff --git a/tests/datasets/test_functions_processor.py b/tests/datasets/test_functions_processor.py index 3a054103788..c31d21cf73d 100644 --- a/tests/datasets/test_functions_processor.py +++ b/tests/datasets/test_functions_processor.py @@ -1,6 +1,7 @@ +from collections.abc import Mapping, Sequence from dataclasses import dataclass -from datetime import datetime, timezone -from typing import Any, Mapping, Optional, Sequence +from datetime import UTC, datetime +from typing import Any from snuba.consumers.types import KafkaMessageMetadata from snuba.datasets.processors.functions_processor import FunctionsMessageProcessor @@ -14,7 +15,7 @@ class Function: package: str in_app: bool self_times_ns: Sequence[int] - thread_id: Optional[str] + thread_id: str | None def serialize(self) -> Mapping[str, Any]: return { @@ -29,19 +30,19 @@ def serialize(self) -> Mapping[str, Any]: @dataclass class ProfileFunctionsEvent: - environment: Optional[str] + environment: str | None functions: Sequence[Function] platform: str profile_id: str project_id: int - received: Optional[int] - release: Optional[str] + received: int | None + release: str | None retention_days: int timestamp: int transaction_name: str - start_timestamp: Optional[float] - end_timestamp: Optional[float] - profiling_type: Optional[str] + start_timestamp: float | None + end_timestamp: float | None + profiling_type: str | None def serialize(self) -> Mapping[str, Any]: return { @@ -65,7 +66,7 @@ class TestFunctionsProcessor: def test_process_message_functions(self) -> None: meta = KafkaMessageMetadata(offset=1, partition=2, timestamp=datetime(1970, 1, 1)) - now = int(datetime.now(timezone.utc).timestamp()) + now = int(datetime.now(UTC).timestamp()) message = ProfileFunctionsEvent( environment="prod", functions=[ diff --git a/tests/datasets/test_generic_metrics_processor.py b/tests/datasets/test_generic_metrics_processor.py index aa05fd7be8a..4634156f80f 100644 --- a/tests/datasets/test_generic_metrics_processor.py +++ b/tests/datasets/test_generic_metrics_processor.py @@ -1,5 +1,6 @@ -from datetime import datetime, timezone -from typing import Any, Iterable, Mapping, Tuple +from collections.abc import Iterable, Mapping +from datetime import UTC, datetime +from typing import Any import pytest @@ -9,7 +10,7 @@ GenericDistributionsMetricsProcessor, ) -timestamp = int(datetime.now(timezone.utc).timestamp()) +timestamp = int(datetime.now(UTC).timestamp()) MAPPING_META_COMMON = { "c": { @@ -28,7 +29,7 @@ def dis_processor() -> GenericDistributionsMetricsProcessor: return GenericDistributionsMetricsProcessor() -def sorted_tag_items(message: Mapping[str, Any]) -> Iterable[Tuple[str, int]]: +def sorted_tag_items(message: Mapping[str, Any]) -> Iterable[tuple[str, int]]: tags = message["tags"] return sorted(tags.items()) diff --git a/tests/datasets/test_group_attributes_join.py b/tests/datasets/test_group_attributes_join.py index d4ccceb774b..9e3e038ceb9 100644 --- a/tests/datasets/test_group_attributes_join.py +++ b/tests/datasets/test_group_attributes_join.py @@ -1,7 +1,8 @@ import uuid -from datetime import datetime, timedelta, timezone +from collections.abc import Mapping +from datetime import UTC, datetime, timedelta from functools import partial -from typing import Any, Dict, Mapping, Union +from typing import Any import pytest import simplejson as json @@ -46,9 +47,7 @@ def write_group_attribute_row(row: Mapping[str, Any]) -> None: def _convert_clickhouse_datetime_str(str: str) -> str: return ( - datetime.strptime(str, "%Y-%m-%d %H:%M:%S") - .replace(microsecond=0, tzinfo=timezone.utc) - .isoformat() + datetime.strptime(str, "%Y-%m-%d %H:%M:%S").replace(microsecond=0, tzinfo=UTC).isoformat() ) @@ -62,16 +61,16 @@ def setup_fixture(self, events_db: Any, redis_db: Any) -> None: self.app.post = partial(self.app.post, headers={"referer": "test"}) # type: ignore[method-assign] self.event = get_raw_event() self.project_id = self.event["project_id"] - self.base_time = datetime.utcnow().replace( - second=0, microsecond=0, tzinfo=timezone.utc - ) - timedelta(minutes=90) + self.base_time = datetime.utcnow().replace(second=0, microsecond=0, tzinfo=UTC) - timedelta( + minutes=90 + ) self.next_time = self.base_time + timedelta(minutes=95) self.events_storage = get_entity(EntityKey.EVENTS).get_writable_storage() assert self.events_storage is not None write_unprocessed_events(self.events_storage, [self.event]) - self.initial_group_attributes: Dict[str, Any] = { + self.initial_group_attributes: dict[str, Any] = { "deleted": False, "project_id": self.project_id, "group_id": self.event["group_id"], @@ -360,9 +359,9 @@ class TestSearchIssuesGroupAttributes(BaseApiTest): def setup_fixture(self, events_db: Any, redis_db: Any) -> None: self.app.post = partial(self.app.post, headers={"referer": "test"}) # type: ignore[method-assign] - self.base_time = datetime.utcnow().replace( - second=0, microsecond=0, tzinfo=timezone.utc - ) - timedelta(minutes=90) + self.base_time = datetime.utcnow().replace(second=0, microsecond=0, tzinfo=UTC) - timedelta( + minutes=90 + ) self.next_time = self.base_time + timedelta(minutes=95) self.occurrence = self.get_search_issue_occurrence(self.base_time) self.project_id = self.occurrence["project_id"] @@ -371,7 +370,7 @@ def setup_fixture(self, events_db: Any, redis_db: Any) -> None: assert self.search_issues_storage is not None write_unprocessed_events(self.search_issues_storage, [self.occurrence]) - self.initial_group_attributes: Dict[str, Any] = { + self.initial_group_attributes: dict[str, Any] = { "deleted": False, "project_id": self.project_id, "group_id": self.occurrence["group_id"], @@ -447,7 +446,7 @@ def query_search_issues_joined_group_attributes(self) -> Any: } def assert_joined_final( - clickhouse_query: Union[ClickhouseQuery, CompositeQuery[Table]], + clickhouse_query: ClickhouseQuery | CompositeQuery[Table], query_settings: QuerySettings, reader: Reader, cluster_name: str, diff --git a/tests/datasets/test_group_attributes_processor.py b/tests/datasets/test_group_attributes_processor.py index d6ec43109b3..c9f3737c815 100644 --- a/tests/datasets/test_group_attributes_processor.py +++ b/tests/datasets/test_group_attributes_processor.py @@ -1,5 +1,4 @@ from datetime import datetime -from typing import Optional import pytest from sentry_kafka_schemas.schema_types.group_attributes_v1 import ( @@ -39,15 +38,13 @@ def group_created() -> GroupAttributesSnapshot: class TestGroupAttributesMessageProcessor: - KAFKA_META = KafkaMessageMetadata( - offset=0, partition=0, timestamp=datetime(1970, 1, 1) - ) + KAFKA_META = KafkaMessageMetadata(offset=0, partition=0, timestamp=datetime(1970, 1, 1)) processor = GroupAttributesMessageProcessor() def process_message( self, message, kafka_meta: KafkaMessageMetadata = KAFKA_META - ) -> Optional[ProcessedMessage]: + ) -> ProcessedMessage | None: return self.processor.process_message(message, kafka_meta) def processed_single_row(self, message) -> WriterTableRow: diff --git a/tests/datasets/test_metrics_processing.py b/tests/datasets/test_metrics_processing.py index 19b281657f1..d90e707180b 100644 --- a/tests/datasets/test_metrics_processing.py +++ b/tests/datasets/test_metrics_processing.py @@ -50,7 +50,7 @@ "sumMergeIf", ( Column("_snuba_value", None, "value"), - FunctionCall(None, "cond", tuple()), + FunctionCall(None, "cond", ()), ), ), id="Test counters entity with mergeIf", @@ -75,7 +75,7 @@ "maxMergeIf", ( Column(None, None, "max"), - FunctionCall(None, "cond", tuple()), + FunctionCall(None, "cond", ()), ), ), id="Test distribution max with condition", @@ -100,7 +100,7 @@ "minMergeIf", ( Column(None, None, "min"), - FunctionCall(None, "cond", tuple()), + FunctionCall(None, "cond", ()), ), ), id="Test distribution min with condition", @@ -125,7 +125,7 @@ "avgMergeIf", ( Column(None, None, "avg"), - FunctionCall(None, "cond", tuple()), + FunctionCall(None, "cond", ()), ), ), id="Test distribution avg with condition", @@ -167,7 +167,7 @@ "quantilesMergeIf", tuple(Literal(None, quant) for quant in [0.5, 0.75, 0.9, 0.95, 0.99]), ), - (Column(None, None, "percentiles"), FunctionCall(None, "cond", tuple())), + (Column(None, None, "percentiles"), FunctionCall(None, "cond", ())), ), id="Test distribution quantiles", ), @@ -202,7 +202,7 @@ FunctionCall(None, "histogramMergeIf", (Literal(None, 250),)), ( Column(None, None, "histogram_buckets"), - FunctionCall(None, "cond", tuple()), + FunctionCall(None, "cond", ()), ), ), id="Test distribution histogram", @@ -210,9 +210,7 @@ ] -@pytest.mark.parametrize( - "entity_name, column_name, entity_key, translated_value", TEST_CASES -) +@pytest.mark.parametrize("entity_name, column_name, entity_key, translated_value", TEST_CASES) @pytest.mark.clickhouse_db def test_metrics_processing( entity_name: str, diff --git a/tests/datasets/test_metrics_processor.py b/tests/datasets/test_metrics_processor.py index 10b0aab884a..bf5bb9a0d5d 100644 --- a/tests/datasets/test_metrics_processor.py +++ b/tests/datasets/test_metrics_processor.py @@ -1,5 +1,6 @@ -from datetime import datetime, timezone -from typing import Any, Iterable, Mapping, Optional, Sequence, Tuple +from collections.abc import Iterable, Mapping, Sequence +from datetime import UTC, datetime +from typing import Any from unittest.mock import ANY import pytest @@ -16,12 +17,12 @@ MATERIALIZATION_VERSION = 4 -timestamp = int(datetime.now(timezone.utc).timestamp()) +timestamp = int(datetime.now(UTC).timestamp()) # expects that test is run in utc local time intermediate_timestamp = datetime.utcfromtimestamp(timestamp) -expected_timestamp = int(intermediate_timestamp.replace(tzinfo=timezone.utc).timestamp()) +expected_timestamp = int(intermediate_timestamp.replace(tzinfo=UTC).timestamp()) -sentry_received_timestamp = datetime.now(timezone.utc).timestamp() +sentry_received_timestamp = datetime.now(UTC).timestamp() expected_sentry_received_timestamp = datetime.utcfromtimestamp(sentry_received_timestamp) MAPPING_META_COMMON = { @@ -187,7 +188,7 @@ ) def test_metrics_polymorphic_processor( message: Mapping[str, Any], - expected_output: Optional[Sequence[Mapping[str, Any]]], + expected_output: Sequence[Mapping[str, Any]] | None, ) -> None: settings.DISABLED_DATASETS = set() @@ -227,7 +228,7 @@ def test_metrics_polymorphic_processor( ], ) def test_generic_metrics_sets_processor( - message: Mapping[str, Any], expected_output: Optional[Sequence[Mapping[str, Any]]] + message: Mapping[str, Any], expected_output: Sequence[Mapping[str, Any]] | None ) -> None: meta = KafkaMessageMetadata(offset=100, partition=1, timestamp=datetime(1970, 1, 1)) @@ -239,6 +240,6 @@ def test_generic_metrics_sets_processor( ) -def sorted_tag_items(message: Mapping[str, Any]) -> Iterable[Tuple[str, int]]: +def sorted_tag_items(message: Mapping[str, Any]) -> Iterable[tuple[str, int]]: tags = message["tags"] return sorted(tags.items()) diff --git a/tests/datasets/test_processors_idempotency.py b/tests/datasets/test_processors_idempotency.py index b7fd801d420..796c15aad79 100644 --- a/tests/datasets/test_processors_idempotency.py +++ b/tests/datasets/test_processors_idempotency.py @@ -1,5 +1,4 @@ from datetime import datetime -from typing import Tuple import pytest @@ -27,7 +26,7 @@ @pytest.mark.parametrize("message,processor", test_data) def test_processors_of_multistorage_consumer_are_idempotent( - message: Tuple[int, str, InsertEvent], processor: MessageProcessor + message: tuple[int, str, InsertEvent], processor: MessageProcessor ) -> None: """ Test that when the same message is provided to the processors, the result would be the same. That is the process diff --git a/tests/datasets/test_profiles_processor.py b/tests/datasets/test_profiles_processor.py index 24fc363cc85..c06f1140072 100644 --- a/tests/datasets/test_profiles_processor.py +++ b/tests/datasets/test_profiles_processor.py @@ -1,7 +1,8 @@ import uuid +from collections.abc import Mapping from dataclasses import asdict, dataclass from datetime import datetime -from typing import Any, Mapping, Optional +from typing import Any import pytest @@ -17,7 +18,7 @@ class ProfileEvent: transaction_id: str received: int profile_id: str - android_api_level: Optional[int] + android_api_level: int | None device_classification: str device_locale: str device_manufacturer: str @@ -27,7 +28,7 @@ class ProfileEvent: device_os_version: str architecture: str duration_ns: int - environment: Optional[str] + environment: str | None platform: str trace_id: str transaction_name: str @@ -50,9 +51,7 @@ def build_result(self, meta: KafkaMessageMetadata) -> Mapping[str, Any]: class TestProfilesProcessor: def test_missing_profile_id(self) -> None: - meta = KafkaMessageMetadata( - offset=1, partition=0, timestamp=datetime(1970, 1, 1) - ) + meta = KafkaMessageMetadata(offset=1, partition=0, timestamp=datetime(1970, 1, 1)) message = ProfileEvent( android_api_level=None, architecture="aarch64", @@ -83,13 +82,11 @@ def test_missing_profile_id(self) -> None: del payload["profile_id"] processor = ProfilesMessageProcessor() - with pytest.raises(Exception): + with pytest.raises(Exception): # noqa: B017 missing field may raise various exception types processor.process_message(payload, meta) def test_valid_message(self) -> None: - meta = KafkaMessageMetadata( - offset=0, partition=0, timestamp=datetime(1970, 1, 1) - ) + meta = KafkaMessageMetadata(offset=0, partition=0, timestamp=datetime(1970, 1, 1)) message = ProfileEvent( android_api_level=None, architecture="aarch64", diff --git a/tests/datasets/test_search_issues_processor.py b/tests/datasets/test_search_issues_processor.py index 7cf47205a36..9587887888b 100644 --- a/tests/datasets/test_search_issues_processor.py +++ b/tests/datasets/test_search_issues_processor.py @@ -1,8 +1,9 @@ import copy import uuid from collections import OrderedDict -from datetime import datetime, timedelta, timezone -from typing import Any, MutableMapping, Union +from collections.abc import MutableMapping +from datetime import UTC, datetime, timedelta +from typing import Any import pytest from snuba_sdk.legacy import json_to_snql @@ -51,9 +52,7 @@ def message_base() -> SearchIssueEvent: class TestSearchIssuesMessageProcessor: - KAFKA_META = KafkaMessageMetadata( - offset=0, partition=0, timestamp=datetime(1970, 1, 1) - ) + KAFKA_META = KafkaMessageMetadata(offset=0, partition=0, timestamp=datetime(1970, 1, 1)) processor = SearchIssuesMessageProcessor() REQUIRED_COLUMNS = { @@ -74,14 +73,10 @@ class TestSearchIssuesMessageProcessor: "message", } - def process_message( - self, message, version=2, operation="insert", kafka_meta=KAFKA_META - ): + def process_message(self, message, version=2, operation="insert", kafka_meta=KAFKA_META): return self.processor.process_message((version, operation, message), kafka_meta) - def assert_required_columns( - self, processed: Union[InsertBatch, ReplacementBatch, None] - ): + def assert_required_columns(self, processed: InsertBatch | ReplacementBatch | None): assert processed assert len(processed.rows) == 1 assert processed.rows[0].keys() > self.REQUIRED_COLUMNS @@ -111,9 +106,7 @@ def test_extract_client_timestamp(self, message_base): del missing_client_timestamp["datetime"] with_data_client_timestamp = copy.deepcopy(missing_client_timestamp) - with_data_client_timestamp["data"][ - "client_timestamp" - ] = datetime.now().timestamp() + with_data_client_timestamp["data"]["client_timestamp"] = datetime.now().timestamp() with_event_datetime = copy.deepcopy(missing_client_timestamp) with_event_datetime["datetime"] = datetime.now().isoformat() + "Z" @@ -128,12 +121,8 @@ def test_extract_timestamp_ms(self, message_base): processed = self.process_message(message_base) self.assert_required_columns(processed) insert_row = processed.rows[0] - client_timestamp_utc = insert_row["client_timestamp"].replace( - tzinfo=timezone.utc - ) - assert insert_row["timestamp_ms"] == int( - client_timestamp_utc.timestamp() * 1000 - ) + client_timestamp_utc = insert_row["client_timestamp"].replace(tzinfo=UTC) + assert insert_row["timestamp_ms"] == int(client_timestamp_utc.timestamp() * 1000) def test_extract_user(self, message_base): message_with_user = message_base @@ -148,12 +137,12 @@ def test_extract_user(self, message_base): insert_row = processed.rows[0] assert ( insert_row.items() - > dict( - user_name="user", - user_id="1", - user_email="test@example.com", - ip_address_v4="127.0.0.1", - ).items() + > { + "user_name": "user", + "user_id": "1", + "user_email": "test@example.com", + "ip_address_v4": "127.0.0.1", + }.items() ) def test_extract_user_empty(self, message_base): @@ -163,11 +152,11 @@ def test_extract_user_empty(self, message_base): insert_row = processed.rows[0] assert ( insert_row.items() - > dict( - user_name=None, - user_id=None, - user_email=None, - ).items() + > { + "user_name": None, + "user_id": None, + "user_email": None, + }.items() ) def test_extract_promoted_user_from_tag(self, message_base): @@ -247,10 +236,7 @@ def test_extract_http(self, message_base): self.assert_required_columns(processed) insert_row = processed.rows[0] assert "http_method" in insert_row and insert_row["http_method"] == "GET" - assert ( - "http_referer" in insert_row - and insert_row["http_referer"] == "http://example.com" - ) + assert "http_referer" in insert_row and insert_row["http_referer"] == "http://example.com" def test_extract_sdk(self, message_base): message_base["data"]["sdk"] = { @@ -274,12 +260,8 @@ def test_extract_context_null_dicts(self, message_base): processed = self.process_message(message_base) self.assert_required_columns(processed) insert_row = processed.rows[0] - assert "contexts.key" in insert_row and insert_row["contexts.key"] == [ - "scalar.string" - ] - assert "contexts.value" in insert_row and insert_row["contexts.value"] == [ - "scalar_value" - ] + assert "contexts.key" in insert_row and insert_row["contexts.key"] == ["scalar.string"] + assert "contexts.value" in insert_row and insert_row["contexts.value"] == ["scalar_value"] def test_extract_context_filters_non_dict(self, message_base): message_base["data"]["contexts"] = { @@ -425,9 +407,7 @@ def test_extract_transaction_duration(self, message_base): assert insert_row["transaction_duration"] == 0 now = datetime.utcnow() - message_base["data"]["start_timestamp"] = int( - (now - timedelta(seconds=10)).timestamp() - ) + message_base["data"]["start_timestamp"] = int((now - timedelta(seconds=10)).timestamp()) message_base["data"]["timestamp"] = int(now.timestamp()) processed = self.process_message(message_base) self.assert_required_columns(processed) @@ -450,9 +430,7 @@ def test_extract_profile_id(self, message_base): assert insert_row["profile_id"] == ensure_uuid(profile_id) for invalid_profile_id in ["", "im a little tea pot", 1, 1.1]: - message_base["data"]["contexts"]["profile"][ - "profile_id" - ] = invalid_profile_id + message_base["data"]["contexts"]["profile"]["profile_id"] = invalid_profile_id with pytest.raises(ValueError): self.process_message(message_base) diff --git a/tests/datasets/test_table_storage.py b/tests/datasets/test_table_storage.py index e5a0b1f99db..0bbf6e5c3d1 100644 --- a/tests/datasets/test_table_storage.py +++ b/tests/datasets/test_table_storage.py @@ -1,3 +1,4 @@ +import pytest from confluent_kafka.admin import AdminClient from snuba.datasets.storages.factory import get_writable_storage @@ -10,7 +11,7 @@ from snuba.utils.streams.topics import Topic as SnubaTopic -def test_get_physical_topic_name(monkeypatch) -> None: # type: ignore +def test_get_physical_topic_name(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setitem( SLICED_KAFKA_TOPIC_MAP, ("ingest-replay-events", 2), "ingest-replay-events-2" ) diff --git a/tests/datasets/test_transaction_processor.py b/tests/datasets/test_transaction_processor.py index c5bd2695cc3..6a28454d89a 100644 --- a/tests/datasets/test_transaction_processor.py +++ b/tests/datasets/test_transaction_processor.py @@ -1,8 +1,9 @@ import uuid +from collections.abc import Mapping, Sequence from copy import deepcopy from dataclasses import dataclass -from datetime import datetime, timedelta, timezone -from typing import Any, Mapping, Optional, Sequence, Tuple +from datetime import UTC, datetime, timedelta +from typing import Any from unittest.mock import ANY import pytest @@ -27,30 +28,30 @@ class TransactionEvent: start_timestamp: float timestamp: float platform: str - dist: Optional[str] - user_name: Optional[str] - user_id: Optional[str] - user_email: Optional[str] - ipv6: Optional[str] - ipv4: Optional[str] - environment: Optional[str] + dist: str | None + user_name: str | None + user_id: str | None + user_email: str | None + ipv6: str | None + ipv4: str | None + environment: str | None release: str - sdk_name: Optional[str] - sdk_version: Optional[str] - http_method: Optional[str] - http_referer: Optional[str] + sdk_name: str | None + sdk_version: str | None + http_method: str | None + http_referer: str | None geo: Mapping[str, str] status: str - transaction_source: Optional[str] + transaction_source: str | None app_start_type: str = "warm" has_app_ctx: bool = True - profile_id: Optional[str] = None - profiler_id: Optional[str] = None - thread_id: Optional[int | str] = None - replay_id: Optional[str] = None - received: Optional[float] = None + profile_id: str | None = None + profiler_id: str | None = None + thread_id: int | str | None = None + replay_id: str | None = None + received: float | None = None - def get_trace_context(self) -> Optional[Mapping[str, Any]]: + def get_trace_context(self) -> Mapping[str, Any] | None: context = { "sampled": True, "trace_id": self.trace_id, @@ -68,13 +69,12 @@ def get_trace_context(self) -> Optional[Mapping[str, Any]]: return context - def get_app_context(self) -> Optional[Mapping[str, str]]: + def get_app_context(self) -> Mapping[str, str] | None: if self.has_app_ctx: return {"start_type": self.app_start_type} - else: - return None + return None - def get_profile_context(self) -> Optional[Mapping[str, str]]: + def get_profile_context(self) -> Mapping[str, str] | None: context = {} if self.profile_id is not None: @@ -88,12 +88,12 @@ def get_profile_context(self) -> Optional[Mapping[str, str]]: return None - def get_replay_context(self) -> Optional[Mapping[str, str]]: + def get_replay_context(self) -> Mapping[str, str] | None: if self.replay_id is None: return None return {"replay_id": self.replay_id} - def serialize(self) -> Tuple[int, str, Mapping[str, Any]]: + def serialize(self) -> tuple[int, str, Mapping[str, Any]]: return ( 2, "insert", @@ -221,9 +221,7 @@ def build_result(self, meta: KafkaMessageMetadata) -> Mapping[str, Any]: start_timestamp = datetime.utcfromtimestamp(self.start_timestamp) finish_timestamp = datetime.utcfromtimestamp(self.timestamp) - spans = sorted( - [(self.op, int("a" * 16, 16), 1.2345), ("http", int("b" * 16, 16), 0.1234)] - ) + spans = sorted([(self.op, int("a" * 16, 16), 1.2345), ("http", int("b" * 16, 16), 0.1234)]) ret = { "deleted": 0, @@ -240,9 +238,7 @@ def build_result(self, meta: KafkaMessageMetadata) -> Mapping[str, Any]: "start_ms": int(start_timestamp.microsecond / 1000), "finish_ts": finish_timestamp, "finish_ms": int(finish_timestamp.microsecond / 1000), - "duration": int( - (finish_timestamp - start_timestamp).total_seconds() * 1000 - ), + "duration": int((finish_timestamp - start_timestamp).total_seconds() * 1000), "platform": self.platform, "environment": self.environment, "release": self.release, @@ -317,8 +313,8 @@ def build_result(self, meta: KafkaMessageMetadata) -> Mapping[str, Any]: @pytest.mark.clickhouse_db @pytest.mark.redis_db class TestTransactionsProcessor: - def __get_timestamps(self) -> Tuple[float, float]: - timestamp = datetime.now(tz=timezone.utc) - timedelta(seconds=5) + def __get_timestamps(self) -> tuple[float, float]: + timestamp = datetime.now(tz=UTC) - timedelta(seconds=5) start_timestamp = timestamp - timedelta(seconds=5) return (start_timestamp.timestamp(), timestamp.timestamp()) @@ -334,9 +330,7 @@ def __get_transaction_event(self) -> TransactionEvent: op="navigation", timestamp=finish, start_timestamp=start, - received=( - datetime.now(tz=timezone.utc) - timedelta(seconds=15) - ).timestamp(), + received=(datetime.now(tz=UTC) - timedelta(seconds=15)).timestamp(), platform="python", dist="", user_name="me", @@ -369,9 +363,7 @@ def test_skip_non_transactions(self) -> None: # Force an invalid event payload[2]["data"]["type"] = "error" - meta = KafkaMessageMetadata( - offset=1, partition=2, timestamp=datetime(1970, 1, 1) - ) + meta = KafkaMessageMetadata(offset=1, partition=2, timestamp=datetime(1970, 1, 1)) processor = TransactionsMessageProcessor() assert processor.process_message(payload, meta) is None @@ -381,9 +373,7 @@ def test_missing_trace_context(self) -> None: # Force an invalid event del payload[2]["data"]["contexts"] - meta = KafkaMessageMetadata( - offset=1, partition=2, timestamp=datetime(1970, 1, 1) - ) + meta = KafkaMessageMetadata(offset=1, partition=2, timestamp=datetime(1970, 1, 1)) processor = TransactionsMessageProcessor() assert processor.process_message(payload, meta) is None @@ -393,9 +383,7 @@ def test_base_process(self) -> None: message = self.__get_transaction_event() - meta = KafkaMessageMetadata( - offset=1, partition=2, timestamp=datetime(1970, 1, 1) - ) + meta = KafkaMessageMetadata(offset=1, partition=2, timestamp=datetime(1970, 1, 1)) assert TransactionsMessageProcessor().process_message( message.serialize(), meta ) == InsertBatch([message.build_result(meta)], ANY) @@ -407,9 +395,7 @@ def test_too_many_spans(self) -> None: set_config("max_spans_per_transaction", 1) message = self.__get_transaction_event() - meta = KafkaMessageMetadata( - offset=1, partition=2, timestamp=datetime(1970, 1, 1) - ) + meta = KafkaMessageMetadata(offset=1, partition=2, timestamp=datetime(1970, 1, 1)) payload = message.serialize() @@ -421,9 +407,9 @@ def test_too_many_spans(self) -> None: result["spans.exclusive_time"] = [0] result["spans.exclusive_time_32"] = [1.2345] - assert TransactionsMessageProcessor().process_message( - payload, meta - ) == InsertBatch([result], ANY) + assert TransactionsMessageProcessor().process_message(payload, meta) == InsertBatch( + [result], ANY + ) settings.TRANSACT_SKIP_CONTEXT_STORE = old_skip_context def test_missing_transaction_source(self) -> None: @@ -436,9 +422,7 @@ def test_missing_transaction_source(self) -> None: # Remove transaction_info del payload_wo_transaction_info[2]["data"]["transaction_info"] - meta = KafkaMessageMetadata( - offset=1, partition=2, timestamp=datetime(1970, 1, 1) - ) + meta = KafkaMessageMetadata(offset=1, partition=2, timestamp=datetime(1970, 1, 1)) actual_message = TransactionsMessageProcessor().process_message( payload_wo_transaction_info, meta ) @@ -447,12 +431,8 @@ def test_missing_transaction_source(self) -> None: # Remove transaction_info.source del payload_wo_source[2]["data"]["transaction_info"]["source"] - meta = KafkaMessageMetadata( - offset=1, partition=2, timestamp=datetime(1970, 1, 1) - ) - actual_message = TransactionsMessageProcessor().process_message( - payload_wo_source, meta - ) + meta = KafkaMessageMetadata(offset=1, partition=2, timestamp=datetime(1970, 1, 1)) + actual_message = TransactionsMessageProcessor().process_message(payload_wo_source, meta) assert actual_message.rows[0]["transaction_source"] == "" def test_app_ctx_none(self) -> None: @@ -462,9 +442,7 @@ def test_app_ctx_none(self) -> None: message = self.__get_transaction_event() message.has_app_ctx = False - meta = KafkaMessageMetadata( - offset=1, partition=2, timestamp=datetime(1970, 1, 1) - ) + meta = KafkaMessageMetadata(offset=1, partition=2, timestamp=datetime(1970, 1, 1)) assert TransactionsMessageProcessor().process_message( message.serialize(), meta ) == InsertBatch([message.build_result(meta)], ANY) @@ -478,14 +456,10 @@ def test_replay_id_as_tag(self) -> None: message = self.__get_transaction_event() payload = message.serialize() - payload[2]["data"]["tags"].append( - ["replayId", "d2731f8ed8934c6fa5253e450915aa12"] - ) + payload[2]["data"]["tags"].append(["replayId", "d2731f8ed8934c6fa5253e450915aa12"]) del payload[2]["data"]["contexts"]["replay"] - meta = KafkaMessageMetadata( - offset=1, partition=2, timestamp=datetime(1970, 1, 1) - ) + meta = KafkaMessageMetadata(offset=1, partition=2, timestamp=datetime(1970, 1, 1)) result = message.build_result(meta) # when the replay_id is sent as a tag instead of a context, @@ -495,9 +469,9 @@ def test_replay_id_as_tag(self) -> None: result["tags.key"].insert(1, "replayId") result["tags.value"].insert(1, "d2731f8ed8934c6fa5253e450915aa12") - assert TransactionsMessageProcessor().process_message( - payload, meta - ) == InsertBatch([result], ANY) + assert TransactionsMessageProcessor().process_message(payload, meta) == InsertBatch( + [result], ANY + ) def test_replay_id_as_tag_and_context(self) -> None: """ @@ -509,13 +483,9 @@ def test_replay_id_as_tag_and_context(self) -> None: message = self.__get_transaction_event() payload = message.serialize() - payload[2]["data"]["tags"].append( - ["replayId", "d2731f8ed8934c6fa5253e450915aa12"] - ) + payload[2]["data"]["tags"].append(["replayId", "d2731f8ed8934c6fa5253e450915aa12"]) - meta = KafkaMessageMetadata( - offset=1, partition=2, timestamp=datetime(1970, 1, 1) - ) + meta = KafkaMessageMetadata(offset=1, partition=2, timestamp=datetime(1970, 1, 1)) result = message.build_result(meta) # when the replay_id is sent as a tag instead of a context, @@ -525,9 +495,9 @@ def test_replay_id_as_tag_and_context(self) -> None: result["tags.key"].insert(1, "replayId") result["tags.value"].insert(1, "d2731f8ed8934c6fa5253e450915aa12") - assert TransactionsMessageProcessor().process_message( - payload, meta - ) == InsertBatch([result], ANY) + assert TransactionsMessageProcessor().process_message(payload, meta) == InsertBatch( + [result], ANY + ) def test_replay_id_as_invalid_tag(self) -> None: """ @@ -541,9 +511,7 @@ def test_replay_id_as_invalid_tag(self) -> None: del payload[2]["data"]["contexts"]["replay"] payload[2]["data"]["tags"].append(["replayId", "I_AM_NOT_A_UUID"]) - meta = KafkaMessageMetadata( - offset=1, partition=2, timestamp=datetime(1970, 1, 1) - ) + meta = KafkaMessageMetadata(offset=1, partition=2, timestamp=datetime(1970, 1, 1)) result = message.build_result(meta) del result["replay_id"] @@ -552,9 +520,9 @@ def test_replay_id_as_invalid_tag(self) -> None: result["tags.key"].insert(1, "replayId") result["tags.value"].insert(1, "I_AM_NOT_A_UUID") - assert TransactionsMessageProcessor().process_message( - payload, meta - ) == InsertBatch([result], ANY) + assert TransactionsMessageProcessor().process_message(payload, meta) == InsertBatch( + [result], ANY + ) def test_trace_data_is_none(self) -> None: """ @@ -566,9 +534,7 @@ def test_trace_data_is_none(self) -> None: # Force an invalid event payload[2]["data"]["contexts"]["trace"]["data"] = None - meta = KafkaMessageMetadata( - offset=1, partition=2, timestamp=datetime(1970, 1, 1) - ) + meta = KafkaMessageMetadata(offset=1, partition=2, timestamp=datetime(1970, 1, 1)) result = message.build_result(meta) diff --git a/tests/datasets/validation/test_datetime_condition_validator.py b/tests/datasets/validation/test_datetime_condition_validator.py index 4725d6df52b..07ebbe18a16 100644 --- a/tests/datasets/validation/test_datetime_condition_validator.py +++ b/tests/datasets/validation/test_datetime_condition_validator.py @@ -1,5 +1,4 @@ import datetime -from typing import Optional from unittest.mock import MagicMock import pytest @@ -42,7 +41,7 @@ @pytest.mark.parametrize("condition", required_column_tests) -def test_datetime_column_validation(condition: Optional[Expression]) -> None: +def test_datetime_column_validation(condition: Expression | None) -> None: query = LogicalQuery( QueryEntity(EntityKey.EVENTS, get_entity(EntityKey.EVENTS).get_data_model()), selected_columns=[ @@ -84,7 +83,7 @@ def test_datetime_column_validation(condition: Optional[Expression]) -> None: def test_invalid_datetime_column_validation() -> None: old_logger = logger.warning mock_logger = MagicMock() - logger.warning = mock_logger # type: ignore + logger.warning = mock_logger # type: ignore[method-assign] for condition, message in invalid_tests: query = LogicalQuery( QueryEntity(EntityKey.EVENTS, get_entity(EntityKey.EVENTS).get_data_model()), @@ -100,4 +99,4 @@ def test_invalid_datetime_column_validation() -> None: mock_logger.assert_called_with(message) mock_logger.reset_mock() - logger.warning = old_logger # type: ignore + logger.warning = old_logger # type: ignore[method-assign] diff --git a/tests/datasets/validation/test_entity_validation.py b/tests/datasets/validation/test_entity_validation.py index 8c7994e06f8..685dbd0e59b 100644 --- a/tests/datasets/validation/test_entity_validation.py +++ b/tests/datasets/validation/test_entity_validation.py @@ -1,5 +1,4 @@ import datetime -from typing import Optional import pytest @@ -71,7 +70,7 @@ @pytest.mark.parametrize("key, condition", required_column_tests) -def test_entity_required_column_validation(key: EntityKey, condition: Optional[Expression]) -> None: +def test_entity_required_column_validation(key: EntityKey, condition: Expression | None) -> None: query = LogicalQuery( QueryEntity(key, get_entity(key).get_data_model()), selected_columns=[ @@ -104,7 +103,7 @@ def test_entity_required_column_validation(key: EntityKey, condition: Optional[E @pytest.mark.parametrize("key, condition", invalid_required_column_tests) def test_entity_required_column_validation_failure( - key: EntityKey, condition: Optional[Expression] + key: EntityKey, condition: Expression | None ) -> None: query = LogicalQuery( QueryEntity(key, get_entity(key).get_data_model()), @@ -157,7 +156,7 @@ def test_entity_required_column_validation_failure( @pytest.mark.parametrize("key, condition", required_str_column_tests) def test_entity_required_str_column_validation( - key: EntityKey, condition: Optional[Expression] + key: EntityKey, condition: Expression | None ) -> None: query = LogicalQuery( QueryEntity(key, get_entity(key).get_data_model()), diff --git a/tests/datasets/validation/test_illegal_aggregate_conditions_validation.py b/tests/datasets/validation/test_illegal_aggregate_conditions_validation.py index ec94bb04289..210f60ecb11 100644 --- a/tests/datasets/validation/test_illegal_aggregate_conditions_validation.py +++ b/tests/datasets/validation/test_illegal_aggregate_conditions_validation.py @@ -1,5 +1,4 @@ import datetime -from typing import Optional import pytest @@ -133,8 +132,8 @@ @pytest.mark.parametrize("condition, having, exception", tests) @pytest.mark.redis_db def test_illegal_aggregate_in_condition_validator( - condition: Optional[Expression], - having: Optional[Expression], + condition: Expression | None, + having: Expression | None, exception: Exception, ) -> None: query = LogicalQuery( diff --git a/tests/datasets/validation/test_no_time_condition_validator.py b/tests/datasets/validation/test_no_time_condition_validator.py index 6290cfddf99..a83b65ba6c0 100644 --- a/tests/datasets/validation/test_no_time_condition_validator.py +++ b/tests/datasets/validation/test_no_time_condition_validator.py @@ -25,7 +25,7 @@ ] -@pytest.mark.parametrize("key, condition", tests) # type: ignore +@pytest.mark.parametrize("key, condition", tests) def test_no_time_based_validation(key: EntityKey, condition: Expression) -> None: entity = get_entity(key) query = LogicalQuery( @@ -71,10 +71,8 @@ def test_no_time_based_validation(key: EntityKey, condition: Expression) -> None ] -@pytest.mark.parametrize("key, condition", invalid_tests) # type: ignore -def test_no_time_based_validation_failure( - key: EntityKey, condition: Expression -) -> None: +@pytest.mark.parametrize("key, condition", invalid_tests) +def test_no_time_based_validation_failure(key: EntityKey, condition: Expression) -> None: entity = get_entity(key) query = LogicalQuery( QueryEntity(key, entity.get_data_model()), diff --git a/tests/datasets/validation/test_subscription_clauses_validator.py b/tests/datasets/validation/test_subscription_clauses_validator.py index 71de25ecf6e..6c80b1c9137 100644 --- a/tests/datasets/validation/test_subscription_clauses_validator.py +++ b/tests/datasets/validation/test_subscription_clauses_validator.py @@ -23,13 +23,9 @@ tests = [ pytest.param( LogicalQuery( - QueryEntity( - EntityKey.EVENTS, get_entity(EntityKey.EVENTS).get_data_model() - ), + QueryEntity(EntityKey.EVENTS, get_entity(EntityKey.EVENTS).get_data_model()), selected_columns=[ - SelectedExpression( - "time", Column("_snuba_timestamp", None, "timestamp") - ), + SelectedExpression("time", Column("_snuba_timestamp", None, "timestamp")), ], condition=binary_condition( "equals", @@ -92,9 +88,7 @@ ), pytest.param( LogicalQuery( - QueryEntity( - EntityKey.EVENTS, get_entity(EntityKey.EVENTS).get_data_model() - ), + QueryEntity(EntityKey.EVENTS, get_entity(EntityKey.EVENTS).get_data_model()), selected_columns=[ SelectedExpression("count", FunctionCall("_snuba_count", "count", ())), ], @@ -110,7 +104,7 @@ ] -@pytest.mark.parametrize("query", tests) # type: ignore +@pytest.mark.parametrize("query", tests) def test_subscription_clauses_validation(query: LogicalQuery) -> None: entity = get_entity(query.get_from_clause().key) subscription_validators = entity.get_subscription_validators() @@ -145,9 +139,7 @@ def test_subscription_clauses_validation(query: LogicalQuery) -> None: ), pytest.param( LogicalQuery( - QueryEntity( - EntityKey.EVENTS, get_entity(EntityKey.EVENTS).get_data_model() - ), + QueryEntity(EntityKey.EVENTS, get_entity(EntityKey.EVENTS).get_data_model()), selected_columns=[ SelectedExpression("count", FunctionCall("_snuba_count", "count", ())), ], @@ -166,9 +158,7 @@ def test_subscription_clauses_validation(query: LogicalQuery) -> None: ), pytest.param( LogicalQuery( - QueryEntity( - EntityKey.EVENTS, get_entity(EntityKey.EVENTS).get_data_model() - ), + QueryEntity(EntityKey.EVENTS, get_entity(EntityKey.EVENTS).get_data_model()), selected_columns=[ SelectedExpression("count", FunctionCall("_snuba_count", "count", ())), ], @@ -177,11 +167,7 @@ def test_subscription_clauses_validation(query: LogicalQuery) -> None: Column("_snuba_project_id", None, "project_id"), Literal(None, 1), ), - order_by=[ - OrderBy( - OrderByDirection.ASC, Column("_snuba_timestamp", None, "timestamp") - ) - ], + order_by=[OrderBy(OrderByDirection.ASC, Column("_snuba_timestamp", None, "timestamp"))], ), id="no orderby clauses", ), @@ -296,7 +282,7 @@ def test_subscription_clauses_validation(query: LogicalQuery) -> None: ] -@pytest.mark.parametrize("query", invalid_tests) # type: ignore +@pytest.mark.parametrize("query", invalid_tests) def test_subscription_clauses_validation_failure(query: LogicalQuery) -> None: entity = get_entity(query.get_from_clause().key) subscription_validators = entity.get_subscription_validators() diff --git a/tests/fixtures.py b/tests/fixtures.py index 41408599fc3..151a334e65e 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -3,9 +3,10 @@ import calendar import json import uuid -from datetime import datetime, timedelta, timezone +from collections.abc import Mapping +from datetime import UTC, datetime, timedelta from hashlib import md5 -from typing import Any, Mapping, Tuple +from typing import Any from snuba import settings from snuba.processor import InsertEvent @@ -19,7 +20,7 @@ def get_raw_event() -> InsertEvent: event_id = str(uuid.uuid4().hex) message = "Caught exception!" - unique = "%s:%s" % (str(PROJECT_ID), event_id) + unique = f"{str(PROJECT_ID)}:{event_id}" primary_hash = md5(unique.encode("utf-8")).hexdigest() platform = "java" event_datetime = (now - timedelta(seconds=2)).strftime( @@ -190,12 +191,12 @@ def get_raw_event() -> InsertEvent: def get_raw_transaction(span_id: str | None = None) -> Mapping[str, Any]: - now = datetime.utcnow().replace(minute=0, second=0, microsecond=0, tzinfo=timezone.utc) + now = datetime.utcnow().replace(minute=0, second=0, microsecond=0, tzinfo=UTC) start_timestamp = now - timedelta(seconds=3) end_timestamp = now - timedelta(seconds=2) event_received = now - timedelta(seconds=1) trace_id = uuid.UUID("7400045b-25c4-43b8-8591-4600aa83ad04") - span_id = "8841662216cc598b" if not span_id else span_id + span_id = span_id if span_id else "8841662216cc598b" unique = "100" primary_hash = md5(unique.encode("utf-8")).hexdigest() app_start_type = "warm.prewarmed" @@ -296,11 +297,7 @@ def get_raw_transaction(span_id: str | None = None) -> Mapping[str, Any]: def get_replay_event(replay_id: str | None = None) -> Mapping[str, Any]: replay_id = replay_id if replay_id else str(uuid.UUID("e5e062bf2e1d4afd96fd2f90b6770431")) - now = ( - datetime.utcnow() - .replace(minute=0, second=0, microsecond=0, tzinfo=timezone.utc) - .timestamp() - ) + now = datetime.utcnow().replace(minute=0, second=0, microsecond=0, tzinfo=UTC).timestamp() return { "type": "replay_event", @@ -359,14 +356,14 @@ def get_replay_event(replay_id: str | None = None) -> Mapping[str, Any]: } -def get_raw_error_message() -> Tuple[int, str, InsertEvent, Any]: +def get_raw_error_message() -> tuple[int, str, InsertEvent, Any]: """ Get an error message which can be passed to the processors. """ return (2, "insert", get_raw_event(), {}) -def get_raw_transaction_message() -> Tuple[int, str, Mapping[str, Any]]: +def get_raw_transaction_message() -> tuple[int, str, Mapping[str, Any]]: """ Get a transaction message which can be passed to the processors. """ diff --git a/tests/helpers.py b/tests/helpers.py index 52c66061f31..acef978aa6e 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -1,5 +1,6 @@ +from collections.abc import Mapping, MutableSequence, Sequence from datetime import datetime -from typing import Any, Mapping, MutableSequence, Sequence, Union +from typing import Any from snuba.clickhouse.http import JSONRowEncoder from snuba.consumers.types import KafkaMessageMetadata @@ -28,7 +29,7 @@ def write_processed_messages( def write_unprocessed_events( - storage: WritableStorage, events: Sequence[Union[InsertEvent, Mapping[str, Any]]] + storage: WritableStorage, events: Sequence[InsertEvent | Mapping[str, Any]] ) -> None: processor = storage.get_table_writer().get_stream_loader().get_processor() processed_messages = [] @@ -44,7 +45,7 @@ def write_unprocessed_events( def write_raw_unprocessed_events( storage: WritableStorage, - events: Sequence[Union[InsertEvent, Mapping[str, Any] | bytes]], + events: Sequence[InsertEvent | (Mapping[str, Any] | bytes)], ) -> None: processor = storage.get_table_writer().get_stream_loader().get_processor() diff --git a/tests/lw_deletions/test_formatters.py b/tests/lw_deletions/test_formatters.py index aae5ae5089a..165630fd0af 100644 --- a/tests/lw_deletions/test_formatters.py +++ b/tests/lw_deletions/test_formatters.py @@ -1,4 +1,4 @@ -from typing import Sequence, Type +from collections.abc import Sequence import pytest from sentry_protos.snuba.v1.trace_item_attribute_pb2 import AttributeKey @@ -76,7 +76,7 @@ def create_delete_query_message( def test_search_issues_formatter( messages: Sequence[DeleteQueryMessage], expected_formatted: Sequence[ConditionsBag], - formatter: Type[Formatter], + formatter: type[Formatter], ) -> None: formatted = formatter().format(messages) assert formatted == expected_formatted @@ -102,7 +102,7 @@ def test_search_issues_formatter( def test_eap_items_formatter_identity_conditions( messages: Sequence[DeleteQueryMessage], expected_formatted: Sequence[ConditionsBag], - formatter: Type[Formatter], + formatter: type[Formatter], ) -> None: formatted = formatter().format(messages) assert formatted == expected_formatted diff --git a/tests/lw_deletions/test_lw_deletions.py b/tests/lw_deletions/test_lw_deletions.py index 8bb2ff33fd3..38234211870 100644 --- a/tests/lw_deletions/test_lw_deletions.py +++ b/tests/lw_deletions/test_lw_deletions.py @@ -1,7 +1,7 @@ from __future__ import annotations +from collections.abc import Iterator from datetime import datetime, timedelta -from typing import Iterator from unittest.mock import Mock, patch import pytest diff --git a/tests/lw_deletions/test_off_peak.py b/tests/lw_deletions/test_off_peak.py index 6377fa533ca..6b12ecae692 100644 --- a/tests/lw_deletions/test_off_peak.py +++ b/tests/lw_deletions/test_off_peak.py @@ -1,6 +1,6 @@ from __future__ import annotations -from datetime import datetime, timedelta, timezone +from datetime import UTC, datetime, timedelta from unittest.mock import MagicMock import pytest @@ -30,8 +30,8 @@ def _tomorrow_at(hour: int) -> datetime: """Return tomorrow at the given UTC hour. Always in the future so time_machine.travel moves the clock forward and the snuba.state memoize cache naturally expires.""" - tomorrow = datetime.now(timezone.utc).date() + timedelta(days=1) - return datetime(tomorrow.year, tomorrow.month, tomorrow.day, hour, tzinfo=timezone.utc) + tomorrow = datetime.now(UTC).date() + timedelta(days=1) + return datetime(tomorrow.year, tomorrow.month, tomorrow.day, hour, tzinfo=UTC) def _make_message() -> Message[KafkaPayload]: diff --git a/tests/manual_jobs/test_extract_span_data.py b/tests/manual_jobs/test_extract_span_data.py index da1ec0af100..282b72ce05f 100644 --- a/tests/manual_jobs/test_extract_span_data.py +++ b/tests/manual_jobs/test_extract_span_data.py @@ -1,11 +1,12 @@ import random import uuid +from collections.abc import Mapping from datetime import datetime, timedelta -from typing import Any, Mapping +from typing import Any import pytest -from snuba.datasets.storages.factory import get_storage +from snuba.datasets.storages.factory import get_writable_storage from snuba.datasets.storages.storage_key import StorageKey from snuba.manual_jobs import JobSpec from snuba.manual_jobs.job_status import JobStatus @@ -99,14 +100,14 @@ def test_extract_span_data() -> None: minutes=180 ) organization_ids = [0, 1] - spans_storage = get_storage(StorageKey("eap_spans")) + spans_storage = get_writable_storage(StorageKey("eap_spans")) messages = [ _gen_message(BASE_TIME - timedelta(minutes=i), organization_id) for organization_id in organization_ids for i in range(20) ] - write_raw_unprocessed_events(spans_storage, messages) # type: ignore + write_raw_unprocessed_events(spans_storage, messages) assert ( run_job( diff --git a/tests/migrations/autogeneration/test_generate_python_migration.py b/tests/migrations/autogeneration/test_generate_python_migration.py index dd8950e02cb..17776968501 100644 --- a/tests/migrations/autogeneration/test_generate_python_migration.py +++ b/tests/migrations/autogeneration/test_generate_python_migration.py @@ -8,7 +8,7 @@ def mockstoragewithcolumns(cols: list[str]) -> Any: - colstr = ",\n ".join([s for s in cols]) + colstr = ",\n ".join(list(cols)) storage = f""" version: v1 kind: writable_storage diff --git a/tests/migrations/test_check_dangerous.py b/tests/migrations/test_check_dangerous.py index 72e3c97ce6c..2eb99e77a13 100644 --- a/tests/migrations/test_check_dangerous.py +++ b/tests/migrations/test_check_dangerous.py @@ -1,5 +1,6 @@ import logging -from typing import Any, Sequence, Tuple +from collections.abc import Sequence +from typing import Any from unittest.mock import Mock, patch import pytest @@ -49,7 +50,7 @@ def create_table(self, create_databases: None) -> None: " ENGINE = MergeTree ORDER BY name" ) - def _make_modify_op(self, column: Column[MigrationModifiers]) -> Tuple[SqlOperation, Any]: + def _make_modify_op(self, column: Column[MigrationModifiers]) -> tuple[SqlOperation, Any]: op = ModifyColumn( StorageSetKey.EVENTS, self.table_name, column, target=OperationTarget.LOCAL ) diff --git a/tests/migrations/test_connect.py b/tests/migrations/test_connect.py index 098fc741236..d73a544b446 100644 --- a/tests/migrations/test_connect.py +++ b/tests/migrations/test_connect.py @@ -14,7 +14,7 @@ ) from snuba.migrations.groups import MigrationGroup -_ALL_STORAGE_SET_KEYS = set([s.value for s in StorageSetKey]) +_ALL_STORAGE_SET_KEYS = {s.value for s in StorageSetKey} _REMAINING_STORAGE_SET_KEYS = _ALL_STORAGE_SET_KEYS - {"events", "querylog"} _QUERYLOG_CLUSTER = cluster.ClickhouseCluster( @@ -102,21 +102,21 @@ def test_get_clickhouse_clusters_for_migration_group(override_cluster: Any) -> N [ReadinessState.PARTIAL], [_QUERYLOG_CLUSTER, _EVENTS_CLUSTER], [_QUERYLOG_CLUSTER], - set([StorageSetKey.QUERYLOG]), + {StorageSetKey.QUERYLOG}, id="partial only", ), pytest.param( [ReadinessState.COMPLETE], [_QUERYLOG_CLUSTER, _EVENTS_CLUSTER], [_EVENTS_CLUSTER], - set([StorageSetKey.EVENTS]), + {StorageSetKey.EVENTS}, id="complete only", ), pytest.param( [ReadinessState.COMPLETE, ReadinessState.PARTIAL], [_QUERYLOG_CLUSTER, _EVENTS_CLUSTER], [_QUERYLOG_CLUSTER, _EVENTS_CLUSTER], - set([StorageSetKey.EVENTS, StorageSetKey.QUERYLOG]), + {StorageSetKey.EVENTS, StorageSetKey.QUERYLOG}, id="complete and partial", ), ], diff --git a/tests/migrations/test_legacy_use.py b/tests/migrations/test_legacy_use.py index 7571258e639..a9baabf35ed 100644 --- a/tests/migrations/test_legacy_use.py +++ b/tests/migrations/test_legacy_use.py @@ -1,11 +1,11 @@ -from typing import List, Mapping, Tuple +from collections.abc import Mapping import pytest from snuba.migrations import migration from snuba.migrations.groups import MigrationGroup, get_group_loader -all_migrations: List[Tuple[str, MigrationGroup, migration.ClickhouseNodeMigration]] = [] +all_migrations: list[tuple[str, MigrationGroup, migration.ClickhouseNodeMigration]] = [] for group in MigrationGroup: group_loader = get_group_loader(group) for migration_id in group_loader.get_migrations(): diff --git a/tests/migrations/test_operations.py b/tests/migrations/test_operations.py index 9db078aa4e2..0ce7fe977b3 100644 --- a/tests/migrations/test_operations.py +++ b/tests/migrations/test_operations.py @@ -1,5 +1,5 @@ +from collections.abc import Callable, Sequence from logging import Logger -from typing import Callable, Sequence from unittest import mock from unittest.mock import Mock, patch diff --git a/tests/migrations/test_parse_schema.py b/tests/migrations/test_parse_schema.py index e4584ebc4da..dccdea5a3f6 100644 --- a/tests/migrations/test_parse_schema.py +++ b/tests/migrations/test_parse_schema.py @@ -1,5 +1,3 @@ -from typing import Tuple - import pytest from snuba.clickhouse.columns import ( @@ -152,8 +150,8 @@ @pytest.mark.parametrize("input, expected_output", test_data) def test_parse_column( - input: Tuple[str, str, str, str], - expected_output: Tuple[Tuple[str, str, str, str, str], ColumnType[Modifiers]], + input: tuple[str, str, str, str], + expected_output: tuple[tuple[str, str, str, str, str], ColumnType[Modifiers]], ) -> None: (input_name, input_type, default_expr, codec_expr) = input assert _get_column(input_name, input_type, default_expr, codec_expr) == expected_output diff --git a/tests/migrations/test_policies.py b/tests/migrations/test_policies.py index 8683f9dccdd..e0eed3bb41e 100644 --- a/tests/migrations/test_policies.py +++ b/tests/migrations/test_policies.py @@ -1,5 +1,5 @@ +from collections.abc import Mapping from datetime import datetime, timedelta -from typing import Mapping from unittest.mock import Mock, patch import pytest @@ -99,8 +99,8 @@ def test_policies( ) def test_experimental_groups(self, mock_readiness_state: Mock) -> None: migration_key = MigrationKey(MigrationGroup("test_migration"), "0001_create_test_table") - assert NonBlockingMigrationsPolicy().can_run(migration_key) == True - assert NonBlockingMigrationsPolicy().can_reverse(migration_key) == True + assert NonBlockingMigrationsPolicy().can_run(migration_key) + assert NonBlockingMigrationsPolicy().can_reverse(migration_key) @patch( "snuba.migrations.runner.Runner.get_status", @@ -115,11 +115,11 @@ def test_pending_migration_reverse(self, mock_get_status: Mock) -> None: """ # non-blocking migration migration_key = MigrationKey(MigrationGroup("events"), "0016_drop_legacy_events") - assert NonBlockingMigrationsPolicy().can_reverse(migration_key) == True + assert NonBlockingMigrationsPolicy().can_reverse(migration_key) # blocking migration migration_key = code_migration_key() - assert NonBlockingMigrationsPolicy().can_reverse(migration_key) == False + assert not NonBlockingMigrationsPolicy().can_reverse(migration_key) @patch( "snuba.migrations.runner.Runner.get_status", @@ -139,8 +139,8 @@ def test_completed_migration_reverse(self, mock_get_status: Mock) -> None: """ # non-blocking migration migration_key = MigrationKey(MigrationGroup("events"), "0016_drop_legacy_events") - assert NonBlockingMigrationsPolicy().can_reverse(migration_key) == True + assert NonBlockingMigrationsPolicy().can_reverse(migration_key) # blocking migration migration_key = code_migration_key() - assert NonBlockingMigrationsPolicy().can_reverse(migration_key) == False + assert not NonBlockingMigrationsPolicy().can_reverse(migration_key) diff --git a/tests/migrations/test_runner.py b/tests/migrations/test_runner.py index 56f21511697..853a30f7752 100644 --- a/tests/migrations/test_runner.py +++ b/tests/migrations/test_runner.py @@ -1,6 +1,7 @@ import importlib +from collections.abc import Generator from datetime import datetime -from typing import Any, Generator +from typing import Any from unittest.mock import MagicMock, patch import pytest @@ -34,7 +35,7 @@ def _drop_all_tables() -> None: @pytest.fixture(autouse=True) -def setup_teardown() -> Generator[None, None, None]: +def setup_teardown() -> Generator[None]: _drop_all_tables() yield _drop_all_tables() @@ -70,22 +71,18 @@ def test_get_status() -> None: def test_show_all() -> None: runner = Runner() assert all( - [ - migration.status == Status.NOT_STARTED - for (_, group_migrations) in runner.show_all() - for migration in group_migrations - ] + migration.status == Status.NOT_STARTED + for (_, group_migrations) in runner.show_all() + for migration in group_migrations ) # only need to run migrations for the system table to # test show_all, can fake the status for the rest runner.run_all(force=True, group=MigrationGroup.SYSTEM) runner.run_all(force=True, fake=True) assert all( - [ - migration.status == Status.COMPLETED - for (_, group_migrations) in runner.show_all() - for migration in group_migrations - ] + migration.status == Status.COMPLETED + for (_, group_migrations) in runner.show_all() + for migration in group_migrations ) @@ -98,7 +95,7 @@ def test_show_all_for_groups() -> None: assert len(results) == 1 group, migrations = results[0] assert group == MigrationGroup("system") - assert all([migration.status == Status.NOT_STARTED for migration in migrations]) + assert all(migration.status == Status.NOT_STARTED for migration in migrations) runner.run_migration(migration_key, force=True) results = runner.show_all(["system"]) @@ -106,7 +103,7 @@ def test_show_all_for_groups() -> None: assert len(results) == 1 group, migrations = results[0] assert group == MigrationGroup("system") - assert all([migration.status == Status.COMPLETED for migration in migrations]) + assert all(migration.status == Status.COMPLETED for migration in migrations) @pytest.mark.custom_clickhouse_db @@ -332,7 +329,7 @@ def test_reverse_idempotency_all() -> None: if migration.group != MigrationGroup.SYSTEM: runner.reverse_migration(migration, force=True) - def reverse_twice() -> None: + def reverse_twice(migration: MigrationKey = migration) -> None: # reverse again to ensure idempotency runner.run_migration(migration, fake=True) runner.reverse_migration(migration, force=True) diff --git a/tests/migrations/test_runner_individual.py b/tests/migrations/test_runner_individual.py index b0aea333d5a..25d826dd814 100644 --- a/tests/migrations/test_runner_individual.py +++ b/tests/migrations/test_runner_individual.py @@ -1,6 +1,7 @@ from __future__ import annotations -from typing import Any, Dict, Generator, Optional, Sequence, cast +from collections.abc import Generator, Sequence +from typing import Any, cast import pytest @@ -34,7 +35,7 @@ def _drop_all_tables() -> None: @pytest.fixture(autouse=True) -def setup_teardown(clickhouse_db: None) -> Generator[None, None, None]: +def setup_teardown(clickhouse_db: None) -> Generator[None]: _drop_all_tables() yield _drop_all_tables() @@ -225,8 +226,8 @@ def run_prior_migrations( def perform_select_query( columns: Sequence[str], table: str, - where: Optional[Dict[str, str]], - limit: Optional[str], + where: dict[str, str] | None, + limit: str | None, connection: ClickhousePool, ) -> Sequence[Any]: """Performs a SELECT query, with optional WHERE and LIMIT clauses diff --git a/tests/migrations/test_validator.py b/tests/migrations/test_validator.py index 004c0b07b00..49b9a630afb 100644 --- a/tests/migrations/test_validator.py +++ b/tests/migrations/test_validator.py @@ -1,6 +1,7 @@ import copy +from collections.abc import Iterator, Sequence from contextlib import _GeneratorContextManager, contextmanager -from typing import Any, Iterator, Sequence, Union +from typing import Any from unittest.mock import Mock, patch import pytest @@ -85,7 +86,7 @@ class TestValidateMigrations: drop_col_local_op = DropColumn(storage, "test_local_table", "col") drop_col_dist_op = DropColumn(storage, "test_dist_table", "col") - def _dist_to_local(self, op: Union[CreateTable, AddColumn, DropColumn]) -> str: + def _dist_to_local(self, op: CreateTable | AddColumn | DropColumn) -> str: if op.table_name == "test_dist_table": return "test_local_table" if op.table_name == "test_dist_table2": @@ -269,18 +270,17 @@ class TestMigrationNew(migration.ClickhouseNodeMigration): def forwards_ops(self) -> Sequence[SqlOperation]: if forwards_local_first_val: return (*forwards_local, *forwards_dist) - else: - return (*forwards_dist, *forwards_local) + return (*forwards_dist, *forwards_local) def backwards_ops(self) -> Sequence[SqlOperation]: if backwards_local_first_val: return (*backwards_local, *backwards_dist) - else: - return (*backwards_dist, *backwards_local) + return (*backwards_dist, *backwards_local) with expectation_new as err: validate_migration_order(TestMigrationNew()) if err_msg: + assert err is not None assert str(err.value) == err_msg @@ -300,7 +300,7 @@ def test_conflicts(mock_get_local_table_name: Mock, mock_get_cluster: Mock) -> N ) mock_get_cluster.return_value = mock_cluster - def _dist_to_local(op: Union[CreateTable, AddColumn, DropColumn]) -> str: + def _dist_to_local(op: CreateTable | AddColumn | DropColumn) -> str: if op.table_name == "test_dist_table": return "test_local_table" if op.table_name == "test_dist_table2": diff --git a/tests/pipeline/conftest.py b/tests/pipeline/conftest.py index ffbe5e00a09..e6a2616edfd 100644 --- a/tests/pipeline/conftest.py +++ b/tests/pipeline/conftest.py @@ -37,12 +37,12 @@ def process_query(self, query: LogicalQuery, query_settings: QuerySettings) -> N @pytest.fixture -def mock_query_storage() -> Generator[QueryStorage, None, None]: +def mock_query_storage() -> Generator[QueryStorage]: yield QueryStorage(key=_MOCK_STORAGE_KEY) @pytest.fixture -def mock_storage() -> Generator[ReadableTableStorage, None, None]: +def mock_storage() -> Generator[ReadableTableStorage]: # create a storage storkey = _MOCK_STORAGE_KEY storsetkey = StorageSetKey("mockstorageset") @@ -78,7 +78,7 @@ def mock_storage() -> Generator[ReadableTableStorage, None, None]: @pytest.fixture def mock_entity( mock_storage: ReadableTableStorage, -) -> Generator[PluggableEntity, None, None]: +) -> Generator[PluggableEntity]: # setup entkey = EntityKey("mock_entity") entity = PluggableEntity( diff --git a/tests/pipeline/test_execution_stage.py b/tests/pipeline/test_execution_stage.py index 0fec3519970..9afaf247e93 100644 --- a/tests/pipeline/test_execution_stage.py +++ b/tests/pipeline/test_execution_stage.py @@ -261,11 +261,7 @@ def test_disable_max_query_size_check(ch_query: Query) -> None: settings = HTTPQuerySettings() timer = Timer("test") metadata = get_fake_metadata() - cluster_name = ( - snubasettings.CLUSTERS[0]["cluster_name"] - if "cluster_name" in snubasettings.CLUSTERS[0] - else "test_cluster" - ) + cluster_name = snubasettings.CLUSTERS[0].get("cluster_name", "test_cluster") # Lowering this should make the query too big... state.set_config(MAX_QUERY_SIZE_BYTES_CONFIG, 1) diff --git a/tests/pipeline/test_pipeline_stage.py b/tests/pipeline/test_pipeline_stage.py index c120b87c39f..9e9fa35e66f 100644 --- a/tests/pipeline/test_pipeline_stage.py +++ b/tests/pipeline/test_pipeline_stage.py @@ -1,5 +1,3 @@ -from typing import Optional - import pytest from snuba.pipeline.query_pipeline import ( @@ -18,7 +16,7 @@ def _process_data(self, pipe_input: QueryPipelineResult[int]) -> int: return check_input_and_multiply(pipe_input.data) -def check_input_and_multiply(num: Optional[int]) -> int: +def check_input_and_multiply(num: int | None) -> int: if num == 0 or num is None: raise Exception("Input cannot be zero") return num * 2 diff --git a/tests/pipeline/test_storage_processing_stage.py b/tests/pipeline/test_storage_processing_stage.py index 98f489d6de2..ba70ec44b4a 100644 --- a/tests/pipeline/test_storage_processing_stage.py +++ b/tests/pipeline/test_storage_processing_stage.py @@ -33,7 +33,7 @@ def process_query(self, query: Query, query_settings: QuerySettings) -> None: @pytest.fixture -def mock_storage() -> Generator[ReadableTableStorage, None, None]: +def mock_storage() -> Generator[ReadableTableStorage]: # Create a fake storage mock_storage = ReadableTableStorage( storage_key=StorageKey("mockstorage"), @@ -125,9 +125,11 @@ def test_default_subscriptable(mock_storage: ReadableTableStorage) -> None: ) ) assert not result.error + schema = mock_storage.get_schema() + assert isinstance(schema, TableSchema) expected = Query( from_clause=Table( - table_name=mock_storage.get_schema().get_table_name(), # type: ignore + table_name=schema.get_table_name(), schema=mock_storage.get_schema().get_columns(), storage_key=mock_storage.get_storage_key(), allocation_policies=mock_storage.get_allocation_policies(), diff --git a/tests/pipeline/test_storage_query_identity_translate.py b/tests/pipeline/test_storage_query_identity_translate.py index cb90ab1856b..d5f48939b88 100644 --- a/tests/pipeline/test_storage_query_identity_translate.py +++ b/tests/pipeline/test_storage_query_identity_translate.py @@ -53,4 +53,7 @@ def test_translate_composite(mock_storage: Storage, mock_query_storage: QuerySto limit=100, ) storage_query = try_translate_storage_query(input_query) - assert isinstance(storage_query.get_from_clause().get_from_clause(), Table) # type: ignore + assert isinstance(storage_query, CompositeQuery) + from_clause = storage_query.get_from_clause() + assert isinstance(from_clause, CompositeQuery) + assert isinstance(from_clause.get_from_clause(), Table) diff --git a/tests/query/allocation_policies/test_allocation_policy_base.py b/tests/query/allocation_policies/test_allocation_policy_base.py index c8cd916962b..a6c2273bbee 100644 --- a/tests/query/allocation_policies/test_allocation_policy_base.py +++ b/tests/query/allocation_policies/test_allocation_policy_base.py @@ -64,10 +64,7 @@ class SomeAllocationPolicy(PassthroughPolicy): def test_passthrough_allows_queries() -> None: DEFAULT_PASSTHROUGH_POLICY.set_config_value("max_threads", 420) assert DEFAULT_PASSTHROUGH_POLICY.get_quota_allowance({}, "deadbeef").can_run - assert ( - DEFAULT_PASSTHROUGH_POLICY.get_quota_allowance({}, "deadbeef").max_threads - == 420 - ) + assert DEFAULT_PASSTHROUGH_POLICY.get_quota_allowance({}, "deadbeef").max_threads == 420 class RejectingEverythingAllocationPolicy(PassthroughPolicy): @@ -121,9 +118,7 @@ def _update_quota_balance( class InvalidTenantAllocationPolicy(PassthroughPolicy): def _get_quota_allowance(self, tenant_ids: dict[str, str | int], query_id: str): - raise InvalidTenantsForAllocationPolicy.from_args( - tenant_ids, self.__class__.__name__ - ) + raise InvalidTenantsForAllocationPolicy.from_args(tenant_ids, self.__class__.__name__) def _update_quota_balance( self, @@ -131,19 +126,21 @@ def _update_quota_balance( query_id: str, result_or_error: QueryResultOrError, ): - raise InvalidTenantsForAllocationPolicy.from_args( - tenant_ids, self.__class__.__name__ - ) + raise InvalidTenantsForAllocationPolicy.from_args(tenant_ids, self.__class__.__name__) def test_passes_through_on_error() -> None: with pytest.raises(AttributeError): - BadlyWrittenAllocationPolicy( - StorageKey("something"), [], {} - ).get_quota_allowance({}, query_id="deadbeef") + BadlyWrittenAllocationPolicy(StorageKey("something"), [], {}).get_quota_allowance( + {}, query_id="deadbeef" + ) with pytest.raises(ValueError): - BadlyWrittenAllocationPolicy(StorageKey("something"), [], {}).update_quota_balance(None, None, None) # type: ignore + BadlyWrittenAllocationPolicy(StorageKey("something"), [], {}).update_quota_balance( + None, # type: ignore[arg-type] + None, # type: ignore[arg-type] + None, # type: ignore[arg-type] + ) # should not raise even though the implementation is buggy (this is the production setting) with mock.patch("snuba.settings.RAISE_ON_ALLOCATION_POLICY_FAILURES", False): @@ -153,10 +150,10 @@ def test_passes_through_on_error() -> None: .can_run ) - BadlyWrittenAllocationPolicy( - StorageKey("something"), [], {} - ).update_quota_balance( - None, None, None # type: ignore + BadlyWrittenAllocationPolicy(StorageKey("something"), [], {}).update_quota_balance( + None, # type: ignore[arg-type] + None, # type: ignore[arg-type] + None, # type: ignore[arg-type] ) assert ( @@ -165,7 +162,11 @@ def test_passes_through_on_error() -> None: .can_run ) - InvalidTenantAllocationPolicy(StorageKey("Something"), [], {}).update_quota_balance({"some": "tenant"}, "12345", None) # type: ignore + InvalidTenantAllocationPolicy(StorageKey("Something"), [], {}).update_quota_balance( + {"some": "tenant"}, + "12345", + None, # type: ignore[arg-type] + ) @pytest.mark.redis_db @@ -195,10 +196,7 @@ def test_bad_config_keys() -> None: with pytest.raises(InvalidConfig) as err: policy.get_config_value("does_not_exist") - assert ( - str(err.value) - == "'does_not_exist' is not a valid config for PassthroughPolicy!" - ) + assert str(err.value) == "'does_not_exist' is not a valid config for PassthroughPolicy!" class SomeParametrizedConfigPolicy(AllocationPolicy): @@ -252,7 +250,7 @@ def test_bad_config_key_in_redis(self) -> None: # the bad configs are logged assert len(captured.records) == 3 - logs = set([record.getMessage() for record in captured.records]) + logs = {record.getMessage() for record in captured.records} assert logs == { "AllocationPolicy could not deserialize a key: something.SomeParametrizedConfigPolicy.my_bad_config.org:10,ref:ref", "AllocationPolicy could not deserialize a key: something.SomeParametrizedConfigPolicy.my_param_config.org:10", @@ -269,13 +267,8 @@ def policy() -> AllocationPolicy: @pytest.mark.redis_db def test_config_validation(policy: AllocationPolicy) -> None: with pytest.raises(InvalidConfig) as err: - policy.set_config_value( - config_key="my_config", value=10, params={"bad_param": 10} - ) - assert ( - str(err.value) - == "'my_config' takes no params for SomeParametrizedConfigPolicy!" - ) + policy.set_config_value(config_key="my_config", value=10, params={"bad_param": 10}) + assert str(err.value) == "'my_config' takes no params for SomeParametrizedConfigPolicy!" with pytest.raises(InvalidConfig) as err: policy.set_config_value(config_key="my_config", value="lol") assert ( @@ -432,9 +425,7 @@ def test_default_config_override() -> None: StorageKey("some_storage"), [], {"my_param_config": 420, "is_enforced": 0} ) assert ( - policy.get_config_value( - "my_param_config", params={"org": 1, "ref": "a"}, validate=True - ) + policy.get_config_value("my_param_config", params={"org": 1, "ref": "a"}, validate=True) == 420 ) assert policy.get_config_value("is_enforced") == 0 @@ -443,17 +434,11 @@ def test_default_config_override() -> None: @pytest.mark.redis_db def test_bad_defaults() -> None: with pytest.raises(ValueError): - SomeParametrizedConfigPolicy( - StorageKey("some_storage"), [], {"is_enforced": "0"} - ) + SomeParametrizedConfigPolicy(StorageKey("some_storage"), [], {"is_enforced": "0"}) with pytest.raises(ValueError): - SomeParametrizedConfigPolicy( - StorageKey("some_storage"), [], {"is_active": False} - ) + SomeParametrizedConfigPolicy(StorageKey("some_storage"), [], {"is_active": False}) with pytest.raises(ValueError): - SomeParametrizedConfigPolicy( - StorageKey("some_storage"), [], {"my_param_config": False} - ) + SomeParametrizedConfigPolicy(StorageKey("some_storage"), [], {"my_param_config": False}) @pytest.mark.redis_db @@ -521,29 +506,20 @@ def test_is_not_enforced() -> None: assert throttle_policy.get_quota_allowance(tenant_ids, "deadbeef").max_threads == 1 throttle_policy.set_config_value(config_key="is_enforced", value=0) - assert ( - throttle_policy.get_quota_allowance(tenant_ids, "deadbeef").max_threads - == MAX_THREADS - ) + assert throttle_policy.get_quota_allowance(tenant_ids, "deadbeef").max_threads == MAX_THREADS rejected_metrics = get_recorded_metric_calls( "increment", "allocation_policy.db_request_rejected" ) assert len(rejected_metrics) == 2 - assert ( - rejected_metrics[0].tags["policy_class"] - == "RejectingEverythingAllocationPolicy" - ) + assert rejected_metrics[0].tags["policy_class"] == "RejectingEverythingAllocationPolicy" assert rejected_metrics[0].tags["is_enforced"] == "True" assert rejected_metrics[1].tags["is_enforced"] == "False" throttled_metrics = get_recorded_metric_calls( "increment", "allocation_policy.db_request_throttled" ) assert len(throttled_metrics) == 2, throttled_metrics - assert ( - throttled_metrics[0].tags["policy_class"] - == "ThrottleEverythingAllocationPolicy" - ) + assert throttled_metrics[0].tags["policy_class"] == "ThrottleEverythingAllocationPolicy" assert throttled_metrics[0].tags["is_enforced"] == "True" assert throttled_metrics[1].tags["is_enforced"] == "False" @@ -568,9 +544,7 @@ def test_configs_with_delimiter_values() -> None: def test_cannot_use_escape_sequences() -> None: policy = SomeParametrizedConfigPolicy(StorageKey("something"), [], {}) with pytest.raises(InvalidConfig): - policy.set_config_value( - "my_param_config", 5, {"ref": "a__dot_literal__.b.c", "org": 1} - ) + policy.set_config_value("my_param_config", 5, {"ref": "a__dot_literal__.b.c", "org": 1}) class TestComponentNameBackwardsCompatibility: diff --git a/tests/query/allocation_policies/test_bytes_scanned_window_allocation_policy.py b/tests/query/allocation_policies/test_bytes_scanned_window_allocation_policy.py index 86e0dc345fd..46bf6379460 100644 --- a/tests/query/allocation_policies/test_bytes_scanned_window_allocation_policy.py +++ b/tests/query/allocation_policies/test_bytes_scanned_window_allocation_policy.py @@ -225,10 +225,7 @@ def test_passthrough_subscriptions(policy: AllocationPolicy) -> None: def test_single_thread_referrers(policy: AllocationPolicy) -> None: _configure_policy(policy) tenant_ids: dict[str, str | int] = {"referrer": "delete-events-from-file"} - assert ( - policy.get_quota_allowance(tenant_ids=tenant_ids, query_id=QUERY_ID).max_threads - == 1 - ) + assert policy.get_quota_allowance(tenant_ids=tenant_ids, query_id=QUERY_ID).max_threads == 1 policy.update_quota_balance( tenant_ids, QUERY_ID, @@ -240,10 +237,7 @@ def test_single_thread_referrers(policy: AllocationPolicy) -> None: error=None, ), ) - assert ( - policy.get_quota_allowance(tenant_ids=tenant_ids, query_id=QUERY_ID).max_threads - == 1 - ) + assert policy.get_quota_allowance(tenant_ids=tenant_ids, query_id=QUERY_ID).max_threads == 1 @pytest.mark.redis_db @@ -281,4 +275,4 @@ def test_cross_org(policy: AllocationPolicy) -> None: ) # make sure that this can be called with cross org queries # and nothing raises - policy.update_quota_balance(tenant_ids, QUERY_ID, None) # type: ignore + policy.update_quota_balance(tenant_ids, QUERY_ID, None) # type: ignore[arg-type] diff --git a/tests/query/allocation_policies/test_concurrent_rate_limit_policy.py b/tests/query/allocation_policies/test_concurrent_rate_limit_policy.py index 866daa17ed1..1b0dce81a4f 100644 --- a/tests/query/allocation_policies/test_concurrent_rate_limit_policy.py +++ b/tests/query/allocation_policies/test_concurrent_rate_limit_policy.py @@ -49,9 +49,7 @@ def policy() -> ConcurrentRateLimitAllocationPolicy: @pytest.mark.redis_db def test_rate_limit_concurrent(policy: ConcurrentRateLimitAllocationPolicy) -> None: for i in range(MAX_CONCURRENT_QUERIES): - policy.get_quota_allowance( - tenant_ids={"organization_id": 123}, query_id=f"abc{i}" - ) + policy.get_quota_allowance(tenant_ids={"organization_id": 123}, query_id=f"abc{i}") quota_allowance = policy.get_quota_allowance( tenant_ids={"organization_id": 123}, query_id=f"abc{MAX_CONCURRENT_QUERIES}" @@ -94,7 +92,9 @@ def test_configure_max_query_duration( time.sleep(sleep_time) assert policy.get_quota_allowance( tenant_ids={"organization_id": 123}, query_id="abc2" - ).can_run, "max_query_duration_s is set to {max_query_duration_s}, test sleeps for {sleep_time} seconds, the first query should have no longer been counted towards the concurrent limit" + ).can_run, ( + "max_query_duration_s is set to {max_query_duration_s}, test sleeps for {sleep_time} seconds, the first query should have no longer been counted towards the concurrent limit" + ) @pytest.mark.redis_db @@ -103,9 +103,7 @@ def test_rate_limit_concurrent_complete_query( ) -> None: # submit the max concurrent queries for i in range(MAX_CONCURRENT_QUERIES): - policy.get_quota_allowance( - tenant_ids={"organization_id": 123}, query_id=f"abc{i}" - ) + policy.get_quota_allowance(tenant_ids={"organization_id": 123}, query_id=f"abc{i}") # cant submit anymore quota_allowance = policy.get_quota_allowance( @@ -139,9 +137,7 @@ def test_update_quota_balance(policy: ConcurrentRateLimitAllocationPolicy) -> No # when a query is finished (in whatever state), it is no longer counted as a concurrent query for i in range(MAX_CONCURRENT_QUERIES): - policy.get_quota_allowance( - tenant_ids={"organization_id": 123}, query_id=f"abc{i}" - ) + policy.get_quota_allowance(tenant_ids={"organization_id": 123}, query_id=f"abc{i}") for i in range(MAX_CONCURRENT_QUERIES): policy.update_quota_balance( @@ -262,7 +258,7 @@ def test_apply_overrides( for i in range(expected_concurrent_limit): policy.get_quota_allowance(tenant_ids=tenant_ids, query_id=f"{i}") allowance = policy.get_quota_allowance( - tenant_ids=tenant_ids, query_id=f"{expected_concurrent_limit+1}" + tenant_ids=tenant_ids, query_id=f"{expected_concurrent_limit + 1}" ) assert not allowance.can_run and allowance.max_threads == 0 assert allowance.explanation["overrides"] == expected_overrides @@ -303,9 +299,7 @@ def test_override_isolation( query_id="uniq_string_2", ) except Exception: - pytest.fail( - "overridden query was finished, another one should have been able to run" - ) + pytest.fail("overridden query was finished, another one should have been able to run") # finish a non-overidden query policy.update_quota_balance( @@ -343,7 +337,7 @@ def test_cross_org(policy: ConcurrentRateLimitAllocationPolicy) -> None: ) # make sure that this can be called with cross org queries # and nothing raises - policy.update_quota_balance(tenant_ids, "c", None) # type: ignore + policy.update_quota_balance(tenant_ids, "c", None) # type: ignore[arg-type] @pytest.mark.redis_db diff --git a/tests/query/allocation_policies/test_cross_org_policy.py b/tests/query/allocation_policies/test_cross_org_policy.py index d8acd9d723e..ff418fab9a3 100644 --- a/tests/query/allocation_policies/test_cross_org_policy.py +++ b/tests/query/allocation_policies/test_cross_org_policy.py @@ -191,4 +191,7 @@ def test_throttle_cross_org_query_with_unregistered_referrer(self): query_id="2", ) assert not allowance.can_run - assert allowance.explanation["cross_org_query"] == "This referrer is not registered for the current storage generic_metrics_distributions, if you want to increase its limits, register it in the yaml of the CrossOrgQueryAllocationPolicy" # type: ignore + assert ( + allowance.explanation["cross_org_query"] + == "This referrer is not registered for the current storage generic_metrics_distributions, if you want to increase its limits, register it in the yaml of the CrossOrgQueryAllocationPolicy" + ) diff --git a/tests/query/allocation_policies/test_per_referrer.py b/tests/query/allocation_policies/test_per_referrer.py index e1a4e2acd1b..d102adec992 100644 --- a/tests/query/allocation_policies/test_per_referrer.py +++ b/tests/query/allocation_policies/test_per_referrer.py @@ -24,13 +24,9 @@ def test_policy_pass_basic(self): ) policy.set_config_value("default_concurrent_request_per_referrer", 2) - policy.get_quota_allowance( - tenant_ids={"referrer": "statistical_detectors"}, query_id="1" - ) + policy.get_quota_allowance(tenant_ids={"referrer": "statistical_detectors"}, query_id="1") - policy.get_quota_allowance( - tenant_ids={"referrer": "statistical_detectors"}, query_id="2" - ) + policy.get_quota_allowance(tenant_ids={"referrer": "statistical_detectors"}, query_id="2") quota_allowance = policy.get_quota_allowance( tenant_ids={"referrer": "statistical_detectors"}, query_id="3" diff --git a/tests/query/data_source/test_join.py b/tests/query/data_source/test_join.py index d59b454b1e2..8d1d5ee13ec 100644 --- a/tests/query/data_source/test_join.py +++ b/tests/query/data_source/test_join.py @@ -13,9 +13,7 @@ from snuba.query.expressions import Column from snuba.query.logical import Query -ERRORS_SCHEMA = ColumnSet( - [("event_id", UUID()), ("message", String()), ("group_id", UInt(32))] -) +ERRORS_SCHEMA = ColumnSet([("event_id", UUID()), ("message", String()), ("group_id", UInt(32))]) GROUPS_SCHEMA = ColumnSet([("id", UInt(32)), ("message", String())]) diff --git a/tests/query/formatters/test_query.py b/tests/query/formatters/test_query.py index e70c8b9f236..8da56faa2bf 100644 --- a/tests/query/formatters/test_query.py +++ b/tests/query/formatters/test_query.py @@ -1,10 +1,7 @@ -from typing import Union - import pytest -from snuba.clickhouse.columns import ColumnSet +from snuba.clickhouse.columns import ColumnSet, UInt from snuba.clickhouse.columns import SchemaModifiers as Modifiers -from snuba.clickhouse.columns import UInt from snuba.clickhouse.query import Query as ClickhouseQuery from snuba.datasets.entities.entity_key import EntityKey from snuba.datasets.storages.storage_key import StorageKey @@ -61,14 +58,10 @@ from_clause=Entity(EntityKey.EVENTS, EVENTS_SCHEMA, 0.5), selected_columns=[ SelectedExpression("c1", Column("_snuba_c1", "t", "c")), - SelectedExpression( - "f1", FunctionCall("_snuba_f1", "f", (Column(None, "t", "c2"),)) - ), + SelectedExpression("f1", FunctionCall("_snuba_f1", "f", (Column(None, "t", "c2"),))), ], array_join=Column(None, None, "col"), - condition=binary_condition( - "equals", Column(None, None, "c4"), Literal(None, "asd") - ), + condition=binary_condition("equals", Column(None, None, "c4"), Literal(None, "asd")), groupby=[Column(None, "t", "c4")], having=binary_condition("equals", Column(None, None, "c6"), Literal(None, "asd2")), order_by=[OrderBy(OrderByDirection.ASC, Column(None, "t", "c"))], @@ -116,9 +109,7 @@ CompositeQuery( from_clause=LOGICAL_QUERY, selected_columns=[ - SelectedExpression( - "f", FunctionCall("f", "avg", (Column(None, "t", "c"),)) - ) + SelectedExpression("f", FunctionCall("f", "avg", (Column(None, "t", "c"),))) ], ), [ @@ -187,9 +178,7 @@ from_clause=CompositeQuery( from_clause=SIMPLE_SELECT_QUERY, selected_columns=[ - SelectedExpression( - "f", FunctionCall("f", "avg", (Column(None, "t", "c"),)) - ) + SelectedExpression("f", FunctionCall("f", "avg", (Column(None, "t", "c"),))) ], ), selected_columns=[SelectedExpression("tc", Column(None, "t", "c"))], @@ -258,10 +247,10 @@ @pytest.mark.parametrize("query, formatted", TEST_JOIN) def test_query_formatter( - query: Union[ProcessableQuery, CompositeQuery[Entity]], + query: ProcessableQuery | CompositeQuery[Entity], formatted: TExpression, ) -> None: - formatted_query = format_query(query) # type: ignore + formatted_query = format_query(query) # type: ignore[arg-type] assert formatted_query == formatted # make sure there are no empty lines assert [line for line in formatted_query if not line] == [] diff --git a/tests/query/joins/equivalence_schema.py b/tests/query/joins/equivalence_schema.py index 65eb5040bd8..78599393210 100644 --- a/tests/query/joins/equivalence_schema.py +++ b/tests/query/joins/equivalence_schema.py @@ -1,5 +1,5 @@ from abc import ABC -from typing import Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import UUID, ColumnSet, String, UInt from snuba.datasets.entities.entity_key import EntityKey diff --git a/tests/query/joins/join_structures.py b/tests/query/joins/join_structures.py index 09216f2786f..1e716275d1d 100644 --- a/tests/query/joins/join_structures.py +++ b/tests/query/joins/join_structures.py @@ -1,4 +1,5 @@ -from typing import Optional, Sequence, TypeVar +from collections.abc import Sequence +from typing import TypeVar from snuba.clickhouse.columns import ColumnSet from snuba.clickhouse.query import Query as ClickhouseQuery @@ -29,13 +30,13 @@ def build_node( alias: str, from_clause: Entity, selected_columns: Sequence[SelectedExpression], - condition: Optional[Expression], - granularity: Optional[int] = None, + condition: Expression | None, + granularity: int | None = None, ) -> IndividualNode[Entity]: return IndividualNode( alias=alias, - data_source=EntityQuery.from_query( - LogicalQuery( # type: ignore + data_source=EntityQuery.from_query( # type: ignore[arg-type] + LogicalQuery( from_clause=from_clause, selected_columns=selected_columns, condition=condition, @@ -47,8 +48,8 @@ def build_node( def events_node( selected_columns: Sequence[SelectedExpression], - condition: Optional[Expression] = None, - granularity: Optional[int] = None, + condition: Expression | None = None, + granularity: int | None = None, ) -> IndividualNode[Entity]: return build_node( "ev", @@ -61,8 +62,8 @@ def events_node( def groups_node( selected_columns: Sequence[SelectedExpression], - condition: Optional[Expression] = None, - granularity: Optional[int] = None, + condition: Expression | None = None, + granularity: int | None = None, ) -> IndividualNode[Entity]: return build_node( "gr", @@ -77,8 +78,8 @@ def build_clickhouse_node( alias: str, from_clause: Table, selected_columns: Sequence[SelectedExpression], - condition: Optional[Expression], - groupby: Optional[Sequence[Expression]] = None, + condition: Expression | None, + groupby: Sequence[Expression] | None = None, ) -> IndividualNode[Table]: return IndividualNode( alias=alias, @@ -93,8 +94,8 @@ def build_clickhouse_node( def clickhouse_events_node( selected_columns: Sequence[SelectedExpression], - condition: Optional[Expression] = None, - groupby: Optional[Sequence[Expression]] = None, + condition: Expression | None = None, + groupby: Sequence[Expression] | None = None, ) -> IndividualNode[Table]: return build_clickhouse_node( "ev", @@ -107,7 +108,7 @@ def clickhouse_events_node( def clickhouse_groups_node( selected_columns: Sequence[SelectedExpression], - condition: Optional[Expression] = None, + condition: Expression | None = None, ) -> IndividualNode[Table]: return build_clickhouse_node( "gr", @@ -119,7 +120,7 @@ def clickhouse_groups_node( def clickhouse_assignees_node( selected_columns: Sequence[SelectedExpression], - condition: Optional[Expression] = None, + condition: Expression | None = None, ) -> IndividualNode[Table]: return build_clickhouse_node( "as", diff --git a/tests/query/joins/test_branch_cutter.py b/tests/query/joins/test_branch_cutter.py index 3ac7c4ffc5d..9612dba7b38 100644 --- a/tests/query/joins/test_branch_cutter.py +++ b/tests/query/joins/test_branch_cutter.py @@ -1,4 +1,4 @@ -from typing import Generator +from collections.abc import Generator import pytest @@ -750,7 +750,7 @@ def test_branch_cutter( expression: Expression, expected: SubExpression, main_expr: MainQueryExpression ) -> None: - def alias_generator() -> Generator[str, None, None]: + def alias_generator() -> Generator[str]: i = 0 while True: i += 1 diff --git a/tests/query/joins/test_equivalence_adder.py b/tests/query/joins/test_equivalence_adder.py index bfc8d751cd5..a4cd54bf50b 100644 --- a/tests/query/joins/test_equivalence_adder.py +++ b/tests/query/joins/test_equivalence_adder.py @@ -46,9 +46,7 @@ def test_classify_and_replace() -> None: assert condition.transform( partial(_replace_col, "ev", "project_id", "gr", "project_id") - ) == binary_condition( - ConditionFunctions.EQ, Column(None, "gr", "project_id"), Literal(None, 1) - ) + ) == binary_condition(ConditionFunctions.EQ, Column(None, "gr", "project_id"), Literal(None, 1)) ENTITY_GROUP_JOIN = JoinClause( @@ -67,19 +65,13 @@ def test_classify_and_replace() -> None: TEST_REPLACEMENT = [ pytest.param( - binary_condition( - ConditionFunctions.EQ, Column(None, "ev", "event_id"), Literal(None, 1) - ), + binary_condition(ConditionFunctions.EQ, Column(None, "ev", "event_id"), Literal(None, 1)), ENTITY_GROUP_JOIN, - binary_condition( - ConditionFunctions.EQ, Column(None, "ev", "event_id"), Literal(None, 1) - ), + binary_condition(ConditionFunctions.EQ, Column(None, "ev", "event_id"), Literal(None, 1)), id="No condition to add", ), pytest.param( - binary_condition( - ConditionFunctions.EQ, Column(None, "ev", "event_id"), Literal(None, 1) - ), + binary_condition(ConditionFunctions.EQ, Column(None, "ev", "event_id"), Literal(None, 1)), JoinClause( IndividualNode("ev", EntitySource(EntityKey.EVENTS, EVENTS_SCHEMA, None)), IndividualNode("ev2", EntitySource(EntityKey.EVENTS, EVENTS_SCHEMA, None)), @@ -109,9 +101,7 @@ def test_classify_and_replace() -> None: id="Self join. Duplicate condition", ), pytest.param( - binary_condition( - ConditionFunctions.EQ, Column(None, "ev", "project_id"), Literal(None, 1) - ), + binary_condition(ConditionFunctions.EQ, Column(None, "ev", "project_id"), Literal(None, 1)), ENTITY_GROUP_JOIN, combine_and_conditions( [ @@ -255,9 +245,7 @@ def test_classify_and_replace() -> None: ] -@pytest.mark.parametrize( - "initial_condition, join_clause, expected_expr", TEST_REPLACEMENT -) +@pytest.mark.parametrize("initial_condition, join_clause, expected_expr", TEST_REPLACEMENT) def test_add_equivalent_condition( initial_condition: Expression, join_clause: JoinClause[EntitySource], diff --git a/tests/query/joins/test_equivalences.py b/tests/query/joins/test_equivalences.py index f5c63eb4d83..e90977b43ff 100644 --- a/tests/query/joins/test_equivalences.py +++ b/tests/query/joins/test_equivalences.py @@ -28,9 +28,7 @@ pytest.param( JoinClause( IndividualNode("ev", EntitySource(EntityKey.EVENTS, EVENTS_SCHEMA, None)), - IndividualNode( - "gr", EntitySource(EntityKey.GROUPEDMESSAGE, GROUPS_SCHEMA, None) - ), + IndividualNode("gr", EntitySource(EntityKey.GROUPEDMESSAGE, GROUPS_SCHEMA, None)), [ JoinCondition( JoinConditionExpression("ev", "group_id"), @@ -59,12 +57,8 @@ pytest.param( JoinClause( JoinClause( - IndividualNode( - "ev", EntitySource(EntityKey.EVENTS, EVENTS_SCHEMA, None) - ), - IndividualNode( - "as", EntitySource(EntityKey.GROUPASSIGNEE, GROUPS_ASSIGNEE, None) - ), + IndividualNode("ev", EntitySource(EntityKey.EVENTS, EVENTS_SCHEMA, None)), + IndividualNode("as", EntitySource(EntityKey.GROUPASSIGNEE, GROUPS_ASSIGNEE, None)), [ JoinCondition( JoinConditionExpression("ev", "group_id"), @@ -74,9 +68,7 @@ JoinType.INNER, None, ), - IndividualNode( - "gr", EntitySource(EntityKey.GROUPEDMESSAGE, GROUPS_SCHEMA, None) - ), + IndividualNode("gr", EntitySource(EntityKey.GROUPEDMESSAGE, GROUPS_SCHEMA, None)), [ JoinCondition( JoinConditionExpression("ev", "group_id"), @@ -117,12 +109,8 @@ pytest.param( JoinClause( JoinClause( - IndividualNode( - "ev", EntitySource(EntityKey.EVENTS, EVENTS_SCHEMA, None) - ), - IndividualNode( - "gr", EntitySource(EntityKey.GROUPEDMESSAGE, GROUPS_SCHEMA, None) - ), + IndividualNode("ev", EntitySource(EntityKey.EVENTS, EVENTS_SCHEMA, None)), + IndividualNode("gr", EntitySource(EntityKey.GROUPEDMESSAGE, GROUPS_SCHEMA, None)), [ JoinCondition( JoinConditionExpression("ev", "group_id"), @@ -132,9 +120,7 @@ JoinType.INNER, None, ), - IndividualNode( - "as", EntitySource(EntityKey.GROUPASSIGNEE, GROUPS_ASSIGNEE, None) - ), + IndividualNode("as", EntitySource(EntityKey.GROUPASSIGNEE, GROUPS_ASSIGNEE, None)), [ JoinCondition( JoinConditionExpression("gr", "user_id"), @@ -176,9 +162,7 @@ @pytest.mark.parametrize("join, graph", TEST_CASES) -def test_find_equivalences( - join: JoinClause[EntitySource], graph: EquivalenceGraph -) -> None: +def test_find_equivalences(join: JoinClause[EntitySource], graph: EquivalenceGraph) -> None: override_entity_map(EntityKey.EVENTS, Events()) override_entity_map(EntityKey.GROUPEDMESSAGE, GroupedMessage()) override_entity_map(EntityKey.GROUPASSIGNEE, GroupAssignee()) diff --git a/tests/query/joins/test_metrics_subqueries.py b/tests/query/joins/test_metrics_subqueries.py index 67ad37d359a..8f7bba06665 100644 --- a/tests/query/joins/test_metrics_subqueries.py +++ b/tests/query/joins/test_metrics_subqueries.py @@ -181,9 +181,7 @@ def test_subquery_generator_metrics() -> None: assert selected_columns == expected_outer_query_selected # Test outer conditions - assert ( - original_query.get_condition() is None - ), "all conditions should be pushed down" + assert original_query.get_condition() is None, "all conditions should be pushed down" # Test outer groupby expected_outer_query_groupby = [] @@ -235,9 +233,7 @@ def test_subquery_generator_metrics() -> None: assert selected_columns == expected_lhs_selected # The ordering of conditions is not guaranteed so we sort them by alias before asserting - flattened_conditions: list[Expression] = get_first_level_and_conditions( - lhs.get_condition() - ) + flattened_conditions: list[Expression] = get_first_level_and_conditions(lhs.get_condition()) flattened_conditions.sort(key=lambda x: x.alias) assert flattened_conditions == [ f.greaterOrEquals( @@ -315,9 +311,7 @@ def test_subquery_generator_metrics() -> None: assert selected_columns == expected_rhs_selected # Test rhs conditions - flattened_conditions: list[Expression] = get_first_level_and_conditions( - rhs.get_condition() - ) + flattened_conditions: list[Expression] = get_first_level_and_conditions(rhs.get_condition()) flattened_conditions.sort(key=lambda x: x.alias) assert flattened_conditions == [ f.greaterOrEquals( diff --git a/tests/query/joins/test_semi_join.py b/tests/query/joins/test_semi_join.py index cd9b50163a2..26237498cf7 100644 --- a/tests/query/joins/test_semi_join.py +++ b/tests/query/joins/test_semi_join.py @@ -1,4 +1,5 @@ -from typing import Mapping, Optional, cast +from collections.abc import Mapping +from typing import cast import pytest @@ -207,10 +208,10 @@ @pytest.mark.parametrize("query, expected_semi_join", TEST_CASES) def test_subquery_generator( query: CompositeQuery[Table], - expected_semi_join: Mapping[str, Optional[JoinModifier]], + expected_semi_join: Mapping[str, JoinModifier | None], ) -> None: def assert_transformation( - clause: JoinClause[Table], expected: Mapping[str, Optional[JoinModifier]] + clause: JoinClause[Table], expected: Mapping[str, JoinModifier | None] ) -> None: right_alias = clause.right_node.alias assert right_alias in expected and clause.join_modifier == expected[right_alias], ( diff --git a/tests/query/joins/test_subqueries.py b/tests/query/joins/test_subqueries.py index 97836eafc34..832319e6c69 100644 --- a/tests/query/joins/test_subqueries.py +++ b/tests/query/joins/test_subqueries.py @@ -44,9 +44,7 @@ ), right_node=IndividualNode( alias="gr", - data_source=Entity( - EntityKey.GROUPEDMESSAGE, ColumnSet(GROUPS_SCHEMA.columns), None - ), + data_source=Entity(EntityKey.GROUPEDMESSAGE, ColumnSet(GROUPS_SCHEMA.columns), None), ), keys=[ JoinCondition( @@ -116,9 +114,7 @@ "_snuba_group_id", Column("_snuba_group_id", None, "id"), ), - SelectedExpression( - "_snuba_id", Column("_snuba_id", None, "id") - ), + SelectedExpression("_snuba_id", Column("_snuba_id", None, "id")), ], ), ), @@ -131,9 +127,7 @@ (Column("_snuba_ev.event_id", "ev", "_snuba_ev.event_id"),), ), ), - SelectedExpression( - "group_id", Column("_snuba_group_id", "gr", "_snuba_group_id") - ), + SelectedExpression("group_id", Column("_snuba_group_id", "gr", "_snuba_group_id")), ], ), id="Basic join with select", @@ -166,9 +160,7 @@ "_snuba_group_id", Column("_snuba_group_id", None, "id"), ), - SelectedExpression( - "_snuba_id", Column("_snuba_id", None, "id") - ), + SelectedExpression("_snuba_id", Column("_snuba_id", None, "id")), ], condition=binary_condition( ConditionFunctions.EQ, @@ -178,9 +170,7 @@ ), ), selected_columns=[ - SelectedExpression( - "group_id", Column("_snuba_group_id", "gr", "_snuba_group_id") - ), + SelectedExpression("group_id", Column("_snuba_group_id", "gr", "_snuba_group_id")), ], ), id="Query with condition", @@ -201,9 +191,7 @@ binary_condition( ConditionFunctions.EQ, Column("_snuba_group_id", "gr", "id"), - FunctionCall( - None, "f", (Column("_snuba_e_group_id", "ev", "group_id"),) - ), + FunctionCall(None, "f", (Column("_snuba_e_group_id", "ev", "group_id"),)), ), ), ), @@ -231,9 +219,7 @@ "_snuba_group_id", Column("_snuba_group_id", None, "id"), ), - SelectedExpression( - "_snuba_id", Column("_snuba_id", None, "id") - ), + SelectedExpression("_snuba_id", Column("_snuba_id", None, "id")), ], binary_condition( ConditionFunctions.EQ, @@ -243,9 +229,7 @@ ), ), selected_columns=[ - SelectedExpression( - "group_id", Column("_snuba_group_id", "gr", "_snuba_group_id") - ), + SelectedExpression("group_id", Column("_snuba_group_id", "gr", "_snuba_group_id")), ], condition=binary_condition( ConditionFunctions.EQ, @@ -291,9 +275,7 @@ "_snuba_group_id", Column("_snuba_group_id", None, "id"), ), - SelectedExpression( - "_snuba_id", Column("_snuba_id", None, "id") - ), + SelectedExpression("_snuba_id", Column("_snuba_id", None, "id")), ], ), ), @@ -318,9 +300,7 @@ join_type=JoinType.INNER, ), selected_columns=[ - SelectedExpression( - "group_id", Column("_snuba_group_id", "gr", "_snuba_group_id") - ), + SelectedExpression("group_id", Column("_snuba_group_id", "gr", "_snuba_group_id")), ], ), id="Multi entity join", @@ -413,9 +393,7 @@ None, "greater", ( - FunctionCall( - None, "min", (Column("_snuba_gen_2", "ev", "_snuba_gen_2"),) - ), + FunctionCall(None, "min", (Column("_snuba_gen_2", "ev", "_snuba_gen_2"),)), Literal(None, "sometime"), ), ), @@ -460,9 +438,7 @@ from_clause=events_groups_join( events_node( [ - SelectedExpression( - "_snuba_a_col", Column("_snuba_a_col", None, "column") - ), + SelectedExpression("_snuba_a_col", Column("_snuba_a_col", None, "column")), SelectedExpression( "_snuba_another_func", FunctionCall( @@ -483,16 +459,12 @@ "_snuba_another_col", Column("_snuba_another_col", None, "another_column"), ), - SelectedExpression( - "_snuba_id", Column("_snuba_id", None, "id") - ), + SelectedExpression("_snuba_id", Column("_snuba_id", None, "id")), ] ), ), selected_columns=[ - SelectedExpression( - "a_col", Column("_snuba_a_col", "ev", "_snuba_a_col") - ), + SelectedExpression("a_col", Column("_snuba_a_col", "ev", "_snuba_a_col")), SelectedExpression( "a_func", FunctionCall( @@ -527,17 +499,11 @@ def test_subquery_generator( generate_subqueries(original_query) - original_map = cast( - JoinClause[Entity], original_query.get_from_clause() - ).get_alias_node_map() - processed_map = cast( - JoinClause[Entity], processed_query.get_from_clause() - ).get_alias_node_map() + original_map = cast(JoinClause[Entity], original_query.get_from_clause()).get_alias_node_map() + processed_map = cast(JoinClause[Entity], processed_query.get_from_clause()).get_alias_node_map() for k, node in original_map.items(): - report = cast(LogicalQuery, node.data_source).equals( - processed_map[k].data_source - ) + report = cast(LogicalQuery, node.data_source).equals(processed_map[k].data_source) assert report[0], f"Failed equality {k}: {report[1]}" report = original_query.equals(processed_query) diff --git a/tests/query/parser/test_formula_mql_query.py b/tests/query/parser/test_formula_mql_query.py index 27113a650b3..04adce6f8c4 100644 --- a/tests/query/parser/test_formula_mql_query.py +++ b/tests/query/parser/test_formula_mql_query.py @@ -46,9 +46,7 @@ def time_expression( "toStartOfInterval", ( Column("_snuba_timestamp", table_alias, "timestamp"), - FunctionCall( - None, "toIntervalSecond", (Literal(None, to_interval_seconds),) - ), + FunctionCall(None, "toIntervalSecond", (Literal(None, to_interval_seconds),)), Literal(None, "Universal"), ), ) @@ -59,9 +57,7 @@ def subscriptable_expression( ) -> SubscriptableReference: return SubscriptableReference( alias=f"_snuba_tags_raw[{tag_key}]", - column=Column( - alias="_snuba_tags_raw", table_name=table_alias, column_name="tags_raw" - ), + column=Column(alias="_snuba_tags_raw", table_name=table_alias, column_name="tags_raw"), key=Literal(alias=None, value=tag_key), ) @@ -129,9 +125,7 @@ def condition(table_alias: str | None = None) -> list[FunctionCall]: } -def timeseries( - agg: str, metric_id: int, condition: FunctionCall | None = None -) -> FunctionCall: +def timeseries(agg: str, metric_id: int, condition: FunctionCall | None = None) -> FunctionCall: metric_condition = FunctionCall( None, "equals", @@ -176,7 +170,9 @@ def metric_id_condition(metric_id: int, table_alias: str | None = None) -> Funct def tag_column(tag: str, table_alias: str | None = None) -> SubscriptableReference: - tag_val = mql_context.get("indexer_mappings").get(tag) # type: ignore + indexer_mappings = mql_context.get("indexer_mappings") + assert isinstance(indexer_mappings, dict) + tag_val = indexer_mappings.get(tag) return SubscriptableReference( alias=f"_snuba_tags_raw[{tag_val}]", column=Column( @@ -233,9 +229,7 @@ def test_simple_formula() -> None: metric_condition1 = metric_id_condition(123456, "d0") metric_condition2 = metric_id_condition(123456, "d1") formula_condition = combine_and_conditions( - condition("d0") - + condition("d1") - + [tag_condition, metric_condition1, metric_condition2] + condition("d0") + condition("d1") + [tag_condition, metric_condition1, metric_condition2] ) expected = CompositeQuery( @@ -318,12 +312,8 @@ def test_bracket_on_formula() -> None: ), keys=[ JoinCondition( - left=JoinConditionExpression( - table_alias="d3", column="d3.time" - ), - right=JoinConditionExpression( - table_alias="d2", column="d2.time" - ), + left=JoinConditionExpression(table_alias="d3", column="d3.time"), + right=JoinConditionExpression(table_alias="d2", column="d2.time"), ) ], join_type=JoinType.INNER, @@ -466,12 +456,8 @@ def test_multiple_filter_same_groupby_formula() -> None: ), keys=[ JoinCondition( - left=JoinConditionExpression( - table_alias="d1", column="tags_raw[222222]" - ), - right=JoinConditionExpression( - table_alias="d0", column="tags_raw[222222]" - ), + left=JoinConditionExpression(table_alias="d1", column="tags_raw[222222]"), + right=JoinConditionExpression(table_alias="d0", column="tags_raw[222222]"), ), JoinCondition( left=JoinConditionExpression(table_alias="d1", column="d1.time"), @@ -490,12 +476,8 @@ def test_multiple_filter_same_groupby_formula() -> None: ) tag_condition3 = binary_condition( "or", - binary_condition( - "equals", tag_column("transaction", "d1"), Literal(None, "prod") - ), - binary_condition( - "equals", tag_column("status_code", "d1"), Literal(None, "400") - ), + binary_condition("equals", tag_column("transaction", "d1"), Literal(None, "prod")), + binary_condition("equals", tag_column("status_code", "d1"), Literal(None, "400")), ) metric_condition1 = metric_id_condition(123456, "d0") metric_condition2 = metric_id_condition(123456, "d1") @@ -804,12 +786,8 @@ def test_groupby() -> None: ), keys=[ JoinCondition( - left=JoinConditionExpression( - table_alias="d1", column="tags_raw[333333]" - ), - right=JoinConditionExpression( - table_alias="d0", column="tags_raw[333333]" - ), + left=JoinConditionExpression(table_alias="d1", column="tags_raw[333333]"), + right=JoinConditionExpression(table_alias="d0", column="tags_raw[333333]"), ), JoinCondition( left=JoinConditionExpression(table_alias="d1", column="d1.time"), @@ -826,9 +804,7 @@ def test_groupby() -> None: metric_condition1 = metric_id_condition(123456, "d0") metric_condition2 = metric_id_condition(123456, "d1") formula_condition = combine_and_conditions( - condition("d0") - + condition("d1") - + [tag_condition, metric_condition1, metric_condition2] + condition("d0") + condition("d1") + [tag_condition, metric_condition1, metric_condition2] ) expected = CompositeQuery( @@ -911,12 +887,8 @@ def test_groupby_with_totals() -> None: ), keys=[ JoinCondition( - left=JoinConditionExpression( - table_alias="d1", column="tags_raw[333333]" - ), - right=JoinConditionExpression( - table_alias="d0", column="tags_raw[333333]" - ), + left=JoinConditionExpression(table_alias="d1", column="tags_raw[333333]"), + right=JoinConditionExpression(table_alias="d0", column="tags_raw[333333]"), ) ], join_type=JoinType.INNER, @@ -929,9 +901,7 @@ def test_groupby_with_totals() -> None: metric_condition1 = metric_id_condition(123456, "d0") metric_condition2 = metric_id_condition(123456, "d1") formula_condition = combine_and_conditions( - condition("d0") - + condition("d1") - + [tag_condition, metric_condition1, metric_condition2] + condition("d0") + condition("d1") + [tag_condition, metric_condition1, metric_condition2] ) expected = CompositeQuery( @@ -1021,9 +991,7 @@ def test_onesided_groupby() -> None: metric_condition1 = metric_id_condition(123456, "d0") metric_condition2 = metric_id_condition(123456, "d1") formula_condition = combine_and_conditions( - condition("d0") - + condition("d1") - + [tag_condition, metric_condition1, metric_condition2] + condition("d0") + condition("d1") + [tag_condition, metric_condition1, metric_condition2] ) expected = CompositeQuery( @@ -1305,9 +1273,7 @@ def test_curried_aggregate_formula() -> None: metric_condition1 = metric_id_condition(123456, "d0") metric_condition2 = metric_id_condition(123456, "d1") formula_condition = combine_and_conditions( - condition("d0") - + condition("d1") - + [tag_condition, metric_condition1, metric_condition2] + condition("d0") + condition("d1") + [tag_condition, metric_condition1, metric_condition2] ) expected = CompositeQuery( @@ -1386,9 +1352,7 @@ def test_formula_no_groupby_no_interval_with_totals() -> None: metric_condition1 = metric_id_condition(123456, "d0") metric_condition2 = metric_id_condition(123456, "d1") formula_condition = combine_and_conditions( - condition("d0") - + condition("d1") - + [tag_condition, metric_condition1, metric_condition2] + condition("d0") + condition("d1") + [tag_condition, metric_condition1, metric_condition2] ) expected = CompositeQuery( @@ -1453,9 +1417,7 @@ def test_formula_onesided_groupby_no_interval_with_totals() -> None: metric_condition1 = metric_id_condition(123456, "d0") metric_condition2 = metric_id_condition(123456, "d1") formula_condition = combine_and_conditions( - condition("d0") - + condition("d1") - + [tag_condition, metric_condition1, metric_condition2] + condition("d0") + condition("d1") + [tag_condition, metric_condition1, metric_condition2] ) expected = CompositeQuery( @@ -1495,9 +1457,7 @@ def test_formula_extrapolation_with_nested_functions() -> None: None, "apdex", ( - FunctionCall( - None, "avg_weighted", (Column("_snuba_value", "d0", "value"),) - ), + FunctionCall(None, "avg_weighted", (Column("_snuba_value", "d0", "value"),)), Literal(None, 123.0), ), ), @@ -1577,12 +1537,10 @@ def test_formula_extrapolation_with_nested_functions() -> None: mql_context_with_extrapolation = deepcopy(mql_context) mql_context_with_extrapolation["extrapolate"] = True - mql_context_with_extrapolation["indexer_mappings"][ - "c:transactions/duration@millisecond" - ] = 123456 - - query = parse_mql_query( - str(query_body), mql_context_with_extrapolation, generic_metrics + mql_context_with_extrapolation["indexer_mappings"]["c:transactions/duration@millisecond"] = ( + 123456 ) + + query = parse_mql_query(str(query_body), mql_context_with_extrapolation, generic_metrics) eq, reason = query.equals(expected) assert eq, reason diff --git a/tests/query/parser/test_invalid_legacy_query.py b/tests/query/parser/test_invalid_legacy_query.py index 57d0d6dc216..a27b9bdc9ac 100644 --- a/tests/query/parser/test_invalid_legacy_query.py +++ b/tests/query/parser/test_invalid_legacy_query.py @@ -1,4 +1,5 @@ -from typing import Any, MutableMapping, Type +from collections.abc import MutableMapping +from typing import Any import pytest from snuba_sdk.legacy import json_to_snql @@ -28,7 +29,7 @@ @pytest.mark.parametrize("query_body, expected_exception", test_cases) def test_failures( query_body: MutableMapping[str, Any], - expected_exception: Type[InvalidQueryException], + expected_exception: type[InvalidQueryException], ) -> None: with pytest.raises(expected_exception): json_to_snql(query_body, "events") diff --git a/tests/query/parser/test_parser.py b/tests/query/parser/test_parser.py index 1d6eb65e4b8..fd80fec9382 100644 --- a/tests/query/parser/test_parser.py +++ b/tests/query/parser/test_parser.py @@ -8,6 +8,7 @@ """ from datetime import datetime +from typing import cast import pytest @@ -114,9 +115,7 @@ def test_mql() -> None: column("project_id", None, "_snuba_project_id"), f.tuple(literal(1)), ), - in_cond( - column("org_id", None, "_snuba_org_id"), f.tuple(literal(1)) - ), + in_cond(column("org_id", None, "_snuba_org_id"), f.tuple(literal(1))), ), ), and_cond( @@ -125,17 +124,11 @@ def test_mql() -> None: column("use_case_id", None, "_snuba_use_case_id"), literal("transactions"), ), - f.equals( - column("granularity", None, "_snuba_granularity"), literal(60) - ), + f.equals(column("granularity", None, "_snuba_granularity"), literal(60)), ), and_cond( - f.equals( - column("metric_id", None, "_snuba_metric_id"), literal(123456) - ), - in_cond( - tags_raw["888"], f.tuple(literal("dist1"), literal("dist2")) - ), + f.equals(column("metric_id", None, "_snuba_metric_id"), literal(123456)), + in_cond(tags_raw["888"], f.tuple(literal("dist1"), literal("dist2"))), ), ), ), @@ -222,9 +215,7 @@ def test_mql_extrapolate() -> None: column("project_id", None, "_snuba_project_id"), f.tuple(literal(1)), ), - in_cond( - column("org_id", None, "_snuba_org_id"), f.tuple(literal(1)) - ), + in_cond(column("org_id", None, "_snuba_org_id"), f.tuple(literal(1))), ), ), and_cond( @@ -233,17 +224,11 @@ def test_mql_extrapolate() -> None: column("use_case_id", None, "_snuba_use_case_id"), literal("transactions"), ), - f.equals( - column("granularity", None, "_snuba_granularity"), literal(60) - ), + f.equals(column("granularity", None, "_snuba_granularity"), literal(60)), ), and_cond( - f.equals( - column("metric_id", None, "_snuba_metric_id"), literal(123456) - ), - in_cond( - tags_raw["888"], f.tuple(literal("dist1"), literal("dist2")) - ), + f.equals(column("metric_id", None, "_snuba_metric_id"), literal(123456)), + in_cond(tags_raw["888"], f.tuple(literal("dist1"), literal("dist2"))), ), ), ), @@ -329,9 +314,7 @@ def test_mql_wildcards() -> None: column("project_id", None, "_snuba_project_id"), f.tuple(literal(1)), ), - in_cond( - column("org_id", None, "_snuba_org_id"), f.tuple(literal(1)) - ), + in_cond(column("org_id", None, "_snuba_org_id"), f.tuple(literal(1))), ), ), and_cond( @@ -340,14 +323,10 @@ def test_mql_wildcards() -> None: column("use_case_id", None, "_snuba_use_case_id"), literal("transactions"), ), - f.equals( - column("granularity", None, "_snuba_granularity"), literal(60) - ), + f.equals(column("granularity", None, "_snuba_granularity"), literal(60)), ), and_cond( - f.equals( - column("metric_id", None, "_snuba_metric_id"), literal(123456) - ), + f.equals(column("metric_id", None, "_snuba_metric_id"), literal(123456)), f.like(tags_raw["42"], literal("before_wildcard_%")), ), ), @@ -434,9 +413,7 @@ def test_mql_negated_wildcards() -> None: column("project_id", None, "_snuba_project_id"), f.tuple(literal(1)), ), - in_cond( - column("org_id", None, "_snuba_org_id"), f.tuple(literal(1)) - ), + in_cond(column("org_id", None, "_snuba_org_id"), f.tuple(literal(1))), ), ), and_cond( @@ -445,14 +422,10 @@ def test_mql_negated_wildcards() -> None: column("use_case_id", None, "_snuba_use_case_id"), literal("transactions"), ), - f.equals( - column("granularity", None, "_snuba_granularity"), literal(60) - ), + f.equals(column("granularity", None, "_snuba_granularity"), literal(60)), ), and_cond( - f.equals( - column("metric_id", None, "_snuba_metric_id"), literal(123456) - ), + f.equals(column("metric_id", None, "_snuba_metric_id"), literal(123456)), f.notLike(tags_raw["42"], literal("before_wildcard_%")), ), ), @@ -508,9 +481,7 @@ def test_formula_mql() -> None: "offset": None, } - def timeseries( - agg: str, metric_id: int, condition: FunctionCall | None = None - ) -> FunctionCall: + def timeseries(agg: str, metric_id: int, condition: FunctionCall | None = None) -> FunctionCall: metric_condition = FunctionCall( None, "equals", @@ -543,7 +514,7 @@ def timeseries( ) def tag_column(tag: str) -> SubscriptableReference: - tag_val = mql_context.get("indexer_mappings").get(tag) # type: ignore + tag_val = cast("dict[str, int]", mql_context["indexer_mappings"]).get(tag) return SubscriptableReference( alias=f"_snuba_tags_raw[{tag_val}]", column=Column( @@ -561,9 +532,7 @@ def tag_column(tag: str) -> SubscriptableReference: timeseries( "sumIf", 123456, - binary_condition( - "equals", tag_column("status_code"), Literal(None, "200") - ), + binary_condition("equals", tag_column("status_code"), Literal(None, "200")), ), timeseries("sumIf", 123456), "_snuba_aggregate_value", @@ -640,9 +609,7 @@ def tag_column(tag: str) -> SubscriptableReference: ), and_cond( and_cond( - in_cond( - column("org_id", "d0", "_snuba_org_id"), f.tuple(literal(1)) - ), + in_cond(column("org_id", "d0", "_snuba_org_id"), f.tuple(literal(1))), f.equals( column("use_case_id", "d0", "_snuba_use_case_id"), literal("transactions"), @@ -689,9 +656,7 @@ def tag_column(tag: str) -> SubscriptableReference: column("granularity", "d1", "_snuba_granularity"), literal(60), ), - f.equals( - NestedColumn("tags_raw", "d0")["222222"], literal("200") - ), + f.equals(NestedColumn("tags_raw", "d0")["222222"], literal("200")), ), and_cond( f.equals( @@ -740,9 +705,7 @@ def snql_conditions_with_default(*conditions: str) -> str: foo(zoo(offset)) AS offset WHERE {conditions} ORDER BY group_id ASC -""".format( - conditions=snql_conditions_with_default("foo(issue_id) AS group_id = 1") - ) +""".format(conditions=snql_conditions_with_default("foo(issue_id) AS group_id = 1")) expected = Query( QueryEntity(EntityKey.EVENTS, get_entity(EntityKey.EVENTS).get_data_model()), selected_columns=[ @@ -877,16 +840,12 @@ def snql_conditions_with_default(*conditions: str) -> str: ] return " AND ".join(list(conditions) + DEFAULT_TEST_QUERY_CONDITIONS) - conds = " OR ".join( - ["(group_id=268128807 AND group_id=268128807)" for i in range(NUM_CONDS)] - ) + conds = " OR ".join(["(group_id=268128807 AND group_id=268128807)" for i in range(NUM_CONDS)]) snql = """ MATCH (events) SELECT group_id, goo(partition) AS issue_id, foo(zoo(offset)) AS offset WHERE {conditions} ORDER BY group_id ASC - """.format( - conditions=snql_conditions_with_default(f"({conds})") - ) + """.format(conditions=snql_conditions_with_default(f"({conds})")) parse_snql_query(snql, get_dataset("events")) diff --git a/tests/query/parser/unit_tests/test_parse_snql_query_initial.py b/tests/query/parser/unit_tests/test_parse_snql_query_initial.py index 3df080f1d59..910977947dd 100644 --- a/tests/query/parser/unit_tests/test_parse_snql_query_initial.py +++ b/tests/query/parser/unit_tests/test_parse_snql_query_initial.py @@ -3,8 +3,6 @@ This tests the first stage of the SnQL parsing pipeline, which looks like SnQL->AST. """ -from typing import Type - import pytest from snuba.datasets.entities.entity_key import EntityKey @@ -900,6 +898,6 @@ def test_autogenerated(body: str, expected: Query | CompositeQuery[Entity]) -> N @pytest.mark.parametrize("body, expected_error", failure_cases) -def test_autogenerated_invalid(body: str, expected_error: Type[Exception]) -> None: +def test_autogenerated_invalid(body: str, expected_error: type[Exception]) -> None: with pytest.raises(expected_error): parse_snql_query_initial(body) diff --git a/tests/query/parser/unit_tests/test_post_process_and_validate_query.py b/tests/query/parser/unit_tests/test_post_process_and_validate_query.py index 9693887aa0d..60d169aba4a 100644 --- a/tests/query/parser/unit_tests/test_post_process_and_validate_query.py +++ b/tests/query/parser/unit_tests/test_post_process_and_validate_query.py @@ -4,7 +4,6 @@ """ from datetime import datetime -from typing import Type import pytest @@ -1952,7 +1951,7 @@ def test_autogenerated_invalid( QuerySettings | None, CustomProcessors | None, ], - expected_error: Type[Exception], + expected_error: type[Exception], ) -> None: query, dataset, settings, custom_processing = theinput timer = Timer("snql_pipeline") diff --git a/tests/query/parser/unit_tests/test_resolver_visitor.py b/tests/query/parser/unit_tests/test_resolver_visitor.py index 480241361ff..180ff613ff8 100644 --- a/tests/query/parser/unit_tests/test_resolver_visitor.py +++ b/tests/query/parser/unit_tests/test_resolver_visitor.py @@ -1,4 +1,4 @@ -from typing import Mapping +from collections.abc import Mapping import pytest @@ -16,7 +16,7 @@ TEST_CASES = [ pytest.param( Column(alias="a", table_name=None, column_name="a"), - {"b": FunctionCall(alias="b", function_name="f", parameters=tuple())}, + {"b": FunctionCall(alias="b", function_name="f", parameters=())}, False, Column(alias="a", table_name=None, column_name="a"), id="Simple Column - do nothing", @@ -30,9 +30,9 @@ ), pytest.param( Column(alias=None, table_name=None, column_name="ref"), - {"ref": FunctionCall(alias="ref", function_name="f", parameters=tuple())}, + {"ref": FunctionCall(alias="ref", function_name="f", parameters=())}, False, - FunctionCall(alias="ref", function_name="f", parameters=tuple()), + FunctionCall(alias="ref", function_name="f", parameters=()), id="Alias resolves to a simple function", ), pytest.param( @@ -142,7 +142,7 @@ pytest.param( CurriedFunctionCall( alias=None, - internal_function=FunctionCall(alias=None, function_name="f", parameters=tuple()), + internal_function=FunctionCall(alias=None, function_name="f", parameters=()), parameters=(Column(alias=None, table_name=None, column_name="a"),), ), { @@ -155,7 +155,7 @@ False, CurriedFunctionCall( alias=None, - internal_function=FunctionCall(alias=None, function_name="f", parameters=tuple()), + internal_function=FunctionCall(alias=None, function_name="f", parameters=()), parameters=( FunctionCall( alias="a", @@ -171,7 +171,7 @@ { "a": Lambda( alias="a", - parameters=tuple(), + parameters=(), transformation=FunctionCall( alias="b", function_name="f", @@ -183,7 +183,7 @@ True, Lambda( alias="a", - parameters=tuple(), + parameters=(), transformation=FunctionCall( alias="b", function_name="f", diff --git a/tests/query/parser/validation/test_functions.py b/tests/query/parser/validation/test_functions.py index 50d4a053169..3f757907482 100644 --- a/tests/query/parser/validation/test_functions.py +++ b/tests/query/parser/validation/test_functions.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Mapping, Optional, Sequence, Type +from collections.abc import Mapping, Sequence from unittest.mock import MagicMock import pytest @@ -80,7 +80,7 @@ def validate( def test_functions( default_validators: Mapping[str, FunctionCallValidator], entity_validators: Mapping[str, FunctionCallValidator], - exception: Optional[Type[InvalidExpressionException]], + exception: type[InvalidExpressionException] | None, ) -> None: fn_cached = functions.default_validators functions.default_validators = default_validators @@ -89,7 +89,7 @@ def test_functions( entity_return.return_value = entity_validators events_entity = get_entity(EntityKey.EVENTS) cached = events_entity.get_function_call_validators - setattr(events_entity, "get_function_call_validators", entity_return) + events_entity.get_function_call_validators = entity_return data_source = QueryEntity(EntityKey.EVENTS, ColumnSet([])) expression = FunctionCall( @@ -102,7 +102,7 @@ def test_functions( FunctionCallsValidator().validate(expression, data_source) # TODO: This should use fixture to do this - setattr(events_entity, "get_function_call_validators", cached) + events_entity.get_function_call_validators = cached functions.default_validators = fn_cached diff --git a/tests/query/processors/query_builders.py b/tests/query/processors/query_builders.py index aa7305cb9ff..b64d06b2d1c 100644 --- a/tests/query/processors/query_builders.py +++ b/tests/query/processors/query_builders.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Optional, Sequence +from collections.abc import Sequence from snuba.clickhouse.columns import ColumnSet from snuba.clickhouse.query import Query as ClickhouseQuery @@ -14,9 +14,9 @@ def build_query( - selected_columns: Optional[Sequence[Expression]] = None, - condition: Optional[Expression] = None, - having: Optional[Expression] = None, + selected_columns: Sequence[Expression] | None = None, + condition: Expression | None = None, + having: Expression | None = None, ) -> ClickhouseQuery: return ClickhouseQuery( Table( diff --git a/tests/query/processors/test_array_has_optimizer.py b/tests/query/processors/test_array_has_optimizer.py index b92326b081b..f94b9d3d1ef 100644 --- a/tests/query/processors/test_array_has_optimizer.py +++ b/tests/query/processors/test_array_has_optimizer.py @@ -1,5 +1,3 @@ -from typing import Optional - import pytest from snuba.clickhouse.query import Query as ClickhouseQuery @@ -193,7 +191,7 @@ @pytest.mark.parametrize("query, expected_conditions", array_has_tests) def test_array_has_optimizer( query: ClickhouseQuery, - expected_conditions: Optional[Expression], + expected_conditions: Expression | None, ) -> None: query_settings = HTTPQuerySettings() array_has_processor = ArrayHasOptimizer(["spans.op", "spans.group"]) diff --git a/tests/query/processors/test_arrayjoin_optimizer.py b/tests/query/processors/test_arrayjoin_optimizer.py index 50c58aa2f30..c43b3e1f052 100644 --- a/tests/query/processors/test_arrayjoin_optimizer.py +++ b/tests/query/processors/test_arrayjoin_optimizer.py @@ -1,6 +1,6 @@ import uuid +from collections.abc import Sequence from datetime import datetime -from typing import Optional, Sequence import pytest @@ -37,9 +37,9 @@ def build_query( - selected_columns: Optional[Sequence[Expression]] = None, - condition: Optional[Expression] = None, - having: Optional[Expression] = None, + selected_columns: Sequence[Expression] | None = None, + condition: Expression | None = None, + having: Expression | None = None, ) -> ClickhouseQuery: return ClickhouseQuery( None, @@ -437,8 +437,9 @@ def parse_and_process(snql_query: str) -> ClickhouseQuery: ) .data ) - ArrayJoinKeyValueOptimizer("tags").process_query(clickhouse_query, request.query_settings) # type: ignore - return clickhouse_query # type: ignore + assert isinstance(clickhouse_query, ClickhouseQuery) + ArrayJoinKeyValueOptimizer("tags").process_query(clickhouse_query, request.query_settings) + return clickhouse_query @pytest.mark.redis_db diff --git a/tests/query/processors/test_arrayjoin_spans_optimizer.py b/tests/query/processors/test_arrayjoin_spans_optimizer.py index 59173becaa9..4df60de540a 100644 --- a/tests/query/processors/test_arrayjoin_spans_optimizer.py +++ b/tests/query/processors/test_arrayjoin_spans_optimizer.py @@ -1,5 +1,3 @@ -from typing import List, Optional, Set - import pytest from snuba.clickhouse.query import Query as ClickhouseQuery @@ -113,7 +111,7 @@ @pytest.mark.parametrize("query, expected_result", spans_op_filter_tests) def test_get_single_column_filters( - query: ClickhouseQuery, expected_result: Set[Expression] + query: ClickhouseQuery, expected_result: set[Expression] ) -> None: """ Test the algorithm identifies conditions on op/group that can potentially @@ -133,9 +131,7 @@ def test_get_single_column_filters( condition=binary_condition( ConditionFunctions.EQ, spans_op_group_col, - FunctionCall( - None, "tuple", (Literal(None, "db"), Literal(None, "a" * 16)) - ), + FunctionCall(None, "tuple", (Literal(None, "db"), Literal(None, "a" * 16))), ), ), {("db", "a" * 16)}, @@ -268,20 +264,17 @@ def test_get_single_column_filters( @pytest.mark.parametrize("query, expected_result", spans_op_group_tuple_filter_tests) def test_get_multiple_columns_filters( - query: ClickhouseQuery, expected_result: Set[Expression] + query: ClickhouseQuery, expected_result: set[Expression] ) -> None: """ Test the algorithm identifies conditions on the tuple (op, group) that can potentially be pre-filtered through arrayFilter. """ - assert ( - set(get_multiple_columns_filters(query, ("spans.op", "spans.group"))) - == expected_result - ) + assert set(get_multiple_columns_filters(query, ("spans.op", "spans.group"))) == expected_result def array_join_col(ops=None, groups=None, op_groups=None): - conditions: List[Expression] = [] + conditions: list[Expression] = [] argument_name = "arg" argument = Argument(None, argument_name) @@ -300,9 +293,7 @@ def array_join_col(ops=None, groups=None, op_groups=None): binary_condition( ConditionFunctions.IN, tupleElement(None, argument, Literal(None, 2)), - FunctionCall( - None, "tuple", tuple(Literal(None, group) for group in groups) - ), + FunctionCall(None, "tuple", tuple(Literal(None, group) for group in groups)), ) ) @@ -322,9 +313,7 @@ def array_join_col(ops=None, groups=None, op_groups=None): None, "tuple", tuple( - FunctionCall( - None, "tuple", (Literal(None, op), Literal(None, group)) - ) + FunctionCall(None, "tuple", (Literal(None, op), Literal(None, group))) for op, group in op_groups ), ), @@ -338,9 +327,7 @@ def array_join_col(ops=None, groups=None, op_groups=None): Lambda( None, ("x", "y", "z"), - FunctionCall( - None, "tuple", tuple(Argument(None, arg) for arg in ("x", "y", "z")) - ), + FunctionCall(None, "tuple", tuple(Argument(None, arg) for arg in ("x", "y", "z"))), ), Column(None, None, "spans.op"), Column(None, None, "spans.group"), @@ -369,9 +356,7 @@ def array_join_col(ops=None, groups=None, op_groups=None): id="no spans columns in select clause", ), pytest.param( - build_query( - selected_columns=[spans_op_col, spans_group_col, spans_exclusive_time_col] - ), + build_query(selected_columns=[spans_op_col, spans_group_col, spans_exclusive_time_col]), [ SelectedExpression( "spans_op", tupleElement("spans_op", array_join_col(), Literal(None, 1)) @@ -382,9 +367,7 @@ def array_join_col(ops=None, groups=None, op_groups=None): ), SelectedExpression( "spans_exclusive_time", - tupleElement( - "spans_exclusive_time", array_join_col(), Literal(None, 3) - ), + tupleElement("spans_exclusive_time", array_join_col(), Literal(None, 3)), ), ], None, @@ -406,15 +389,11 @@ def array_join_col(ops=None, groups=None, op_groups=None): ), SelectedExpression( "spans_group", - tupleElement( - "spans_group", array_join_col(ops=["db"]), Literal(None, 2) - ), + tupleElement("spans_group", array_join_col(ops=["db"]), Literal(None, 2)), ), SelectedExpression( "spans_exclusive_time", - tupleElement( - "spans_exclusive_time", array_join_col(ops=["db"]), Literal(None, 3) - ), + tupleElement("spans_exclusive_time", array_join_col(ops=["db"]), Literal(None, 3)), ), ], binary_condition( @@ -439,15 +418,11 @@ def array_join_col(ops=None, groups=None, op_groups=None): [ SelectedExpression( "spans_op", - tupleElement( - "spans_op", array_join_col(ops=["db", "http"]), Literal(None, 1) - ), + tupleElement("spans_op", array_join_col(ops=["db", "http"]), Literal(None, 1)), ), SelectedExpression( "spans_group", - tupleElement( - "spans_group", array_join_col(ops=["db", "http"]), Literal(None, 2) - ), + tupleElement("spans_group", array_join_col(ops=["db", "http"]), Literal(None, 2)), ), SelectedExpression( "spans_exclusive_time", @@ -466,9 +441,7 @@ def array_join_col(ops=None, groups=None, op_groups=None): FunctionCall(None, "has", (spans_ops, Literal(None, "http"))), ), in_condition( - tupleElement( - "spans_op", array_join_col(ops=["db", "http"]), Literal(None, 1) - ), + tupleElement("spans_op", array_join_col(ops=["db", "http"]), Literal(None, 1)), [Literal(None, "db"), Literal(None, "http")], ), ), @@ -486,15 +459,11 @@ def array_join_col(ops=None, groups=None, op_groups=None): [ SelectedExpression( "spans_op", - tupleElement( - "spans_op", array_join_col(groups=["a" * 16]), Literal(None, 1) - ), + tupleElement("spans_op", array_join_col(groups=["a" * 16]), Literal(None, 1)), ), SelectedExpression( "spans_group", - tupleElement( - "spans_group", array_join_col(groups=["a" * 16]), Literal(None, 2) - ), + tupleElement("spans_group", array_join_col(groups=["a" * 16]), Literal(None, 2)), ), SelectedExpression( "spans_exclusive_time", @@ -510,9 +479,7 @@ def array_join_col(ops=None, groups=None, op_groups=None): FunctionCall(None, "has", (spans_groups, Literal(None, "a" * 16))), binary_condition( ConditionFunctions.EQ, - tupleElement( - "spans_group", array_join_col(groups=["a" * 16]), Literal(None, 2) - ), + tupleElement("spans_group", array_join_col(groups=["a" * 16]), Literal(None, 2)), Literal(None, "a" * 16), ), ), @@ -576,9 +543,7 @@ def array_join_col(ops=None, groups=None, op_groups=None): condition=binary_condition( ConditionFunctions.EQ, spans_op_group_col, - FunctionCall( - None, "tuple", (Literal(None, "db"), Literal(None, "a" * 16)) - ), + FunctionCall(None, "tuple", (Literal(None, "db"), Literal(None, "a" * 16))), ), ), [ @@ -632,9 +597,7 @@ def array_join_col(ops=None, groups=None, op_groups=None): ), ), ), - FunctionCall( - None, "tuple", (Literal(None, "db"), Literal(None, "a" * 16)) - ), + FunctionCall(None, "tuple", (Literal(None, "db"), Literal(None, "a" * 16))), ), ), id="simple equals filter on op + group", @@ -712,16 +675,12 @@ def array_join_col(ops=None, groups=None, op_groups=None): ( tupleElement( "spans_op", - array_join_col( - op_groups=[("db", "a" * 16), ("http", "b" * 16)] - ), + array_join_col(op_groups=[("db", "a" * 16), ("http", "b" * 16)]), Literal(None, 1), ), tupleElement( "spans_group", - array_join_col( - op_groups=[("db", "a" * 16), ("http", "b" * 16)] - ), + array_join_col(op_groups=[("db", "a" * 16), ("http", "b" * 16)]), Literal(None, 2), ), ), @@ -754,17 +713,13 @@ def array_join_col(ops=None, groups=None, op_groups=None): ) def test_spans_processor( query: ClickhouseQuery, - expected_selected_columns: List[SelectedExpression], - expected_conditions: Optional[Expression], + expected_selected_columns: list[SelectedExpression], + expected_conditions: Expression | None, ) -> None: query_settings = HTTPQuerySettings() - bloom_filter_processor = BloomFilterOptimizer( - "spans", ["op", "group"], ["exclusive_time"] - ) + bloom_filter_processor = BloomFilterOptimizer("spans", ["op", "group"], ["exclusive_time"]) bloom_filter_processor.process_query(query, query_settings) - array_join_processor = ArrayJoinOptimizer( - "spans", ["op", "group"], ["exclusive_time"] - ) + array_join_processor = ArrayJoinOptimizer("spans", ["op", "group"], ["exclusive_time"]) array_join_processor.process_query(query, query_settings) assert query.get_selected_columns() == expected_selected_columns assert query.get_condition() == expected_conditions diff --git a/tests/query/processors/test_clickhouse_settings_override.py b/tests/query/processors/test_clickhouse_settings_override.py index c5b20d9681c..7ce392f124c 100644 --- a/tests/query/processors/test_clickhouse_settings_override.py +++ b/tests/query/processors/test_clickhouse_settings_override.py @@ -1,4 +1,5 @@ -from typing import Any, MutableMapping +from collections.abc import MutableMapping +from typing import Any import pytest diff --git a/tests/query/processors/test_empty_tag_condition_processor.py b/tests/query/processors/test_empty_tag_condition_processor.py index 34a9b77daef..e8ca30b0160 100644 --- a/tests/query/processors/test_empty_tag_condition_processor.py +++ b/tests/query/processors/test_empty_tag_condition_processor.py @@ -154,7 +154,7 @@ ] -@pytest.mark.parametrize("query, expected", test_data) # type: ignore +@pytest.mark.parametrize("query, expected", test_data) def test_empty_tag_condition(query: Query, expected: Expression) -> None: query_settings = HTTPQuerySettings() processor = EmptyTagConditionProcessor(column_name="tags.key") diff --git a/tests/query/processors/test_fixedstring_array_column_processor.py b/tests/query/processors/test_fixedstring_array_column_processor.py index c8157eb215e..3596225168a 100644 --- a/tests/query/processors/test_fixedstring_array_column_processor.py +++ b/tests/query/processors/test_fixedstring_array_column_processor.py @@ -60,7 +60,7 @@ def test_uuid_array_column_processor( condition=expected, ) - FixedStringArrayColumnProcessor(set(["column1", "column2"]), 32).process_query( + FixedStringArrayColumnProcessor({"column1", "column2"}, 32).process_query( unprocessed_query, HTTPQuerySettings() ) assert unprocessed_query.get_selected_columns() == [ diff --git a/tests/query/processors/test_granularity_processor.py b/tests/query/processors/test_granularity_processor.py index c62c5526288..f34c362a0d8 100644 --- a/tests/query/processors/test_granularity_processor.py +++ b/tests/query/processors/test_granularity_processor.py @@ -1,5 +1,4 @@ from datetime import datetime -from typing import List, Optional import pytest @@ -54,7 +53,7 @@ def test_granularity_added( entity_key: EntityKey, column: str, - requested_granularity: Optional[int], + requested_granularity: int | None, query_granularity: int, ) -> None: query = Query( @@ -119,7 +118,7 @@ def test_granularity_added( def test_granularity_added_in_condition( entity_key: EntityKey, column: str, - requested_granularity: Optional[int], + requested_granularity: int | None, query_granularity: int, ) -> None: query = Query( @@ -188,8 +187,8 @@ def test_granularity_added_in_condition( def test_multiple_granularities_added_in_condition( entity_key: EntityKey, column: str, - requested_granularity: List[int], - query_granularity: List[int], + requested_granularity: list[int], + query_granularity: list[int], ) -> None: query_with_multiple_conditions = Query( QueryEntity(entity_key, ColumnSet([])), @@ -309,7 +308,7 @@ def test_multiple_granularities_added_in_condition( def test_granularity_enum_mapping( entity_key: EntityKey, column: str, - requested_granularity: Optional[int], + requested_granularity: int | None, query_granularity: int, ) -> None: query = Query( @@ -376,7 +375,7 @@ def test_granularity_enum_mapping( def test_granularity_enum_mapping_in_condition( entity_key: EntityKey, column: str, - requested_granularity: Optional[int], + requested_granularity: int | None, query_granularity: int, ) -> None: query = Query( @@ -447,8 +446,8 @@ def test_granularity_enum_mapping_in_condition( def test_multiple_granularities_enum_mapping_in_condition( entity_key: EntityKey, column: str, - requested_granularity: List[int], - query_granularity: List[int], + requested_granularity: list[int], + query_granularity: list[int], ) -> None: query = Query( QueryEntity(entity_key, ColumnSet([])), diff --git a/tests/query/processors/test_handled_functions.py b/tests/query/processors/test_handled_functions.py index 538ae682697..573b80ba307 100644 --- a/tests/query/processors/test_handled_functions.py +++ b/tests/query/processors/test_handled_functions.py @@ -28,7 +28,7 @@ def test_handled_processor() -> None: FunctionCall( "result", "isHandled", - tuple(), + (), ), ), ], @@ -106,7 +106,7 @@ def test_not_handled_processor() -> None: FunctionCall( "result", "notHandled", - tuple(), + (), ), ), ], diff --git a/tests/query/processors/test_hexint_column_processor.py b/tests/query/processors/test_hexint_column_processor.py index 890e9713e38..80b975c25b2 100644 --- a/tests/query/processors/test_hexint_column_processor.py +++ b/tests/query/processors/test_hexint_column_processor.py @@ -73,7 +73,7 @@ def test_hexint_column_processor(unprocessed: Expression, formatted_value: str) ) hex = f.hex(column("column1")) - HexIntColumnProcessor(set(["column1"])).process_query(unprocessed_query, HTTPQuerySettings()) + HexIntColumnProcessor({"column1"}).process_query(unprocessed_query, HTTPQuerySettings()) assert unprocessed_query.get_selected_columns() == [ SelectedExpression( "column1", diff --git a/tests/query/processors/test_low_cardinality_processor.py b/tests/query/processors/test_low_cardinality_processor.py index f8c79374186..ea222223f75 100644 --- a/tests/query/processors/test_low_cardinality_processor.py +++ b/tests/query/processors/test_low_cardinality_processor.py @@ -26,9 +26,7 @@ def test_low_cardinality_processor() -> None: ), ) - LowCardinalityProcessor(["environment"]).process_query( - unprocessed, HTTPQuerySettings() - ) + LowCardinalityProcessor(["environment"]).process_query(unprocessed, HTTPQuerySettings()) assert expected.get_condition() == unprocessed.get_condition() @@ -49,7 +47,5 @@ def test_low_cardinality_processor_with_tags() -> None: ), ) - LowCardinalityProcessor(["tags[environment]"]).process_query( - unprocessed, HTTPQuerySettings() - ) + LowCardinalityProcessor(["tags[environment]"]).process_query(unprocessed, HTTPQuerySettings()) assert expected.get_condition() == unprocessed.get_condition() diff --git a/tests/query/processors/test_mandatory_condition_applier.py b/tests/query/processors/test_mandatory_condition_applier.py index 0b065f19bfe..b03aff6d946 100644 --- a/tests/query/processors/test_mandatory_condition_applier.py +++ b/tests/query/processors/test_mandatory_condition_applier.py @@ -1,5 +1,4 @@ import copy -from typing import List import pytest @@ -52,7 +51,7 @@ @pytest.mark.parametrize("table, mand_conditions", test_data) -def test_mand_conditions(table: str, mand_conditions: List[FunctionCall]) -> None: +def test_mand_conditions(table: str, mand_conditions: list[FunctionCall]) -> None: query = Query( Table( table, diff --git a/tests/query/processors/test_null_column_caster.py b/tests/query/processors/test_null_column_caster.py index 26d8beb6591..8a682177deb 100644 --- a/tests/query/processors/test_null_column_caster.py +++ b/tests/query/processors/test_null_column_caster.py @@ -3,9 +3,8 @@ import pytest -from snuba.clickhouse.columns import ColumnSet, DateTime +from snuba.clickhouse.columns import ColumnSet, DateTime, String, UInt from snuba.clickhouse.columns import SchemaModifiers as Modifiers -from snuba.clickhouse.columns import String, UInt from snuba.clickhouse.query import Query from snuba.clusters.storage_sets import StorageSetKey from snuba.datasets.readiness_state import ReadinessState @@ -82,9 +81,7 @@ selected_columns=[ SelectedExpression( name="_snuba_count_unique_sdk_version", - expression=FunctionCall( - None, "uniq", (Column(None, None, "mismatched1"),) - ), + expression=FunctionCall(None, "uniq", (Column(None, None, "mismatched1"),)), ) ], ), @@ -118,9 +115,7 @@ selected_columns=[ SelectedExpression( name="_snuba_count_unique_sdk_version", - expression=FunctionCall( - None, "uniq", (Column(None, None, "mismatched2"),) - ), + expression=FunctionCall(None, "uniq", (Column(None, None, "mismatched2"),)), ) ], ), @@ -154,9 +149,7 @@ selected_columns=[ SelectedExpression( name="_snuba_count_unique_sdk_version", - expression=FunctionCall( - None, "uniq", (Column(None, None, "not_mismatched"),) - ), + expression=FunctionCall(None, "uniq", (Column(None, None, "not_mismatched"),)), ) ], ), @@ -165,9 +158,7 @@ selected_columns=[ SelectedExpression( name="_snuba_count_unique_sdk_version", - expression=FunctionCall( - None, "uniq", (Column(None, None, "not_mismatched"),) - ), + expression=FunctionCall(None, "uniq", (Column(None, None, "not_mismatched"),)), ) ], ), @@ -208,10 +199,9 @@ def _mock_get_storage(storage_key: StorageKey) -> ReadableTableStorage: if storage_key == StorageKey("storage1"): return Storage1 - elif storage_key == StorageKey("storage2"): + if storage_key == StorageKey("storage2"): return Storage2 - else: - raise Exception("UNKNOWN STORAGE KEY " + str(storage_key)) + raise Exception("UNKNOWN STORAGE KEY " + str(storage_key)) def test_find_mismatched_columns(): diff --git a/tests/query/processors/test_prewhere.py b/tests/query/processors/test_prewhere.py index c668ee23cd5..906a00957cc 100644 --- a/tests/query/processors/test_prewhere.py +++ b/tests/query/processors/test_prewhere.py @@ -1,4 +1,5 @@ -from typing import Any, MutableMapping, Optional, Sequence +from collections.abc import MutableMapping, Sequence +from typing import Any import pytest from snuba_sdk.legacy import json_to_snql @@ -220,8 +221,8 @@ def test_prewhere( query_body: MutableMapping[str, Any], keys: Sequence[str], omit_if_final_keys: Sequence[str], - new_ast_condition: Optional[Expression], - new_prewhere_ast_condition: Optional[Expression], + new_ast_condition: Expression | None, + new_prewhere_ast_condition: Expression | None, final: bool, ) -> None: settings.MAX_PREWHERE_CONDITIONS = 2 @@ -252,11 +253,7 @@ def test_prewhere( def verify_expressions(top_level: Expression, expected: Expression) -> bool: actual_conds = get_first_level_and_conditions(top_level) expected_conds = get_first_level_and_conditions(expected) - for cond in expected_conds: - if cond not in actual_conds: - return False - - return True + return all(cond in actual_conds for cond in expected_conds) if new_ast_condition: condition = query.get_condition() diff --git a/tests/query/processors/test_timeseries_processor.py b/tests/query/processors/test_timeseries_processor.py index 317e1ecdd65..288d8a689b8 100644 --- a/tests/query/processors/test_timeseries_processor.py +++ b/tests/query/processors/test_timeseries_processor.py @@ -1,5 +1,3 @@ -from typing import Optional - import pytest from snuba.clickhouse.columns import ColumnSet @@ -168,9 +166,9 @@ ) def test_timeseries_format_expressions( granularity: int, - condition: Optional[FunctionCall], + condition: FunctionCall | None, exp_column: FunctionCall, - exp_condition: Optional[FunctionCall], + exp_condition: FunctionCall | None, formatted_column: str, formatted_condition: str, ) -> None: diff --git a/tests/query/processors/test_tuple_unaliaser.py b/tests/query/processors/test_tuple_unaliaser.py index 9d7e9d889f5..d4c9368be21 100644 --- a/tests/query/processors/test_tuple_unaliaser.py +++ b/tests/query/processors/test_tuple_unaliaser.py @@ -18,7 +18,7 @@ def equals(op1: Expression, op2: Expression, alias: str | None = None) -> FunctionCall: - return FunctionCall(alias, "eq", tuple([op1, op2])) + return FunctionCall(alias, "eq", (op1, op2)) def some_tuple(alias: str | None): @@ -66,9 +66,7 @@ def identity(expression: Expression) -> Expression: identity( equals( # alias of the tuple of internal function is removed (it is not useful) - tupleElement( - None, some_tuple(alias="ayyy"), Literal(None, 1) - ), + tupleElement(None, some_tuple(alias="ayyy"), Literal(None, 1)), Literal(None, 300), ) ) @@ -83,9 +81,7 @@ def identity(expression: Expression) -> Expression: identity( equals( # alias of the tuple of internal function is removed (it is not useful) - tupleElement( - None, some_tuple(alias=None), Literal(None, 1) - ), + tupleElement(None, some_tuple(alias=None), Literal(None, 1)), Literal(None, 300), ) ) @@ -98,9 +94,7 @@ def identity(expression: Expression) -> Expression: pytest.param( build_query( selected_columns=[ - CurriedFunctionCall( - None, some_tuple(alias="foo"), (Literal(None, "4"),) - ) + CurriedFunctionCall(None, some_tuple(alias="foo"), (Literal(None, "4"),)) ] ), build_query( diff --git a/tests/query/processors/test_uuid_array_column_processor.py b/tests/query/processors/test_uuid_array_column_processor.py index 93c296cceb0..cc6d29ce11c 100644 --- a/tests/query/processors/test_uuid_array_column_processor.py +++ b/tests/query/processors/test_uuid_array_column_processor.py @@ -107,7 +107,7 @@ def test_uuid_array_column_processor( condition=expected, ) - UUIDArrayColumnProcessor(set(["column1", "column2"])).process_query( + UUIDArrayColumnProcessor({"column1", "column2"}).process_query( unprocessed_query, HTTPQuerySettings() ) assert unprocessed_query.get_selected_columns() == [ diff --git a/tests/query/processors/test_uuid_column_processor.py b/tests/query/processors/test_uuid_column_processor.py index 5de15e25ce9..037403af1ea 100644 --- a/tests/query/processors/test_uuid_column_processor.py +++ b/tests/query/processors/test_uuid_column_processor.py @@ -275,7 +275,7 @@ def test_uuid_column_processor( condition=expected, ) - UUIDColumnProcessor(set(["column1", "column2"])).process_query( + UUIDColumnProcessor({"column1", "column2"}).process_query( unprocessed_query, HTTPQuerySettings() ) assert unprocessed_query.get_selected_columns() == [ @@ -333,6 +333,6 @@ def test_invalid_uuid(unprocessed: Expression) -> None: ) with pytest.raises(ColumnTypeError): - UUIDColumnProcessor(set(["column1", "column2"])).process_query( + UUIDColumnProcessor({"column1", "column2"}).process_query( unprocessed_query, HTTPQuerySettings() ) diff --git a/tests/query/snql/test_invalid_queries.py b/tests/query/snql/test_invalid_queries.py index dda2942298c..384107174bb 100644 --- a/tests/query/snql/test_invalid_queries.py +++ b/tests/query/snql/test_invalid_queries.py @@ -1,5 +1,4 @@ import re -from typing import Optional from unittest import mock import pytest @@ -102,7 +101,7 @@ def test_failures(query_body: str, message: str) -> None: "bookmark": (EntityKey.GROUPEDMESSAGE, "first_release_id"), } - def events_mock(relationship: str) -> Optional[JoinRelationship]: + def events_mock(relationship: str) -> JoinRelationship | None: if relationship not in mapping: return None entity_key, rhs_column = mapping[relationship] diff --git a/tests/query/snql/test_joins.py b/tests/query/snql/test_joins.py index 7cea98d70f9..bbbc57e48a1 100644 --- a/tests/query/snql/test_joins.py +++ b/tests/query/snql/test_joins.py @@ -1,5 +1,5 @@ import uuid -from typing import Sequence, Tuple, Union +from collections.abc import Sequence import pytest @@ -29,7 +29,7 @@ def node(alias: str, name: str) -> IndividualNode[QueryEntity]: def join_clause( - lhs_alias: str, lhs: Union[str, JoinClause[QueryEntity]], rhs: str + lhs_alias: str, lhs: str | JoinClause[QueryEntity], rhs: str ) -> JoinClause[QueryEntity]: rhs_alias, rhs = rhs.split(":", 1) return JoinClause( @@ -192,7 +192,7 @@ def join_clause( @pytest.mark.parametrize("clauses, expected", test_cases) -def test_joins(clauses: Sequence[Tuple[str, str]], expected: JoinClause[QueryEntity]) -> None: +def test_joins(clauses: Sequence[tuple[str, str]], expected: JoinClause[QueryEntity]) -> None: relationships = [] for clause in clauses: diff --git a/tests/query/snql/test_query.py b/tests/query/snql/test_query.py index e33ea66b88b..7af4f3718f2 100644 --- a/tests/query/snql/test_query.py +++ b/tests/query/snql/test_query.py @@ -312,7 +312,7 @@ def build_cond(tn: str) -> str: ), SelectedExpression( "count", - FunctionCall("_snuba_count", "count", tuple()), + FunctionCall("_snuba_count", "count", ()), ), ], groupby=[ @@ -1177,8 +1177,7 @@ def build_cond(tn: str) -> str: id="Multi multi join match", ), pytest.param( - "MATCH { MATCH (events) SELECT count() AS count BY title WHERE %s } SELECT max(count) AS max_count" - % added_condition, + f"MATCH {{ MATCH (events) SELECT count() AS count BY title WHERE {added_condition} }} SELECT max(count) AS max_count", CompositeQuery( from_clause=LogicalQuery( QueryEntity( @@ -1187,7 +1186,7 @@ def build_cond(tn: str) -> str: ), selected_columns=[ SelectedExpression("title", Column("_snuba_title", None, "title")), - SelectedExpression("count", FunctionCall("_snuba_count", "count", tuple())), + SelectedExpression("count", FunctionCall("_snuba_count", "count", ())), ], groupby=[Column("_snuba_title", None, "title")], condition=required_condition, @@ -1208,14 +1207,13 @@ def build_cond(tn: str) -> str: id="sub query match", ), pytest.param( - """MATCH { - MATCH { - MATCH (events) SELECT count() AS count BY title WHERE %s - } + f"""MATCH {{ + MATCH {{ + MATCH (events) SELECT count() AS count BY title WHERE {added_condition} + }} SELECT max(count) AS max_count - } - SELECT min(max_count) AS min_count""" - % added_condition, + }} + SELECT min(max_count) AS min_count""", CompositeQuery( from_clause=CompositeQuery( from_clause=LogicalQuery( @@ -1225,7 +1223,7 @@ def build_cond(tn: str) -> str: ), selected_columns=[ SelectedExpression("title", Column("_snuba_title", None, "title")), - SelectedExpression("count", FunctionCall("_snuba_count", "count", tuple())), + SelectedExpression("count", FunctionCall("_snuba_count", "count", ())), ], groupby=[Column("_snuba_title", None, "title")], condition=required_condition, @@ -1325,14 +1323,14 @@ def build_cond(tn: str) -> str: ), SelectedExpression( "count", - FunctionCall("_snuba_count", "count", tuple()), + FunctionCall("_snuba_count", "count", ()), ), ], groupby=[Column("_snuba_tags_key", None, "tags_key")], order_by=[ OrderBy( OrderByDirection.DESC, - FunctionCall("_snuba_count", "count", tuple()), + FunctionCall("_snuba_count", "count", ()), ), OrderBy( OrderByDirection.ASC, @@ -1405,7 +1403,7 @@ def build_cond(tn: str) -> str: selected_columns=[ SelectedExpression( "times_seen", - FunctionCall("_snuba_times_seen", "count", tuple()), + FunctionCall("_snuba_times_seen", "count", ()), ), ], limit=1000, @@ -1440,7 +1438,7 @@ def build_cond(tn: str) -> str: selected_columns=[ SelectedExpression( "times_seen", - FunctionCall("_snuba_times_seen", "count", tuple()), + FunctionCall("_snuba_times_seen", "count", ()), ), ], limit=1000, @@ -1475,7 +1473,7 @@ def build_cond(tn: str) -> str: selected_columns=[ SelectedExpression( "times_seen", - FunctionCall("_snuba_times_seen", "count", tuple()), + FunctionCall("_snuba_times_seen", "count", ()), ), ], limit=1000, @@ -1549,7 +1547,7 @@ def build_cond(tn: str) -> str: ), SelectedExpression( "count", - FunctionCall("_snuba_count", "count", tuple()), + FunctionCall("_snuba_count", "count", ()), ), ], groupby=[Column("_snuba_transaction_name", None, "transaction_name")], diff --git a/tests/query/snql/test_query_column_validation.py b/tests/query/snql/test_query_column_validation.py index 083fe53bd14..4f3f5c8e622 100644 --- a/tests/query/snql/test_query_column_validation.py +++ b/tests/query/snql/test_query_column_validation.py @@ -1,5 +1,6 @@ import datetime -from typing import Any, Generator +from collections.abc import Generator +from typing import Any import pytest @@ -43,9 +44,7 @@ ), selected_columns=[ SelectedExpression("title", Column("_snuba_title", None, "title")), - SelectedExpression( - "count", FunctionCall("_snuba_count", "count", tuple()) - ), + SelectedExpression("count", FunctionCall("_snuba_count", "count", ())), ], groupby=[Column("_snuba_title", None, "title")], condition=binary_condition( @@ -120,19 +119,13 @@ selected_columns=[ SelectedExpression( "4-5", - FunctionCall( - "_snuba_4-5", "minus", (Literal(None, 4), Literal(None, 5)) - ), - ), - SelectedExpression( - "e.event_id", Column("_snuba_e.event_id", "e", "event_id") + FunctionCall("_snuba_4-5", "minus", (Literal(None, 4), Literal(None, 5))), ), + SelectedExpression("e.event_id", Column("_snuba_e.event_id", "e", "event_id")), ], condition=and_cond( and_cond( - f.equals( - column("project_id", "e", "_snuba_e.project_id"), literal(1) - ), + f.equals(column("project_id", "e", "_snuba_e.project_id"), literal(1)), f.greaterOrEquals( column("timestamp", "e", "_snuba_e.timestamp"), literal(datetime.datetime(2021, 1, 1, 0, 0)), @@ -144,9 +137,7 @@ column("timestamp", "e", "_snuba_e.timestamp"), literal(datetime.datetime(2021, 1, 3, 0, 0)), ), - f.equals( - column("project_id", "t", "_snuba_t.project_id"), literal(1) - ), + f.equals(column("project_id", "t", "_snuba_t.project_id"), literal(1)), ), and_cond( f.greaterOrEquals( @@ -324,9 +315,7 @@ ], condition=and_cond( and_cond( - f.equals( - column("project_id", None, "_snuba_project_id"), literal(1) - ), + f.equals(column("project_id", None, "_snuba_project_id"), literal(1)), f.greaterOrEquals( column("timestamp", None, "_snuba_timestamp"), literal(datetime.datetime(2021, 1, 1, 0, 0)), @@ -364,9 +353,7 @@ ], condition=and_cond( and_cond( - f.equals( - column("project_id", None, "_snuba_project_id"), literal(1) - ), + f.equals(column("project_id", None, "_snuba_project_id"), literal(1)), f.greaterOrEquals( column("timestamp", None, "_snuba_timestamp"), literal(datetime.datetime(2021, 1, 1, 0, 0)), @@ -397,7 +384,7 @@ @pytest.fixture(autouse=True) -def set_configs(redis_db: None) -> Generator[None, None, None]: +def set_configs(redis_db: None) -> Generator[None]: old_max = state.get_config("max_days") old_align = state.get_config("date_align_seconds") state.set_config("max_days", 5) diff --git a/tests/query/snql/test_storage_query.py b/tests/query/snql/test_storage_query.py index 897dffa5f0b..bc0e9461de4 100644 --- a/tests/query/snql/test_storage_query.py +++ b/tests/query/snql/test_storage_query.py @@ -65,13 +65,9 @@ def build_cond(tn: str) -> str: selected_columns=[ SelectedExpression( "4-5", - FunctionCall( - "_snuba_4-5", "minus", (Literal(None, 4), Literal(None, 5)) - ), - ), - SelectedExpression( - "trace_id", Column("_snuba_trace_id", None, "trace_id") + FunctionCall("_snuba_4-5", "minus", (Literal(None, 4), Literal(None, 5))), ), + SelectedExpression("trace_id", Column("_snuba_trace_id", None, "trace_id")), ], granularity=60, condition=required_condition, @@ -87,9 +83,7 @@ def build_cond(tn: str) -> str: Query( QueryStorage(key=StorageKey("eap_items")), selected_columns=[ - SelectedExpression( - "trace_id", Column("_snuba_trace_id", None, "trace_id") - ), + SelectedExpression("trace_id", Column("_snuba_trace_id", None, "trace_id")), ], granularity=None, condition=and_cond( @@ -111,9 +105,7 @@ def build_cond(tn: str) -> str: Query( QueryStorage(key=StorageKey("eap_items"), sample=0.1), selected_columns=[ - SelectedExpression( - "trace_id", Column("_snuba_trace_id", None, "trace_id") - ), + SelectedExpression("trace_id", Column("_snuba_trace_id", None, "trace_id")), ], granularity=None, condition=and_cond( @@ -130,10 +122,9 @@ def build_cond(tn: str) -> str: id="basic_query-sample", ), pytest.param( - """MATCH { - MATCH STORAGE(eap_items) SELECT trace_id, duration_ms WHERE %s LIMIT 100 - } SELECT max(duration_ms) AS max_duration LIMIT 100""" - % added_condition, + f"""MATCH {{ + MATCH STORAGE(eap_items) SELECT trace_id, duration_ms WHERE {added_condition} LIMIT 100 + }} SELECT max(duration_ms) AS max_duration LIMIT 100""", CompositeQuery( selected_columns=[ SelectedExpression( @@ -148,9 +139,7 @@ def build_cond(tn: str) -> str: from_clause=Query( QueryStorage(key=StorageKey("eap_items")), selected_columns=[ - SelectedExpression( - "trace_id", Column("_snuba_trace_id", None, "trace_id") - ), + SelectedExpression("trace_id", Column("_snuba_trace_id", None, "trace_id")), SelectedExpression( "duration_ms", Column("_snuba_duration_ms", None, "duration_ms") ), @@ -165,15 +154,12 @@ def build_cond(tn: str) -> str: id="composite_query", ), pytest.param( - """ MATCH STORAGE(eap_items) SELECT trace_id, duration_ms AS duration WHERE %s LIMIT 100""" - % added_condition, + f""" MATCH STORAGE(eap_items) SELECT trace_id, duration_ms AS duration WHERE {added_condition} LIMIT 100""", StorageQuery.from_query( Query( QueryStorage(key=StorageKey("eap_items")), selected_columns=[ - SelectedExpression( - "trace_id", Column("_snuba_trace_id", None, "trace_id") - ), + SelectedExpression("trace_id", Column("_snuba_trace_id", None, "trace_id")), SelectedExpression( "duration", Column("_snuba_duration_ms", None, "duration_ms") ), diff --git a/tests/query/test_expressions.py b/tests/query/test_expressions.py index 0d3a44825c8..ec718ff0761 100644 --- a/tests/query/test_expressions.py +++ b/tests/query/test_expressions.py @@ -1,7 +1,6 @@ import uuid from dataclasses import replace from datetime import datetime -from typing import Set import pytest @@ -191,7 +190,7 @@ def test_hash() -> None: function_2 = CurriedFunctionCall(None, function_1, (column1,)) lm = Lambda(None, ("x", "y"), FunctionCall(None, "test", (Argument(None, "x"),))) - s: Set[Expression] = set() + s: set[Expression] = set() s.add(column1) s.add(column2) s.add(function_1) diff --git a/tests/query/test_matcher.py b/tests/query/test_matcher.py index 7b33da0ecb6..ef46344dae4 100644 --- a/tests/query/test_matcher.py +++ b/tests/query/test_matcher.py @@ -1,5 +1,3 @@ -from typing import Optional - import pytest from snuba.query.expressions import Column as ColumnExpr @@ -272,7 +270,7 @@ "relevant_and_wrong", ( FunctionCallExpr(None, "f", (ColumnExpr(None, None, "my_col"),)), - FunctionCallExpr(None, "bla", tuple()), + FunctionCallExpr(None, "bla", ()), ), ), None, @@ -294,13 +292,13 @@ "f_name", ( FunctionCallExpr(None, "f", (ColumnExpr(None, None, "my_col"),)), - FunctionCallExpr(None, "second_name", tuple()), + FunctionCallExpr(None, "second_name", ()), ), ), MatchResult( { "second_function_name": "second_name", - "second_function": FunctionCallExpr(None, "second_name", tuple()), + "second_function": FunctionCallExpr(None, "second_name", ()), }, ), ), @@ -366,7 +364,7 @@ def test_base_expression( name: str, pattern: Pattern[Expression], expression: Expression, - expected_result: Optional[MatchResult], + expected_result: MatchResult | None, ) -> None: res = pattern.match(expression) assert res == expected_result @@ -390,11 +388,11 @@ def test_accessors() -> None: "f_name", ( FunctionCallExpr(None, "f", (ColumnExpr(None, None, "my_col"),)), - FunctionCallExpr(None, "second_name", tuple()), + FunctionCallExpr(None, "second_name", ()), ), ) ) assert result is not None - assert result.expression("second_function") == FunctionCallExpr(None, "second_name", tuple()) + assert result.expression("second_function") == FunctionCallExpr(None, "second_name", ()) assert result.scalar("second_function_name") == "second_name" diff --git a/tests/query/test_query.py b/tests/query/test_query.py index ef718e943af..c96cb75670a 100644 --- a/tests/query/test_query.py +++ b/tests/query/test_query.py @@ -69,7 +69,7 @@ def test_query_experiments() -> None: query.set_experiments({"optimization1": True}) assert query.get_experiments() == {"optimization1": True} - assert query.get_experiment_value("optimization1") == True + assert query.get_experiment_value("optimization1") assert query.get_experiment_value("optimization2") is None query.add_experiment("optimization2", "group1") diff --git a/tests/query/test_query_ast.py b/tests/query/test_query_ast.py index 87aec987140..291dde43ade 100644 --- a/tests/query/test_query_ast.py +++ b/tests/query/test_query_ast.py @@ -1,4 +1,5 @@ -from typing import Any, MutableMapping +from collections.abc import MutableMapping +from typing import Any import pytest from snuba_sdk.legacy import json_to_snql @@ -118,9 +119,7 @@ def replace(exp: Expression) -> Expression: expected_query = Query( Table("my_table", ColumnSet([]), storage_key=StorageKey("dontmatter")), selected_columns=[ - SelectedExpression( - "alias", FunctionCall("alias", "tag", (Literal(None, "f1"),)) - ) + SelectedExpression("alias", FunctionCall("alias", "tag", (Literal(None, "f1"),))) ], array_join=None, condition=binary_condition( @@ -144,9 +143,7 @@ def replace(exp: Expression) -> Expression: assert query.get_having() == expected_query.get_having() assert query.get_orderby() == expected_query.get_orderby() - assert list(query.get_all_expressions()) == list( - expected_query.get_all_expressions() - ) + assert list(query.get_all_expressions()) == list(expected_query.get_all_expressions()) def test_get_all_columns_legacy() -> None: @@ -335,9 +332,7 @@ def test_quoted_column_regex_allows_for_mri_format() -> None: @pytest.mark.parametrize("query_body, expected_result", VALIDATION_TESTS) -def test_alias_validation( - query_body: MutableMapping[str, Any], expected_result: bool -) -> None: +def test_alias_validation(query_body: MutableMapping[str, Any], expected_result: bool) -> None: events = get_dataset("events") request = json_to_snql(query_body, "events") request.validate() diff --git a/tests/query/test_visitor.py b/tests/query/test_visitor.py index db46545e859..4be7405a932 100644 --- a/tests/query/test_visitor.py +++ b/tests/query/test_visitor.py @@ -1,4 +1,4 @@ -from typing import Iterable, List +from collections.abc import Iterable from snuba.query.expressions import ( Argument, @@ -18,33 +18,33 @@ class DummyVisitor(ExpressionVisitor[Iterable[Expression]]): def __init__(self) -> None: - self.__visited_nodes: List[Expression] = [] + self.__visited_nodes: list[Expression] = [] - def get_visited_nodes(self) -> List[Expression]: + def get_visited_nodes(self) -> list[Expression]: return self.__visited_nodes - def visit_literal(self, exp: Literal) -> List[Expression]: + def visit_literal(self, exp: Literal) -> list[Expression]: self.__visited_nodes.append(exp) return [exp] - def visit_column(self, exp: Column) -> List[Expression]: + def visit_column(self, exp: Column) -> list[Expression]: self.__visited_nodes.append(exp) return [exp] - def visit_subscriptable_reference(self, exp: SubscriptableReference) -> List[Expression]: + def visit_subscriptable_reference(self, exp: SubscriptableReference) -> list[Expression]: self.__visited_nodes.append(exp) return [exp, *exp.column.accept(self), *exp.key.accept(self)] - def visit_function_call(self, exp: FunctionCall) -> List[Expression]: - ret: List[Expression] = [] + def visit_function_call(self, exp: FunctionCall) -> list[Expression]: + ret: list[Expression] = [] self.__visited_nodes.append(exp) ret.append(exp) for param in exp.parameters: ret.extend(param.accept(self)) return ret - def visit_curried_function_call(self, exp: CurriedFunctionCall) -> List[Expression]: - ret: List[Expression] = [] + def visit_curried_function_call(self, exp: CurriedFunctionCall) -> list[Expression]: + ret: list[Expression] = [] self.__visited_nodes.append(exp) ret.append(exp) ret.extend(exp.internal_function.accept(self)) @@ -52,24 +52,24 @@ def visit_curried_function_call(self, exp: CurriedFunctionCall) -> List[Expressi ret.extend(param.accept(self)) return ret - def visit_argument(self, exp: Argument) -> List[Expression]: + def visit_argument(self, exp: Argument) -> list[Expression]: self.__visited_nodes.append(exp) return [exp] - def visit_lambda(self, exp: Lambda) -> List[Expression]: + def visit_lambda(self, exp: Lambda) -> list[Expression]: self.__visited_nodes.append(exp) self.__visited_nodes.extend(exp.transformation.accept(self)) - ret: List[Expression] = [exp] + ret: list[Expression] = [exp] ret.extend(exp.transformation.accept(self)) return ret - def visit_dangerous_raw_sql(self, exp: DangerousRawSQL) -> List[Expression]: + def visit_dangerous_raw_sql(self, exp: DangerousRawSQL) -> list[Expression]: self.__visited_nodes.append(exp) return [exp] - def visit_json_path(self, exp: JsonPath) -> List[Expression]: + def visit_json_path(self, exp: JsonPath) -> list[Expression]: self.__visited_nodes.append(exp) - ret: List[Expression] = [exp] + ret: list[Expression] = [exp] ret.extend(exp.base.accept(self)) return ret diff --git a/tests/query/validation/test_signature.py b/tests/query/validation/test_signature.py index ae3d2513725..9f999b4b845 100644 --- a/tests/query/validation/test_signature.py +++ b/tests/query/validation/test_signature.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence import pytest diff --git a/tests/querylog/test_query_metadata.py b/tests/querylog/test_query_metadata.py index 37f9a6e28d9..eda8b389085 100644 --- a/tests/querylog/test_query_metadata.py +++ b/tests/querylog/test_query_metadata.py @@ -1,6 +1,6 @@ """Unit tests for query_metadata module""" -from typing import Any, Dict +from typing import Any from clickhouse_driver.errors import ErrorCodes @@ -39,7 +39,7 @@ def test_get_request_status_clickhouse_error_generic(self) -> None: def test_get_request_status_too_many_bytes_with_policy(self) -> None: """Test that TOO_MANY_BYTES with allocation policy flag returns RATE_LIMITED""" error = ClickhouseError("Too many bytes", code=ErrorCodes.TOO_MANY_BYTES) - context: Dict[str, Any] = {"max_bytes_to_read_set_by_policy": True} + context: dict[str, Any] = {"max_bytes_to_read_set_by_policy": True} status = get_request_status(error, context) assert status.status == RequestStatus.RATE_LIMITED assert status.slo == SLO.FOR @@ -47,7 +47,7 @@ def test_get_request_status_too_many_bytes_with_policy(self) -> None: def test_get_request_status_too_many_bytes_without_policy(self) -> None: """Test that TOO_MANY_BYTES without allocation policy flag returns ERROR""" error = ClickhouseError("Too many bytes", code=ErrorCodes.TOO_MANY_BYTES) - context: Dict[str, Any] = {"max_bytes_to_read_set_by_policy": False} + context: dict[str, Any] = {"max_bytes_to_read_set_by_policy": False} status = get_request_status(error, context) assert status.status == RequestStatus.ERROR assert status.slo == SLO.AGAINST @@ -62,7 +62,7 @@ def test_get_request_status_too_many_bytes_no_context(self) -> None: def test_get_request_status_too_many_bytes_empty_context(self) -> None: """Test that TOO_MANY_BYTES with empty context returns ERROR""" error = ClickhouseError("Too many bytes", code=ErrorCodes.TOO_MANY_BYTES) - context: Dict[str, Any] = {} + context: dict[str, Any] = {} status = get_request_status(error, context) assert status.status == RequestStatus.ERROR assert status.slo == SLO.AGAINST diff --git a/tests/replacer/test_cluster_replacements.py b/tests/replacer/test_cluster_replacements.py index 1eff00086ff..5bdd9f6d98b 100644 --- a/tests/replacer/test_cluster_replacements.py +++ b/tests/replacer/test_cluster_replacements.py @@ -1,17 +1,10 @@ from __future__ import annotations from collections import defaultdict +from collections.abc import Callable, Generator, Mapping, MutableMapping, Sequence from concurrent.futures.thread import ThreadPoolExecutor from typing import ( Any, - Callable, - Generator, - List, - Mapping, - MutableMapping, - Optional, - Sequence, - Tuple, ) import pytest @@ -81,7 +74,7 @@ def override_cluster( monkeypatch: pytest.MonkeyPatch, events_db: None, redis_db: None, -) -> Generator[Callable[[bool], FakeClickhouseCluster], None, None]: +) -> Generator[Callable[[bool], FakeClickhouseCluster]]: with monkeypatch.context() as m: def override(healthy: bool) -> FakeClickhouseCluster: @@ -145,13 +138,13 @@ def parse_message( cls, message: ReplacementMessage[Any], context: ReplacementContext, - ) -> Optional[DummyReplacement]: + ) -> DummyReplacement | None: return cls() - def get_count_query(self, table_name: str) -> Optional[str]: + def get_count_query(self, table_name: str) -> str | None: return f"SELECT count() FROM {table_name} FINAL WHERE event_id = '6f0ccc03-6efb-4f7c-8005-d0c992106b31'" - def get_insert_query(self, table_name: str) -> Optional[str]: + def get_insert_query(self, table_name: str) -> str | None: required_columns = "project_id, timestamp, event_id" select_columns = "project_id, timestamp, event_id, group_id, primary_hash" @@ -346,11 +339,11 @@ def test_load_balancing(override_cluster: Callable[[bool], FakeClickhouseCluster @pytest.mark.redis_db @pytest.mark.events_db def test_local_executor( - nodes: Mapping[int, Sequence[Tuple[ClickhouseNode, bool]]], + nodes: Mapping[int, Sequence[tuple[ClickhouseNode, bool]]], backup_connection: ClickhousePool, expected_queries: Mapping[str, Sequence[str]], ) -> None: - queries: MutableMapping[str, List[str]] = defaultdict(list) + queries: MutableMapping[str, list[str]] = defaultdict(list) def run_query( connection: ClickhousePool, @@ -361,7 +354,7 @@ def run_query( connection.execute_robust(query) queries[connection.host].append(query) - all_nodes: List[Tuple[ClickhouseNode, bool]] = [] + all_nodes: list[tuple[ClickhouseNode, bool]] = [] for shard_nodes in nodes.values(): all_nodes.extend(shard_nodes) diff --git a/tests/replacer/test_load_balancer.py b/tests/replacer/test_load_balancer.py index 19a8e3bec4d..4bc52ee79a6 100644 --- a/tests/replacer/test_load_balancer.py +++ b/tests/replacer/test_load_balancer.py @@ -1,4 +1,4 @@ -from typing import Mapping, Sequence +from collections.abc import Mapping, Sequence import pytest diff --git a/tests/replacer/test_replacements_and_expiry.py b/tests/replacer/test_replacements_and_expiry.py index 53806f5c09e..7f8aa494d7f 100644 --- a/tests/replacer/test_replacements_and_expiry.py +++ b/tests/replacer/test_replacements_and_expiry.py @@ -29,7 +29,7 @@ def test_project_does_not_expire_within_expiry(self) -> None: set_config_auto_replacements_bypass_projects([1], self.proj1_add_time) assert set( get_config_auto_replacements_bypass_projects(self.proj1_expiry - timedelta(minutes=1)) - ) == set([1]) + ) == {1} @pytest.mark.redis_db def test_project_expires_after_expiry(self) -> None: @@ -49,10 +49,10 @@ def test_multiple_projects(self) -> None: set_config_auto_replacements_bypass_projects([2], self.proj2_add_time) assert set( get_config_auto_replacements_bypass_projects(self.proj1_expiry - timedelta(minutes=1)) - ) == set([1, 2]) + ) == {1, 2} assert set( get_config_auto_replacements_bypass_projects(self.proj1_expiry + timedelta(minutes=1)) - ) == set([2]) + ) == {2} @pytest.mark.redis_db def test_expiry_does_not_update(self) -> None: @@ -82,13 +82,16 @@ def test_expiry_window_changes(self, mock: mock.MagicMock) -> None: # project 1 expires after 5 minutes assert set( get_config_auto_replacements_bypass_projects(self.proj1_add_time + timedelta(minutes=6)) - ) == set([2]) + ) == {2} # project 2 expires at 10 minutes assert set( get_config_auto_replacements_bypass_projects(self.proj2_add_time + timedelta(minutes=9)) - ) == set([2]) - assert set( - get_config_auto_replacements_bypass_projects( - self.proj2_add_time + timedelta(minutes=11) + ) == {2} + assert ( + set( + get_config_auto_replacements_bypass_projects( + self.proj2_add_time + timedelta(minutes=11) + ) ) - ) == set([]) + == set() + ) diff --git a/tests/request/test_build_request.py b/tests/request/test_build_request.py index 5f448d0b020..36d46f4c24a 100644 --- a/tests/request/test_build_request.py +++ b/tests/request/test_build_request.py @@ -1,7 +1,7 @@ from __future__ import annotations from datetime import datetime -from typing import Any, Dict +from typing import Any import pytest @@ -64,7 +64,7 @@ @pytest.mark.parametrize("body, condition", TESTS) -def test_build_request(body: Dict[str, Any], condition: Expression) -> None: +def test_build_request(body: dict[str, Any], condition: Expression) -> None: dataset = get_dataset("events") entity = get_entity(EntityKey.EVENTS) schema = RequestSchema.build(HTTPQuerySettings) @@ -86,7 +86,7 @@ def test_build_request(body: Dict[str, Any], condition: Expression) -> None: name="time", expression=Column(alias="_snuba_time", table_name=None, column_name="time"), ), - SelectedExpression("count", FunctionCall("_snuba_count", "count", tuple())), + SelectedExpression("count", FunctionCall("_snuba_count", "count", ())), ], condition=condition, groupby=[Column("_snuba_time", None, "time")], @@ -97,7 +97,7 @@ def test_build_request(body: Dict[str, Any], condition: Expression) -> None: assert request.referrer == "my_request" assert dict(request.original_body) == body status, differences = request.query.equals(expected_query) - assert status == True, f"Query mismatch: {differences}" + assert status, f"Query mismatch: {differences}" TENANT_ID_TESTS = [ diff --git a/tests/settings/test_settings.py b/tests/settings/test_settings.py index f31a7183fa1..72035e58555 100644 --- a/tests/settings/test_settings.py +++ b/tests/settings/test_settings.py @@ -1,6 +1,6 @@ import importlib from copy import deepcopy -from typing import Any, Dict +from typing import Any from unittest.mock import patch import pytest @@ -15,7 +15,7 @@ from snuba.utils.streams.topics import Topic -def build_settings_dict() -> Dict[str, Any]: +def build_settings_dict() -> dict[str, Any]: # Build a dictionary with all variables defined in settings. all_settings = { key: value @@ -34,7 +34,7 @@ def test_invalid_storage() -> None: try: validate_settings(all_settings) except Exception as exc: - assert False, f"'validate_settings' raised an exception {exc}" + raise AssertionError(f"'validate_settings' raised an exception {exc}") from exc finally: cluster[0]["storage_sets"].remove("non_existing_storage") diff --git a/tests/state/test_cache.py b/tests/state/test_cache.py index 8d8460c4631..ecba6c94fb2 100644 --- a/tests/state/test_cache.py +++ b/tests/state/test_cache.py @@ -2,9 +2,10 @@ import random import time +from collections.abc import Callable from concurrent.futures import Future from threading import Thread -from typing import Any, Callable +from typing import Any from unittest import mock import pytest @@ -118,12 +119,12 @@ def test_short_circuit(backend: Cache[bytes]) -> None: assert backend.get(key) is None with assert_changes(lambda: function.call_count, 0, 1): - backend.get_readthrough(key, function, noop) == value + assert backend.get_readthrough(key, function, noop) == value assert backend.get(key) is None with assert_changes(lambda: function.call_count, 1, 2): - backend.get_readthrough(key, function, noop) == value + assert backend.get_readthrough(key, function, noop) == value @pytest.mark.redis_db @@ -144,12 +145,12 @@ def test_get_readthrough(backend: Cache[bytes]) -> None: assert backend.get(key) is None with assert_changes(lambda: function.call_count, 0, 1): - backend.get_readthrough(key, function, noop) == value + assert backend.get_readthrough(key, function, noop) == value assert backend.get(key) == value with assert_does_not_change(lambda: function.call_count, 1): - backend.get_readthrough(key, function, noop) == value + assert backend.get_readthrough(key, function, noop) == value @pytest.mark.redis_db @@ -172,7 +173,7 @@ def test_get_readthrough_set_wait(backend: Cache[bytes]) -> None: def function() -> bytes: time.sleep(1) - return f"{random.random()}".encode("utf-8") + return f"{random.random()}".encode() def worker() -> bytes: return backend.get_readthrough(key, function, noop) @@ -264,7 +265,9 @@ def test_set_fails_open(backend: Cache[bytes]) -> None: "error", [ResponseError("OOM command not allowed under OOM prevention."), RedisTimeoutError()] ) def test_dont_record_expected_errors(backend: Cache[bytes], error: Exception) -> None: - with mock.patch.object(redis_client, "set", side_effect=error): - with mock.patch.object(sentry_sdk, "capture_exception") as capture_exception: - backend.get_readthrough("key", lambda: b"value", noop) - capture_exception.assert_not_called() + with ( + mock.patch.object(redis_client, "set", side_effect=error), + mock.patch.object(sentry_sdk, "capture_exception") as capture_exception, + ): + backend.get_readthrough("key", lambda: b"value", noop) + capture_exception.assert_not_called() diff --git a/tests/state/test_rate_limit.py b/tests/state/test_rate_limit.py index c0250fb4795..0ca6b520452 100644 --- a/tests/state/test_rate_limit.py +++ b/tests/state/test_rate_limit.py @@ -2,7 +2,7 @@ import time import uuid -from typing import Any, Tuple +from typing import Any from unittest.mock import patch import pytest @@ -36,13 +36,17 @@ def test_ratelimit_aggregator(self, rate_limit_shards: Any) -> None: rate_limit_params1 = RateLimitParameters("foo", "bar", None, 1) rate_limit_params2 = RateLimitParameters("foo", "bar", None, 0) - with pytest.raises(RateLimitExceeded): - with RateLimitAggregator([rate_limit_params1, rate_limit_params2]): - pass + with ( + pytest.raises(RateLimitExceeded), + RateLimitAggregator([rate_limit_params1, rate_limit_params2]), + ): + pass - with pytest.raises(RateLimitExceeded): - with RateLimitAggregator([rate_limit_params2, rate_limit_params1]): - pass + with ( + pytest.raises(RateLimitExceeded), + RateLimitAggregator([rate_limit_params2, rate_limit_params1]), + ): + pass @pytest.mark.redis_db def test_concurrent_limit(self, rate_limit_shards: Any) -> None: @@ -54,9 +58,8 @@ def test_concurrent_limit(self, rate_limit_shards: Any) -> None: # 0 concurrent limit rate_limit_params = RateLimitParameters("foo", "bar", None, 0) - with pytest.raises(RateLimitExceeded): - with rate_limit(rate_limit_params): - pass + with pytest.raises(RateLimitExceeded), rate_limit(rate_limit_params): + pass # Concurrent limit 1 with consecutive queries should not raise rate_limit_params = RateLimitParameters("foo", "bar", None, 1) @@ -70,18 +73,22 @@ def test_concurrent_limit(self, rate_limit_shards: Any) -> None: # Concurrent limit with concurrent queries rate_limit_params = RateLimitParameters("foo", "bar", None, 1) - with pytest.raises(RateLimitExceeded): - with rate_limit(rate_limit_params): - with rate_limit(rate_limit_params): - pass + with ( + pytest.raises(RateLimitExceeded), + rate_limit(rate_limit_params), + rate_limit(rate_limit_params), + ): + pass # Concurrent with different buckets should not raise rate_limit_params1 = RateLimitParameters("foo", "bar", None, 1) rate_limit_params2 = RateLimitParameters("shoe", "star", None, 1) - with RateLimitAggregator([rate_limit_params1]): - with RateLimitAggregator([rate_limit_params2]): - pass + with ( + RateLimitAggregator([rate_limit_params1]), + RateLimitAggregator([rate_limit_params2]), + ): + pass @pytest.mark.redis_db def test_fails_open(self, rate_limit_shards: Any) -> None: @@ -96,33 +103,34 @@ def test_per_second_limit(self, rate_limit_shards: Any) -> None: bucket = uuid.uuid4() rate_limit_params = RateLimitParameters("foo", str(bucket), 1, None) # Create 30 queries at time 0, should all be allowed - with patch.object(state.time, "time", lambda: 0): # type: ignore + with patch.object(state.time, "time", lambda: 0): # type: ignore[attr-defined] for _ in range(30): with rate_limit(rate_limit_params) as stats: assert stats is not None # Create another 30 queries at time 30, should also be allowed - with patch.object(state.time, "time", lambda: 30): # type: ignore + with patch.object(state.time, "time", lambda: 30): # type: ignore[attr-defined] for _ in range(30): with rate_limit(rate_limit_params) as stats: assert stats is not None - with patch.object(state.time, "time", lambda: 60): # type: ignore + with patch.object(state.time, "time", lambda: 60): # type: ignore[attr-defined] # 1 more query should be allowed at T60 because it does not make the previous # rate exceed 1/sec until it has finished. with rate_limit(rate_limit_params) as stats: assert stats is not None # But the next one should not be allowed - with pytest.raises(RateLimitExceeded): - with rate_limit(rate_limit_params): - pass + with pytest.raises(RateLimitExceeded), rate_limit(rate_limit_params): + pass # Another query at time 61 should be allowed because the first 30 queries # have fallen out of the lookback window - with patch.object(state.time, "time", lambda: 61): # type: ignore - with rate_limit(rate_limit_params) as stats: - assert stats is not None + with ( + patch.object(state.time, "time", lambda: 61), # type: ignore[attr-defined] + rate_limit(rate_limit_params) as stats, + ): + assert stats is not None @pytest.mark.redis_db def test_aggregator(self, rate_limit_shards: Any) -> None: @@ -137,17 +145,21 @@ def test_aggregator(self, rate_limit_shards: Any) -> None: rate_limit_params_outer = RateLimitParameters("foo", "bar", None, 0) rate_limit_params_inner = RateLimitParameters("foo", "bar", None, 5) - with pytest.raises(RateLimitExceeded): - with RateLimitAggregator([rate_limit_params_outer, rate_limit_params_inner]): - pass + with ( + pytest.raises(RateLimitExceeded), + RateLimitAggregator([rate_limit_params_outer, rate_limit_params_inner]), + ): + pass # raise when the outer rate limit should fail rate_limit_params_outer = RateLimitParameters("foo", "bar", None, 5) rate_limit_params_inner = RateLimitParameters("foo", "bar", None, 0) - with pytest.raises(RateLimitExceeded): - with RateLimitAggregator([rate_limit_params_outer, rate_limit_params_inner]): - pass + with ( + pytest.raises(RateLimitExceeded), + RateLimitAggregator([rate_limit_params_outer, rate_limit_params_inner]), + ): + pass @pytest.mark.redis_db def test_rate_limit_container(self) -> None: @@ -175,7 +187,7 @@ def test_bypass_rate_limit(self) -> None: @pytest.mark.redis_db def test_rate_limit_exceptions(self) -> None: params = RateLimitParameters("foo", "bar", None, 5) - bucket = "{}{}".format(state.ratelimit_prefix, params.bucket) + bucket = f"{state.ratelimit_prefix}{params.bucket}" def count() -> int: return int(get_redis_client(RedisClientKey.RATE_LIMITER).zcount(bucket, "-inf", "+inf")) @@ -185,17 +197,16 @@ def count() -> int: assert count() == 1 - with pytest.raises(RateLimitExceeded): - with rate_limit(params): - assert count() == 2 - raise RateLimitExceeded("stuff") # simulate an inner rate limiter failing + with pytest.raises(RateLimitExceeded), rate_limit(params): + assert count() == 2 + raise RateLimitExceeded("stuff") # simulate an inner rate limiter failing assert count() == 2 @pytest.mark.redis_db def test_rate_limit_ttl(self) -> None: params = RateLimitParameters("foo", "bar", None, 5) - bucket = "{}{}".format(state.ratelimit_prefix, params.bucket) + bucket = f"{state.ratelimit_prefix}{params.bucket}" with rate_limit(params): pass @@ -216,18 +227,17 @@ def test_rate_limit_ttl(self) -> None: tests, ) @pytest.mark.redis_db -def test_rate_limit_failures(vals: Tuple[int, int, int], rate_limit_shards: Any) -> None: +def test_rate_limit_failures(vals: tuple[int, int, int], rate_limit_shards: Any) -> None: params = [] for i, v in enumerate(vals): params.append(RateLimitParameters(f"foo{i}", f"bar{i}", None, v)) - with pytest.raises(RateLimitExceeded): - with RateLimitAggregator(params): - pass + with pytest.raises(RateLimitExceeded), RateLimitAggregator(params): + pass now = time.time() for p in params: - bucket = "{}{}".format(state.ratelimit_prefix, p.bucket) + bucket = f"{state.ratelimit_prefix}{p.bucket}" count = get_redis_client(RedisClientKey.RATE_LIMITER).zcount( bucket, now - state.rate_lookback_s, now + state.rate_lookback_s ) diff --git a/tests/state/test_state.py b/tests/state/test_state.py index 13ef5e3ede6..28643734ecd 100644 --- a/tests/state/test_state.py +++ b/tests/state/test_state.py @@ -13,9 +13,9 @@ class TestState: def setup_method(self) -> None: from snuba.web.views import application - assert application.testing == True + assert application.testing self.app = application.test_client() - self.app.post = partial(self.app.post, headers={"referer": "test"}) # type: ignore + self.app.post = partial(self.app.post, headers={"referer": "test"}) # type: ignore[method-assign] @pytest.mark.redis_db def test_config(self) -> None: @@ -97,7 +97,7 @@ def test_config_types(self) -> None: @pytest.mark.redis_db def test_memoize(self) -> None: - @state.memoize(0.1) # type: ignore + @state.memoize(0.1) # type: ignore[arg-type] def rand(config_key: str = "test") -> float: return random.random() @@ -109,7 +109,7 @@ def rand(config_key: str = "test") -> float: @pytest.mark.redis_db def test_memoize_with_args(self) -> None: - @state.memoize(0.1) # type: ignore + @state.memoize(0.1) # type: ignore[arg-type] def rand(config_key: str = "test1") -> str: return f"{random.random()}:{config_key}" diff --git a/tests/subscriptions/__init__.py b/tests/subscriptions/__init__.py index 858fbc963cc..b21128d5322 100644 --- a/tests/subscriptions/__init__.py +++ b/tests/subscriptions/__init__.py @@ -116,10 +116,10 @@ def setup_teardown(self, request: pytest.FixtureRequest) -> None: write_raw_unprocessed_events(items_storage, extra_messages + messages) -def __entity_eq__(self: Entity, other: object) -> bool: +def __entity_eq__(self: object, other: object) -> bool: if not isinstance(other, Entity): return False return isinstance(self, type(other)) -Entity.__eq__ = __entity_eq__ # type: ignore +Entity.__eq__ = __entity_eq__ # type: ignore[method-assign] diff --git a/tests/subscriptions/entity_subscriptions/test_entity_subscriptions.py b/tests/subscriptions/entity_subscriptions/test_entity_subscriptions.py index aebdef7c10f..05972f9a057 100644 --- a/tests/subscriptions/entity_subscriptions/test_entity_subscriptions.py +++ b/tests/subscriptions/entity_subscriptions/test_entity_subscriptions.py @@ -1,4 +1,5 @@ -from typing import Any, Mapping, Optional, Type, Union +from collections.abc import Mapping +from typing import Any import pytest @@ -144,10 +145,10 @@ @pytest.mark.parametrize("entity_key, query, metadata, exception, offset", TESTS) def test_entity_subscription_processors( entity_key: EntityKey, - query: Union[CompositeQuery[QueryEntity], Query], + query: CompositeQuery[QueryEntity] | Query, metadata: Mapping[str, Any], - exception: Optional[Type[Exception]], - offset: Optional[int], + exception: type[Exception] | None, + offset: int | None, ) -> None: entity = get_entity(entity_key) subscription_processors = entity.get_subscription_processors() @@ -156,7 +157,7 @@ def test_entity_subscription_processors( for processor in subscription_processors: if exception is not None: with pytest.raises(exception): - processor.to_dict(metadata) == {} + processor.to_dict(metadata) else: if isinstance(processor, AddColumnCondition): processor.process(query, metadata, offset) @@ -169,10 +170,10 @@ def test_entity_subscription_processors( @pytest.mark.parametrize("entity_key, query, metadata, exception, offset", TESTS) def test_entity_subscription_validators( entity_key: EntityKey, - query: Union[CompositeQuery[QueryEntity], Query], + query: CompositeQuery[QueryEntity] | Query, metadata: Mapping[str, Any], - exception: Optional[Type[Exception]], - offset: Optional[int], + exception: type[Exception] | None, + offset: int | None, ) -> None: entity = get_entity(entity_key) subscription_validators = entity.get_subscription_validators() diff --git a/tests/subscriptions/entity_subscriptions/test_entity_subscriptions_data.py b/tests/subscriptions/entity_subscriptions/test_entity_subscriptions_data.py index d36c6d52f12..db8f2fc1d8c 100644 --- a/tests/subscriptions/entity_subscriptions/test_entity_subscriptions_data.py +++ b/tests/subscriptions/entity_subscriptions/test_entity_subscriptions_data.py @@ -1,4 +1,4 @@ -from typing import List, Tuple, cast +from typing import cast from uuid import UUID import pytest @@ -69,10 +69,10 @@ def test_entity_subscriptions_data() -> None: ] assert len(stores) == 1 - assert len([s for s in stores[0].all()]) == 1 + assert len(list(stores[0].all())) == 1 result = cast( - List[Tuple[UUID, SubscriptionData]], + list[tuple[UUID, SubscriptionData]], RedisSubscriptionDataStore( redis_client, entity_key, diff --git a/tests/subscriptions/test_builder_mode_state.py b/tests/subscriptions/test_builder_mode_state.py index 2837ad36ce5..09a19245270 100644 --- a/tests/subscriptions/test_builder_mode_state.py +++ b/tests/subscriptions/test_builder_mode_state.py @@ -1,5 +1,5 @@ +from collections.abc import Sequence from datetime import timedelta -from typing import Sequence, Tuple import pytest @@ -94,7 +94,7 @@ @pytest.mark.redis_db def test_state_changes( general_mode: str, - subscriptions: Sequence[Tuple[Subscription, int]], + subscriptions: Sequence[tuple[Subscription, int]], expected_modes: Sequence[TaskBuilderMode], ) -> None: prev_threshold = settings.MAX_RESOLUTION_FOR_JITTER diff --git a/tests/subscriptions/test_codecs.py b/tests/subscriptions/test_codecs.py index f8c8c781c86..52612c3fb0d 100644 --- a/tests/subscriptions/test_codecs.py +++ b/tests/subscriptions/test_codecs.py @@ -2,8 +2,9 @@ import json import uuid +from collections.abc import Callable, Mapping from datetime import datetime -from typing import Any, Callable, Mapping +from typing import Any import pytest from google.protobuf.message import Message as ProtobufMessage diff --git a/tests/subscriptions/test_combined_scheduler_executor.py b/tests/subscriptions/test_combined_scheduler_executor.py index 4dfaa7f20ab..67095b6e299 100644 --- a/tests/subscriptions/test_combined_scheduler_executor.py +++ b/tests/subscriptions/test_combined_scheduler_executor.py @@ -89,7 +89,7 @@ def test_combined_scheduler_and_executor(tmpdir: Path) -> None: strategy.submit(message) # Wait for the query to be executed and the result message produced - for i in range(10): + for _i in range(10): time.sleep(0.5) strategy.poll() if commit.call_count == 2: diff --git a/tests/subscriptions/test_data.py b/tests/subscriptions/test_data.py index 4d16abebabb..a24a5c22a8a 100644 --- a/tests/subscriptions/test_data.py +++ b/tests/subscriptions/test_data.py @@ -1,5 +1,4 @@ from datetime import datetime -from typing import Optional, Type import pytest from sentry_protos.snuba.v1.endpoint_create_subscription_pb2 import ( @@ -182,9 +181,9 @@ class TestBuildRequestBase: def compare_conditions( self, subscription: SubscriptionData, - exception: Optional[Type[Exception]], + exception: type[Exception] | None, aggregate: str, - value: Optional[int | float], + value: int | float | None, ) -> None: timer = Timer("test") if exception is not None: @@ -195,7 +194,7 @@ def compare_conditions( 100, timer, ) - subscription.run_query(self.dataset, request, timer) # type: ignore + subscription.run_query(self.dataset, request, timer) # type: ignore[arg-type] return request = subscription.build_request( @@ -204,7 +203,7 @@ def compare_conditions( 100, timer, ) - result = subscription.run_query(self.dataset, request, timer) # type: ignore + result = subscription.run_query(self.dataset, request, timer) # type: ignore[arg-type] assert result.result["data"][0][aggregate] == value @@ -216,8 +215,8 @@ class TestBuildRequest(BaseSubscriptionTest, TestBuildRequestBase): def test_conditions( self, subscription: SubscriptionData, - expected_value: Optional[int | float], - exception: Optional[Type[Exception]], + expected_value: int | float | None, + exception: type[Exception] | None, ) -> None: self.compare_conditions(subscription, exception, "count", expected_value) @@ -227,7 +226,7 @@ def test_conditions( def test_conditions_eap( self, subscription: SubscriptionData, - expected_value: Optional[int | float], - exception: Optional[Type[Exception]], + expected_value: int | float | None, + exception: type[Exception] | None, ) -> None: self.compare_conditions(subscription, exception, "count", expected_value) diff --git a/tests/subscriptions/test_executor_consumer.py b/tests/subscriptions/test_executor_consumer.py index 15c66353650..3e81c130200 100644 --- a/tests/subscriptions/test_executor_consumer.py +++ b/tests/subscriptions/test_executor_consumer.py @@ -1,8 +1,8 @@ import json import time import uuid +from collections.abc import Iterator, Mapping from datetime import datetime, timedelta -from typing import Iterator, Mapping, Optional from unittest import mock import pytest @@ -97,7 +97,7 @@ def on_partitions_assigned(partitions: Mapping[Partition, int]) -> None: # We need to wait for the consumer to receive partitions otherwise, # when we try to consume messages, we will not find anything. # Subscription is an async process. - assert assigned == True, "Did not receive assignment within 10 attempts" + assert assigned, "Did not receive assignment within 10 attempts" consumer_group = str(uuid.uuid1().hex) auto_offset_reset = "latest" @@ -115,7 +115,7 @@ def on_partitions_assigned(partitions: Mapping[Partition, int]) -> None: TestingMetricsBackend(), None, ) - for i in range(1, 5): + for _i in range(1, 5): # Give time to the executor to subscribe time.sleep(1) executor._run_once() @@ -172,8 +172,8 @@ def on_partitions_assigned(partitions: Mapping[Partition, int]) -> None: def generate_message( entity_key: EntityKey, - subscription_identifier: Optional[SubscriptionIdentifier] = None, - bad_query: Optional[bool] = False, + subscription_identifier: SubscriptionIdentifier | None = None, + bad_query: bool | None = False, ) -> Iterator[Message[KafkaPayload]]: codec = SubscriptionScheduledTaskEncoder() epoch = datetime(1970, 1, 1) diff --git a/tests/subscriptions/test_filter_subscriptions.py b/tests/subscriptions/test_filter_subscriptions.py index 3241f93e0d2..58a627ddabc 100644 --- a/tests/subscriptions/test_filter_subscriptions.py +++ b/tests/subscriptions/test_filter_subscriptions.py @@ -2,7 +2,6 @@ import uuid from datetime import timedelta from random import randint -from typing import MutableSequence from unittest.mock import patch import pytest @@ -35,18 +34,21 @@ def build_subscription(resolution: timedelta, org_id: int) -> Subscription: @pytest.fixture -def expected_subs() -> MutableSequence[Subscription]: +def expected_subs() -> list[Subscription]: return [build_subscription(timedelta(minutes=1), 2) for count in range(randint(1, 50))] @pytest.fixture -def extra_subs() -> MutableSequence[Subscription]: +def extra_subs() -> list[Subscription]: return [build_subscription(timedelta(minutes=3), 1) for count in range(randint(1, 50))] @patch("snuba.settings.SLICED_STORAGE_SETS", {"events": 3}) @patch("snuba.settings.LOGICAL_PARTITION_MAPPING", {"events": {0: 0, 1: 1, 2: 2}}) -def test_filter_subscriptions(expected_subs, extra_subs) -> None: # type: ignore +def test_filter_subscriptions( + expected_subs: list[Subscription], + extra_subs: list[Subscription], +) -> None: importlib.reload(scheduler) filtered_subs = filter_subscriptions( diff --git a/tests/subscriptions/test_scheduler.py b/tests/subscriptions/test_scheduler.py index 9bbba7198c1..5863b32b396 100644 --- a/tests/subscriptions/test_scheduler.py +++ b/tests/subscriptions/test_scheduler.py @@ -1,6 +1,6 @@ import uuid +from collections.abc import Callable, Collection from datetime import datetime, timedelta -from typing import Callable, Collection, Optional, Tuple import pytest @@ -54,7 +54,7 @@ def build_tick(self, lower: timedelta, upper: timedelta) -> Tick: ), ) - def sort_key(self, task: ScheduledSubscriptionTask) -> Tuple[datetime, uuid.UUID]: + def sort_key(self, task: ScheduledSubscriptionTask) -> tuple[datetime, uuid.UUID]: return task.timestamp, task.task.subscription.identifier.uuid def run_test( @@ -63,9 +63,7 @@ def run_test( start: timedelta, end: timedelta, expected: Collection[ScheduledSubscriptionTask], - sort_key: Optional[ - Callable[[ScheduledSubscriptionTask], Tuple[datetime, uuid.UUID]] - ] = None, + sort_key: Callable[[ScheduledSubscriptionTask], tuple[datetime, uuid.UUID]] | None = None, entity_key: EntityKey = EntityKey.EVENTS, ) -> None: tick = self.build_tick(start, end) diff --git a/tests/subscriptions/test_scheduler_consumer.py b/tests/subscriptions/test_scheduler_consumer.py index d5b9be20019..ae553358dd6 100644 --- a/tests/subscriptions/test_scheduler_consumer.py +++ b/tests/subscriptions/test_scheduler_consumer.py @@ -4,9 +4,10 @@ import logging import time import uuid +from collections.abc import Mapping from datetime import UTC, datetime, timedelta from pathlib import Path -from typing import Any, Mapping, Optional +from typing import Any from unittest import mock import pytest @@ -378,7 +379,7 @@ def test_tick_time_shift() -> None: pytest.param(timedelta(minutes=-5), id="with time shift"), ], ) -def test_tick_consumer(time_shift: Optional[timedelta]) -> None: +def test_tick_consumer(time_shift: timedelta | None) -> None: clock = MockedClock() broker: Broker[KafkaPayload] = Broker(MemoryMessageStorage(), clock) diff --git a/tests/subscriptions/test_scheduler_processing_strategy.py b/tests/subscriptions/test_scheduler_processing_strategy.py index cc865e678e4..defa0f8f657 100644 --- a/tests/subscriptions/test_scheduler_processing_strategy.py +++ b/tests/subscriptions/test_scheduler_processing_strategy.py @@ -1,8 +1,8 @@ import uuid from collections import deque +from collections.abc import Sequence from concurrent.futures import Future from datetime import datetime, timedelta -from typing import Optional, Sequence from unittest import mock import pytest @@ -206,7 +206,7 @@ def test_tick_buffer_wait_slowest() -> None: def make_message_for_next_step( - message: Message[Tick], offset_to_commit: Optional[int] + message: Message[Tick], offset_to_commit: int | None ) -> Message[CommittableTick]: return message.replace(CommittableTick(message.payload, offset_to_commit)) diff --git a/tests/subscriptions/test_store.py b/tests/subscriptions/test_store.py index 9a72125177d..f6bced6ed6a 100644 --- a/tests/subscriptions/test_store.py +++ b/tests/subscriptions/test_store.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from uuid import uuid1 import pytest diff --git a/tests/subscriptions/test_subscription.py b/tests/subscriptions/test_subscription.py index 867df9968bf..0476e10e1e1 100644 --- a/tests/subscriptions/test_subscription.py +++ b/tests/subscriptions/test_subscription.py @@ -1,4 +1,4 @@ -from typing import List, Tuple, cast +from typing import cast from uuid import UUID import pytest @@ -83,7 +83,7 @@ def test(self, subscription: SubscriptionData) -> None: identifier = creator.create(subscription, self.timer) assert ( cast( - List[Tuple[UUID, SubscriptionData]], + list[tuple[UUID, SubscriptionData]], RedisSubscriptionDataStore( get_redis_client(RedisClientKey.SUBSCRIPTION_STORE), self.entity_key, @@ -279,7 +279,7 @@ def test(self, subscription: SubscriptionData, entity_key: EntityKey) -> None: identifier = creator.create(subscription, self.timer) assert ( cast( - List[Tuple[UUID, SubscriptionData]], + list[tuple[UUID, SubscriptionData]], RedisSubscriptionDataStore( get_redis_client(RedisClientKey.SUBSCRIPTION_STORE), entity_key, @@ -316,7 +316,7 @@ def test(self) -> None: identifier = creator.create(subscription, Timer("test")) assert ( cast( - List[Tuple[UUID, SubscriptionData]], + list[tuple[UUID, SubscriptionData]], RedisSubscriptionDataStore( get_redis_client(RedisClientKey.SUBSCRIPTION_STORE), self.entity_key, @@ -381,7 +381,7 @@ def test_rpc_subscription_creator(self) -> None: identifier = creator.create(subscription, self.timer) assert ( cast( - List[Tuple[UUID, SubscriptionData]], + list[tuple[UUID, SubscriptionData]], RedisSubscriptionDataStore( get_redis_client(RedisClientKey.SUBSCRIPTION_STORE), EntityKey.EAP_ITEMS, diff --git a/tests/subscriptions/test_task_builder.py b/tests/subscriptions/test_task_builder.py index 3bf9ae6f21e..0dacc28f55d 100644 --- a/tests/subscriptions/test_task_builder.py +++ b/tests/subscriptions/test_task_builder.py @@ -1,5 +1,5 @@ +from collections.abc import Sequence from datetime import datetime, timedelta -from typing import Sequence, Tuple import pytest @@ -219,9 +219,9 @@ def test_sequences( builder: TaskBuilder, primary_builder_config: str, - sequence_in: Sequence[Tuple[int, Subscription]], - task_sequence: Sequence[Tuple[int, ScheduledSubscriptionTask]], - metrics: Sequence[Tuple[str, int, Tags]], + sequence_in: Sequence[tuple[int, Subscription]], + task_sequence: Sequence[tuple[int, ScheduledSubscriptionTask]], + metrics: Sequence[tuple[str, int, Tags]], ) -> None: """ Tries to execute the task builder on several sequences of diff --git a/tests/subscriptions/test_types.py b/tests/subscriptions/test_types.py index 2351fe2c6d8..c93589f8d28 100644 --- a/tests/subscriptions/test_types.py +++ b/tests/subscriptions/test_types.py @@ -11,7 +11,7 @@ def test_interval_validation() -> None: Interval(10, 1) with pytest.raises(InvalidRangeError): - Interval(1, None) # type: ignore + Interval(1, None) # type: ignore[type-var] assert e.value.lower == 10 assert e.value.upper == 1 diff --git a/tests/test_api.py b/tests/test_api.py index f0a589360a4..e71022e4016 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -3,8 +3,9 @@ import calendar import time import uuid -from datetime import datetime, timedelta, timezone -from typing import Any, Callable, Generator, List, Sequence, Tuple, Union +from collections.abc import Callable, Generator, Sequence +from datetime import UTC, datetime, timedelta +from typing import Any from unittest.mock import MagicMock, patch import pytest @@ -36,7 +37,7 @@ @pytest.mark.redis_db class SimpleAPITest(BaseApiTest): @pytest.fixture(autouse=True) - def setup_teardown(self, events_db: None, redis_db: None) -> Generator[None, None, None]: + def setup_teardown(self, events_db: None, redis_db: None) -> Generator[None]: # values for test data self.project_ids = [1, 2, 3] # 3 projects self.environments = ["prød", "test"] # 2 environments @@ -153,15 +154,14 @@ def redis_db_size(self, redis_client: RedisClientType) -> int: dbsize: int | dict[str, int] = redis_client.dbsize() if isinstance(dbsize, dict): return sum(dbsize.values()) - else: - return dbsize + return dbsize @pytest.mark.events_db @pytest.mark.redis_db class TestApi(SimpleAPITest): @pytest.fixture - def test_entity(self) -> Union[str, Tuple[str, str]]: + def test_entity(self) -> str | tuple[str, str]: return "events" @pytest.fixture @@ -181,7 +181,7 @@ def test_count(self) -> None: .get_cluster() .get_query_connection(ClickhouseClientSettings.QUERY) ) - res = clickhouse.execute("SELECT count() FROM %s" % self.table).results + res = clickhouse.execute(f"SELECT count() FROM {self.table}").results assert res[0][0] == 330 rollup_mins = 60 @@ -253,9 +253,7 @@ def test_time_alignment(self) -> None: "granularity": 60, "selected_columns": ["time"], "groupby": "time", - "from_date": (self.base_time + skew) - .replace(tzinfo=timezone.utc) - .isoformat(), + "from_date": (self.base_time + skew).replace(tzinfo=UTC).isoformat(), "to_date": ( self.base_time + skew + timedelta(minutes=self.minutes) ).isoformat(), @@ -1127,23 +1125,21 @@ def test_promoted_expansion(self) -> None: result_map = {d["tags_key"]: d for d in result["data"]} # Result contains both promoted and regular tags - assert set(result_map.keys()) == set( - [ - # Promoted tags - "environment", - "sentry:dist", - "sentry:release", - "os.rooted", - "os.name", - # User (nested) tags - "foo", - "foo.bar", - # Note this is a nested (user-provided) os_name tag and is - # unrelated to the fact that we happen to store the - # `os.name` tag as an `os_name` column. - "os_name", - ] - ) + assert set(result_map.keys()) == { + # Promoted tags + "environment", + "sentry:dist", + "sentry:release", + "os.rooted", + "os.name", + # User (nested) tags + "foo", + "foo.bar", + # Note this is a nested (user-provided) os_name tag and is + # unrelated to the fact that we happen to store the + # `os.name` tag as an `os_name` column. + "os_name", + } # Reguar (nested) tag assert result_map["foo"]["count"] == 180 @@ -1461,7 +1457,7 @@ def test_duplicate_column(self) -> None: assert result["meta"] == [{"name": "timestamp", "type": "DateTime"}] def test_exception_captured_by_sentry(self) -> None: - events: List[Any] = [] + events: list[Any] = [] with Hub(Client(transport=events.append)): # This endpoint should return 500 as it internally raises an exception response = self.app.get("/tests/error") @@ -1473,7 +1469,7 @@ def test_exception_captured_by_sentry(self) -> None: def test_consistent(self) -> None: state.set_config("consistent_override", "test_override=0;another=0.5") state.set_config("read_through_cache.short_circuit", 1) - query_data = { + query_data: dict[str, Any] = { "project": 2, "tenant_ids": {"referrer": "test_query", "organization_id": 1234}, "aggregations": [["count()", "", "aggregate"]], @@ -1487,10 +1483,10 @@ def test_consistent(self) -> None: response = json.loads(self.post(query, referrer="test_query").data) assert response["stats"]["consistent"] - query_data["tenant_ids"]["referrer"] = "test_override" # type: ignore + query_data["tenant_ids"]["referrer"] = "test_override" query = json.dumps(query_data) response = json.loads(self.post(query, referrer="test_override").data) - assert response["stats"]["consistent"] == False + assert not response["stats"]["consistent"] def test_gracefully_handle_multiple_conditions_on_same_column(self) -> None: response = self.post( @@ -1623,7 +1619,7 @@ def test_test_endpoints(self) -> None: assert len(clickhouse.execute(f"SELECT * FROM {self.table}").results) == 0 def test_max_limit(self) -> None: - with pytest.raises(Exception): + with pytest.raises(Exception): # noqa: B017 over-limit request may surface as various exception types self.post( json.dumps( { @@ -1750,7 +1746,7 @@ def test_tenant_ids(self) -> None: partition, ).all() )[0] - assert data.tenant_ids == dict() # not saved to the redis store + assert data.tenant_ids == {} # not saved to the redis store assert "tenant_ids" not in data.to_dict() # doesn't show up in dictified data @pytest.mark.clickhouse_db @@ -1831,7 +1827,7 @@ def test_invalid_dataset_and_entity_combination(self) -> None: def test_time_error(self) -> None: resp = self.app.post( - "{}/{}/subscriptions".format(self.dataset_name, self.entity_key), + f"{self.dataset_name}/{self.entity_key}/subscriptions", data=json.dumps( { "project_id": 1, @@ -1857,7 +1853,7 @@ def test_with_bad_snql(self) -> None: with patch("snuba.subscriptions.subscription.uuid1") as uuid4: uuid4.return_value = expected_uuid resp = self.app.post( - "{}/{}/subscriptions".format(self.dataset_name, self.entity_key), + f"{self.dataset_name}/{self.entity_key}/subscriptions", data=json.dumps( { "project_id": 1, diff --git a/tests/test_api_status.py b/tests/test_api_status.py index 7d0a2c6ae6c..bb44f76d53b 100644 --- a/tests/test_api_status.py +++ b/tests/test_api_status.py @@ -1,5 +1,6 @@ +from collections.abc import Callable from datetime import datetime, timedelta -from typing import Any, Callable +from typing import Any from unittest.mock import MagicMock, patch import pytest diff --git a/tests/test_cleanup.py b/tests/test_cleanup.py index 4f8e5b9400d..5cbc8d3feb7 100644 --- a/tests/test_cleanup.py +++ b/tests/test_cleanup.py @@ -1,6 +1,6 @@ import uuid +from collections.abc import Callable from datetime import UTC, datetime, timedelta -from typing import Callable, Optional from unittest.mock import MagicMock, patch import pytest @@ -61,7 +61,7 @@ def test_main_cases( self, current_time: MagicMock, storage_key: StorageKey, - create_event_row_for_date: Callable[[datetime, Optional[int]], InsertBatch], + create_event_row_for_date: Callable[[datetime, int | None], InsertBatch], ) -> None: def to_monday(d: datetime) -> datetime: rounded = d - timedelta(days=d.weekday()) @@ -119,14 +119,12 @@ def to_monday(d: datetime) -> datetime: write_processed_messages(storage, [create_event_row_for_date(one_week_ago, 30)]) parts = cleanup.get_active_partitions(clickhouse, storage, database, table) - assert {(p.date, p.retention_days) for p in parts} == set( - [ - (to_monday(thirteen_weeks_ago), 90), - (to_monday(three_weeks_ago), 90), - (to_monday(one_week_ago), 30), - (to_monday(base), 90), - ] - ) + assert {(p.date, p.retention_days) for p in parts} == { + (to_monday(thirteen_weeks_ago), 90), + (to_monday(three_weeks_ago), 90), + (to_monday(one_week_ago), 30), + (to_monday(base), 90), + } stale = cleanup.filter_stale_partitions(parts) assert [(p.date, p.retention_days) for p in stale] == [(to_monday(thirteen_weeks_ago), 90)] @@ -134,30 +132,27 @@ def to_monday(d: datetime) -> datetime: five_weeks_ago = base - timedelta(days=7 * 5) write_processed_messages(storage, [create_event_row_for_date(five_weeks_ago, 30)]) parts = cleanup.get_active_partitions(clickhouse, storage, database, table) - assert {(p.date, p.retention_days) for p in parts} == set( - [ - (to_monday(thirteen_weeks_ago), 90), - (to_monday(five_weeks_ago), 30), - (to_monday(three_weeks_ago), 90), - (to_monday(one_week_ago), 30), - (to_monday(base), 90), - ] - ) + assert {(p.date, p.retention_days) for p in parts} == { + (to_monday(thirteen_weeks_ago), 90), + (to_monday(five_weeks_ago), 30), + (to_monday(three_weeks_ago), 90), + (to_monday(one_week_ago), 30), + (to_monday(base), 90), + } stale = cleanup.filter_stale_partitions(parts) - assert {(p.date, p.retention_days) for p in stale} == set( - [(to_monday(thirteen_weeks_ago), 90), (to_monday(five_weeks_ago), 30)] - ) + assert {(p.date, p.retention_days) for p in stale} == { + (to_monday(thirteen_weeks_ago), 90), + (to_monday(five_weeks_ago), 30), + } cleanup.drop_partitions(clickhouse, database, table, stale, dry_run=False) parts = cleanup.get_active_partitions(clickhouse, storage, database, table) - assert {(p.date, p.retention_days) for p in parts} == set( - [ - (to_monday(three_weeks_ago), 90), - (to_monday(one_week_ago), 30), - (to_monday(base), 90), - ] - ) + assert {(p.date, p.retention_days) for p in parts} == { + (to_monday(three_weeks_ago), 90), + (to_monday(one_week_ago), 30), + (to_monday(base), 90), + } @pytest.mark.parametrize( "storage_key, create_event_row_for_date", @@ -168,7 +163,7 @@ def test_midnight_error_case( self, current_time: MagicMock, storage_key: StorageKey, - create_event_row_for_date: Callable[[datetime, Optional[int]], InsertBatch], + create_event_row_for_date: Callable[[datetime, int | None], InsertBatch], ) -> None: """ This test is simulating a failure case that happened in production, where when the script ran, diff --git a/tests/test_cli.py b/tests/test_cli.py index 85148d869a6..22db564b1ed 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -3,7 +3,7 @@ import time -class TestCli(object): +class TestCli: def test_consumer_cli(self) -> None: """ Check that the consumer daemon runs until it is killed diff --git a/tests/test_configurable_component.py b/tests/test_configurable_component.py index 163100ee7c2..419584dc533 100644 --- a/tests/test_configurable_component.py +++ b/tests/test_configurable_component.py @@ -116,9 +116,9 @@ def test_component_name(self, test_component: SomeConfigurableComponent) -> None def test_config_definitions(self, test_component: SomeConfigurableComponent) -> None: """Test that config_definitions returns all configurations.""" - assert set( - ["default_config_1", "additional_config_1", "override_config_for_org_id"] - ) == set(test_component.config_definitions().keys()) + assert {"default_config_1", "additional_config_1", "override_config_for_org_id"} == set( + test_component.config_definitions().keys() + ) def test_get_current_configs(self, test_component: SomeConfigurableComponent) -> None: """Test that get_current_configs returns the correct configs.""" diff --git a/tests/test_consumer.py b/tests/test_consumer.py index b6519c95fe6..c66eb6c92e0 100644 --- a/tests/test_consumer.py +++ b/tests/test_consumer.py @@ -2,10 +2,10 @@ import itertools import json import pickle +from collections.abc import MutableSequence from datetime import datetime from pathlib import Path from pickle import PickleBuffer -from typing import MutableSequence from unittest.mock import Mock, call import pytest @@ -88,7 +88,7 @@ def write_step() -> ProcessedMessageBatchWriter: commit_function = Mock() strategy = factory.create_with_partitions(commit_function, {}) - for i in range(3): + for _i in range(3): strategy.poll() strategy.submit(next(messages)) diff --git a/tests/test_discover_api.py b/tests/test_discover_api.py index 5c31dea7190..98dd737e81d 100644 --- a/tests/test_discover_api.py +++ b/tests/test_discover_api.py @@ -1,5 +1,6 @@ -from datetime import datetime, timedelta, timezone -from typing import Any, Callable, Tuple, Union +from collections.abc import Callable +from datetime import UTC, datetime, timedelta +from typing import Any import pytest import simplejson as json @@ -18,7 +19,7 @@ @pytest.mark.redis_db class TestDiscoverApi(BaseApiTest): @pytest.fixture - def test_entity(self) -> Union[str, Tuple[str, str]]: + def test_entity(self) -> str | tuple[str, str]: # This can be overridden in the post function return ( "discover_events", @@ -40,9 +41,9 @@ def _setup_method(self, request: Any, events_db: None) -> None: self.event = get_raw_event() self.project_id = self.event["project_id"] self.skew = timedelta(minutes=180) - self.base_time = datetime.utcnow().replace( - second=0, microsecond=0, tzinfo=timezone.utc - ) - timedelta(minutes=90) + self.base_time = datetime.utcnow().replace(second=0, microsecond=0, tzinfo=UTC) - timedelta( + minutes=90 + ) events_storage = get_entity(EntityKey.EVENTS).get_writable_storage() assert events_storage is not None @@ -1771,7 +1772,7 @@ def test_symbolicated_in_app(self) -> None: data = json.loads(response.data) assert response.status_code == 200 assert len(data["data"]) == 1 - assert data["data"][0]["symbolicated_in_app"] == True + assert data["data"][0]["symbolicated_in_app"] def test_timestamp_ms_query(self) -> None: response = self.post( @@ -1829,7 +1830,7 @@ class TestDiscoverAPIEntitySelection(TestDiscoverApi): """ @pytest.fixture - def test_entity(self) -> Union[str, Tuple[str, str]]: + def test_entity(self) -> str | tuple[str, str]: # This can be overridden in the post function return "discover" diff --git a/tests/test_generic_metrics_api.py b/tests/test_generic_metrics_api.py index bb6efae3518..ccf6cac66cd 100644 --- a/tests/test_generic_metrics_api.py +++ b/tests/test_generic_metrics_api.py @@ -1,7 +1,8 @@ import itertools import json -from datetime import datetime, timedelta, timezone -from typing import Any, Callable, Iterable, Mapping, Tuple, Union +from collections.abc import Callable, Iterable, Mapping +from datetime import UTC, datetime, timedelta +from typing import Any import pytest from pytest import approx @@ -26,7 +27,7 @@ def utc_yesterday_12_15() -> datetime: return (datetime.utcnow() - timedelta(days=1)).replace( - hour=12, minute=15, second=0, microsecond=0, tzinfo=timezone.utc + hour=12, minute=15, second=0, microsecond=0, tzinfo=UTC ) @@ -36,7 +37,7 @@ def utc_yesterday_12_15() -> datetime: def gen_string() -> str: global placeholder_counter placeholder_counter += 1 - return "placeholder{:04d}".format(placeholder_counter) + return f"placeholder{placeholder_counter:04d}" SHARED_TAGS: Mapping[str, str] = { @@ -83,7 +84,7 @@ def test_app(self) -> Any: return self.app @pytest.fixture - def test_entity(self) -> Union[str, Tuple[str, str]]: + def test_entity(self) -> str | tuple[str, str]: return "generic_metrics_sets" @pytest.fixture(autouse=True) @@ -216,7 +217,7 @@ def test_app(self) -> Any: return self.app @pytest.fixture - def test_entity(self) -> Union[str, Tuple[str, str]]: + def test_entity(self) -> str | tuple[str, str]: return "generic_metrics_distributions" @pytest.fixture(autouse=True) @@ -481,7 +482,7 @@ def test_app(self) -> Any: return self.app @pytest.fixture - def test_entity(self) -> Union[str, Tuple[str, str]]: + def test_entity(self) -> str | tuple[str, str]: return "generic_metrics_counters" @pytest.fixture(autouse=True) @@ -620,7 +621,7 @@ def test_app(self) -> Any: return self.app @pytest.fixture - def test_entity(self) -> Union[str, Tuple[str, str]]: + def test_entity(self) -> str | tuple[str, str]: return "generic_metrics_counters" @pytest.fixture(autouse=True) @@ -760,7 +761,7 @@ def test_app(self) -> Any: return self.app @pytest.fixture - def test_entity(self) -> Union[str, Tuple[str, str]]: + def test_entity(self) -> str | tuple[str, str]: return "generic_metrics_gauges" @pytest.fixture(autouse=True) diff --git a/tests/test_group_attributes_api.py b/tests/test_group_attributes_api.py index fe3befb6772..b2b12823988 100644 --- a/tests/test_group_attributes_api.py +++ b/tests/test_group_attributes_api.py @@ -1,5 +1,6 @@ +from collections.abc import Callable from datetime import datetime -from typing import Any, Callable, Tuple, Union +from typing import Any import pytest import simplejson as json @@ -43,7 +44,7 @@ def kafka_metadata() -> KafkaMessageMetadata: class TestGroupAttributesSnQLApi(SimpleAPITest, BaseApiTest, ConfigurationTest): @pytest.fixture - def test_entity(self) -> Union[str, Tuple[str, str]]: + def test_entity(self) -> str | tuple[str, str]: return "group_attributes" @pytest.fixture diff --git a/tests/test_metrics_api.py b/tests/test_metrics_api.py index 9d29d348b19..0fa795fca54 100644 --- a/tests/test_metrics_api.py +++ b/tests/test_metrics_api.py @@ -1,5 +1,6 @@ -from datetime import datetime, timedelta, timezone -from typing import Any, Callable, Generator, Optional, Tuple, Union, cast +from collections.abc import Callable, Generator +from datetime import UTC, datetime, timedelta +from typing import Any, cast import pytest import simplejson as json @@ -64,7 +65,7 @@ def teardown_common() -> None: def utc_yesterday_12_15() -> datetime: return (datetime.utcnow() - timedelta(days=1)).replace( - hour=12, minute=15, second=0, microsecond=0, tzinfo=timezone.utc + hour=12, minute=15, second=0, microsecond=0, tzinfo=UTC ) @@ -80,13 +81,13 @@ def test_app(self) -> Any: return self.app @pytest.fixture - def test_entity(self) -> Union[str, Tuple[str, str]]: + def test_entity(self) -> str | tuple[str, str]: return "metrics_counters" @pytest.fixture(autouse=True) def setup_teardown( self, _build_snql_post_methods: Callable[[str], Any], clickhouse_db: None - ) -> Generator[None, None, None]: + ) -> Generator[None]: self.post = _build_snql_post_methods # values for test data @@ -153,11 +154,11 @@ def generate_counters(self) -> None: def build_simple_query( self, - metric_id: Optional[int] = None, - org_id: Optional[int] = None, - start_time: Optional[str] = None, - end_time: Optional[str] = None, - granularity: Optional[int] = None, + metric_id: int | None = None, + org_id: int | None = None, + start_time: str | None = None, + end_time: str | None = None, + granularity: int | None = None, ) -> str: if not metric_id: metric_id = self.metric_id @@ -237,7 +238,7 @@ def test_app(self) -> Any: return self.app @pytest.fixture - def test_entity(self) -> Union[str, Tuple[str, str]]: + def test_entity(self) -> str | tuple[str, str]: return "org_metrics_counters" @pytest.fixture(autouse=True) @@ -254,11 +255,9 @@ def setup_method(self, test_method: Any) -> None: self.skew = timedelta(seconds=self.seconds) - self.base_time = datetime.utcnow().replace( - minute=0, second=0, microsecond=0, tzinfo=timezone.utc - ) + self.base_time = datetime.utcnow().replace(minute=0, second=0, microsecond=0, tzinfo=UTC) self.sentry_received_timestamp = datetime.utcnow().replace( - minute=0, second=0, microsecond=0, tzinfo=timezone.utc + minute=0, second=0, microsecond=0, tzinfo=UTC ) self.storage = cast( WritableTableStorage, @@ -303,10 +302,10 @@ def generate_counters(self) -> None: def build_simple_query( self, - metric_id: Optional[int] = None, - start_time: Optional[str] = None, - end_time: Optional[str] = None, - granularity: Optional[int] = None, + metric_id: int | None = None, + start_time: str | None = None, + end_time: str | None = None, + granularity: int | None = None, ) -> str: if not metric_id: metric_id = self.metric_id @@ -377,7 +376,7 @@ def test_app(self) -> Any: return self.app @pytest.fixture - def test_entity(self) -> Union[str, Tuple[str, str]]: + def test_entity(self) -> str | tuple[str, str]: return "metrics_sets" @pytest.fixture(autouse=True) @@ -440,20 +439,15 @@ def generate_sets(self) -> None: write_processed_messages(self.storage, events) def test_sets_basic(self) -> None: - query_str = """MATCH (metrics_sets) + query_str = f"""MATCH (metrics_sets) SELECT uniq(value) AS unique_values BY project_id, org_id - WHERE org_id = {org_id} + WHERE org_id = {self.org_id} AND project_id = 1 - AND metric_id = {metric_id} - AND timestamp >= toDateTime('{start_time}') - AND timestamp < toDateTime('{end_time}') + AND metric_id = {self.metric_id} + AND timestamp >= toDateTime('{(self.base_time - self.skew).isoformat()}') + AND timestamp < toDateTime('{(self.base_time + self.skew).isoformat()}') GRANULARITY 60 - """.format( - metric_id=self.metric_id, - org_id=self.org_id, - start_time=(self.base_time - self.skew).isoformat(), - end_time=(self.base_time + self.skew).isoformat(), - ) + """ response = self.app.post( SNQL_ROUTE, data=json.dumps({"query": query_str, "dataset": "metrics"}) ) @@ -477,7 +471,7 @@ def test_app(self) -> Any: return self.app @pytest.fixture - def test_entity(self) -> Union[str, Tuple[str, str]]: + def test_entity(self) -> str | tuple[str, str]: return "metrics_distributions" @pytest.fixture(autouse=True) @@ -546,20 +540,15 @@ def generate_uniform_distributions(self) -> None: have_generated_dists = True def test_dists_percentiles(self) -> None: - query_str = """MATCH (metrics_distributions) + query_str = f"""MATCH (metrics_distributions) SELECT quantiles(0.5,0.9,0.95,0.99)(value) AS quants BY project_id, org_id - WHERE org_id = {org_id} + WHERE org_id = {self.org_id} AND project_id = 1 - AND metric_id = {metric_id} - AND timestamp >= toDateTime('{start_time}') - AND timestamp < toDateTime('{end_time}') + AND metric_id = {self.metric_id} + AND timestamp >= toDateTime('{(self.base_time - self.skew).isoformat()}') + AND timestamp < toDateTime('{(self.base_time + self.skew).isoformat()}') GRANULARITY 60 - """.format( - metric_id=self.metric_id, - org_id=self.org_id, - start_time=(self.base_time - self.skew).isoformat(), - end_time=(self.base_time + self.skew).isoformat(), - ) + """ response = self.app.post( SNQL_ROUTE, data=json.dumps({"query": query_str, "dataset": "metrics"}) ) @@ -580,25 +569,20 @@ def test_dists_percentiles(self) -> None: ] def test_dists_min_max_avg_one_day_granularity(self) -> None: - query_str = """MATCH (metrics_distributions) + query_str = f"""MATCH (metrics_distributions) SELECT min(value) AS dist_min, max(value) AS dist_max, avg(value) AS dist_avg, sum(value) AS dist_sum, count(value) AS dist_count BY project_id, org_id - WHERE org_id = {org_id} + WHERE org_id = {self.org_id} AND project_id = 1 - AND metric_id = {metric_id} - AND timestamp >= toDateTime('{start_time}') - AND timestamp < toDateTime('{end_time}') + AND metric_id = {self.metric_id} + AND timestamp >= toDateTime('{timestamp_to_bucket(self.base_time, 86400).isoformat()}') + AND timestamp < toDateTime('{(timestamp_to_bucket(self.base_time + timedelta(days=2), 86400)).isoformat()}') GRANULARITY 86400 - """.format( - metric_id=self.metric_id, - org_id=self.org_id, - start_time=timestamp_to_bucket(self.base_time, 86400).isoformat(), - end_time=(timestamp_to_bucket(self.base_time + timedelta(days=2), 86400)).isoformat(), - ) + """ response = self.app.post( SNQL_ROUTE, data=json.dumps({"query": query_str, "dataset": "metrics"}) ) @@ -619,20 +603,15 @@ def test_dists_min_max_avg_one_day_granularity(self) -> None: ) def test_bucketed_time(self) -> None: - query_str = """MATCH (metrics_distributions) + query_str = f"""MATCH (metrics_distributions) SELECT bucketed_time, quantiles(0.5,0.9,0.95,0.99)(value) AS quants BY bucketed_time - WHERE org_id = {org_id} + WHERE org_id = {self.org_id} AND project_id = 1 - AND metric_id = {metric_id} - AND timestamp >= toDateTime('{start_time}') - AND timestamp < toDateTime('{end_time}') + AND metric_id = {self.metric_id} + AND timestamp >= toDateTime('{timestamp_to_bucket(self.base_time - self.skew, 3600).isoformat()}') + AND timestamp < toDateTime('{timestamp_to_bucket(self.base_time + self.skew, 3600).isoformat()}') GRANULARITY 3600 - """.format( - metric_id=self.metric_id, - org_id=self.org_id, - start_time=timestamp_to_bucket(self.base_time - self.skew, 3600).isoformat(), - end_time=timestamp_to_bucket(self.base_time + self.skew, 3600).isoformat(), - ) + """ response = self.app.post( SNQL_ROUTE, data=json.dumps({"query": query_str, "dataset": "metrics"}) ) diff --git a/tests/test_metrics_meta_api.py b/tests/test_metrics_meta_api.py index 4745f32d0d5..1757fb1ef51 100644 --- a/tests/test_metrics_meta_api.py +++ b/tests/test_metrics_meta_api.py @@ -1,6 +1,7 @@ import json +from collections.abc import Callable, Generator, Mapping from datetime import UTC, datetime, timedelta -from typing import Any, Callable, Generator, Mapping, Tuple, Union +from typing import Any import pytest @@ -16,7 +17,7 @@ SNQL_ROUTE = "/generic_metrics/snql" -def get_tags() -> Generator[Mapping[str, str], None, None]: +def get_tags() -> Generator[Mapping[str, str]]: idx = 0 mappings = {"environment": "112358", "release": "132134"} while True: @@ -37,7 +38,7 @@ def test_app(self) -> Any: return self.app @pytest.fixture - def test_entity(self) -> Union[str, Tuple[str, str]]: + def test_entity(self) -> str | tuple[str, str]: return "generic_metrics_sets" @pytest.fixture @@ -167,7 +168,7 @@ def test_retrieve_tag_keys(self, test_entity: str) -> None: @pytest.mark.redis_db class TestGenericMetricsApiCounters(TestGenericMetricsApiSets): @pytest.fixture - def test_entity(self) -> Union[str, Tuple[str, str]]: + def test_entity(self) -> str | tuple[str, str]: return "generic_metrics_counters" @pytest.fixture @@ -186,7 +187,7 @@ def generate_metric_values(self) -> Any: @pytest.mark.redis_db class TestGenericMetricsApiGauges(TestGenericMetricsApiSets): @pytest.fixture - def test_entity(self) -> Union[str, Tuple[str, str]]: + def test_entity(self) -> str | tuple[str, str]: return "generic_metrics_gauges" @pytest.fixture @@ -211,7 +212,7 @@ def generate_metric_values(self) -> Any: @pytest.mark.redis_db class TestGenericMetricsApiDistributions(TestGenericMetricsApiSets): @pytest.fixture - def test_entity(self) -> Union[str, Tuple[str, str]]: + def test_entity(self) -> str | tuple[str, str]: return "generic_metrics_distributions" @pytest.fixture diff --git a/tests/test_metrics_mql_api.py b/tests/test_metrics_mql_api.py index 11f0acd223b..5c0911f1469 100644 --- a/tests/test_metrics_mql_api.py +++ b/tests/test_metrics_mql_api.py @@ -1,9 +1,10 @@ from __future__ import annotations import itertools +from collections.abc import Callable from dataclasses import dataclass -from datetime import datetime, timedelta, timezone -from typing import Any, Callable, cast +from datetime import UTC, datetime, timedelta +from typing import Any, cast import pytest import simplejson as json @@ -39,7 +40,7 @@ def utc_yesterday_12_15() -> datetime: return (datetime.utcnow() - timedelta(days=1)).replace( - hour=12, minute=15, second=0, microsecond=0, tzinfo=timezone.utc + hour=12, minute=15, second=0, microsecond=0, tzinfo=UTC ) diff --git a/tests/test_metrics_sdk_api.py b/tests/test_metrics_sdk_api.py index 607c6665a27..bed8ed1638e 100644 --- a/tests/test_metrics_sdk_api.py +++ b/tests/test_metrics_sdk_api.py @@ -1,7 +1,8 @@ from __future__ import annotations -from datetime import datetime, timedelta, timezone -from typing import Any, Callable, Tuple, Union, cast +from collections.abc import Callable +from datetime import UTC, datetime, timedelta +from typing import Any, cast import pytest import simplejson as json @@ -34,7 +35,7 @@ def utc_yesterday_12_15() -> datetime: return (datetime.utcnow() - timedelta(days=1)).replace( - hour=12, minute=15, second=0, microsecond=0, tzinfo=timezone.utc + hour=12, minute=15, second=0, microsecond=0, tzinfo=UTC ) @@ -65,7 +66,7 @@ def test_app(self) -> Any: return self.app @pytest.fixture - def test_entity(self) -> Union[str, Tuple[str, str]]: + def test_entity(self) -> str | tuple[str, str]: return "generic_metrics_counters" @pytest.fixture @@ -320,7 +321,7 @@ def test_tag_key_value(self, test_entity: str, test_dataset: str, tag_column: st rows = data["data"] assert len(rows) == 1, rows assert rows[0] == { - "tags.key": [int(k) for k in SHARED_TAGS.keys()], + "tags.key": [int(k) for k in SHARED_TAGS], "tags.raw_value": ["t1", "200"], } @@ -358,7 +359,7 @@ def test_raw_mql_string(self, test_dataset: str, tag_column: str) -> None: @pytest.mark.redis_db class TestMetricsSdkApiCounters(TestGenericMetricsSdkApiCounters): @pytest.fixture - def test_entity(self) -> Union[str, Tuple[str, str]]: + def test_entity(self) -> str | tuple[str, str]: return "metrics_counters" @pytest.fixture diff --git a/tests/test_outcomes_api.py b/tests/test_outcomes_api.py index 2e8e2e03e4a..e903298678a 100644 --- a/tests/test_outcomes_api.py +++ b/tests/test_outcomes_api.py @@ -1,7 +1,8 @@ import itertools import uuid -from datetime import datetime, timedelta, timezone -from typing import Any, Callable, Optional, Tuple, Union +from collections.abc import Callable +from datetime import UTC, datetime, timedelta +from typing import Any import pytest import simplejson as json @@ -18,7 +19,7 @@ @pytest.mark.redis_db class TestLegacyOutcomesApi(BaseApiTest): @pytest.fixture - def test_entity(self) -> Union[str, Tuple[str, str]]: + def test_entity(self) -> str | tuple[str, str]: return "outcomes" @pytest.fixture @@ -49,8 +50,8 @@ def generate_outcomes( num_outcomes: int, outcome: int, time_since_base: timedelta, - category: Optional[int], - quantity: Optional[int] = None, + category: int | None, + quantity: int | None = None, ) -> None: outcomes = [] for _ in range(num_outcomes): @@ -84,7 +85,7 @@ def generate_outcomes( write_processed_messages(self.storage, outcomes) def format_time(self, time: datetime) -> str: - return time.replace(tzinfo=timezone.utc).isoformat() + return time.replace(tzinfo=UTC).isoformat() def test_happy_path_querying(self, get_project_id: Callable[[], int]) -> None: project_id = get_project_id() @@ -170,7 +171,7 @@ def test_happy_path_querying(self, get_project_id: Callable[[], int]) -> None: data = json.loads(response.data) assert response.status_code == 200 assert len(data["data"]) == 3 - assert all([row["aggregate"] == 10 for row in data["data"]]) + assert all(row["aggregate"] == 10 for row in data["data"]) assert sorted([row["project_id"] for row in data["data"]]) == [ project_id, project_id, @@ -314,8 +315,8 @@ def generate_outcomes( num_outcomes: int, outcome: int, time_since_base: timedelta, - category: Optional[int], - quantity: Optional[int] = None, + category: int | None, + quantity: int | None = None, ) -> None: outcomes = [] for _ in range(num_outcomes): @@ -349,7 +350,7 @@ def generate_outcomes( write_processed_messages(self.storage, outcomes) def format_time(self, time: datetime) -> str: - return time.replace(tzinfo=timezone.utc).isoformat() + return time.replace(tzinfo=UTC).isoformat() @pytest.fixture(autouse=True) def setup_teardown(self, clickhouse_db: None) -> None: @@ -426,4 +427,4 @@ def test_virtual_time_column(self, get_project_id: Callable[[], int]) -> None: data = json.loads(response.data) assert response.status_code == 200 assert len(data["data"]) == 2 - assert all([row["aggregate"] == 10 for row in data["data"]]) + assert all(row["aggregate"] == 10 for row in data["data"]) diff --git a/tests/test_replacer.py b/tests/test_replacer.py index b6380170416..b7352922f18 100644 --- a/tests/test_replacer.py +++ b/tests/test_replacer.py @@ -2,8 +2,9 @@ import importlib import time -from datetime import datetime, timedelta, timezone -from typing import Any, Mapping, MutableMapping, Sequence +from collections.abc import Mapping, MutableMapping, Sequence +from datetime import UTC, datetime, timedelta +from typing import Any from unittest import mock import pytest @@ -122,7 +123,7 @@ def _issue_count(total: bool = False) -> Sequence[Mapping[str, Any]]: assert _issue_count() == [{"count": 1, "group_id": 1}] assert _issue_count(total=True) == [{"count": 1, "group_id": 1}] - timestamp = datetime.now(tz=timezone.utc) + timestamp = datetime.now(tz=UTC) message: Message[KafkaPayload] = Message( BrokerValue( diff --git a/tests/test_replays_api.py b/tests/test_replays_api.py index 0ff422e881a..fa1546e31b1 100644 --- a/tests/test_replays_api.py +++ b/tests/test_replays_api.py @@ -1,6 +1,6 @@ import uuid from datetime import datetime, timedelta -from typing import Any +from typing import Any, cast import pytest import rapidjson @@ -22,7 +22,7 @@ def post(self, url: str, data: str) -> Any: @pytest.fixture(autouse=True) def setup_teardown(self, clickhouse_db: None) -> None: self.replay_id = uuid.UUID("7400045b-25c4-43b8-8591-4600aa83ad05") - self.event = get_replay_event(replay_id=str(self.replay_id)) + self.event = cast("dict[str, Any]", get_replay_event(replay_id=str(self.replay_id))) self.project_id = self.event["project_id"] self.skew = timedelta(minutes=180) self.base_time = datetime.utcnow().replace(minute=0, second=0, microsecond=0) - timedelta( @@ -77,7 +77,7 @@ def test_sdk_user_title_nullability(self) -> None: payload.pop("user") payload.pop("sdk") payload["tags"] = list(filter(lambda tag: tag[0] != "transaction", payload["tags"])) - self.event["payload"] = list(json.dumps(payload).encode()) # type: ignore + self.event["payload"] = list(json.dumps(payload).encode()) replays_storage = get_entity(EntityKey.REPLAYS).get_writable_storage() assert replays_storage is not None diff --git a/tests/test_search_issues_api.py b/tests/test_search_issues_api.py index 804afa5de24..6aeaa55b156 100644 --- a/tests/test_search_issues_api.py +++ b/tests/test_search_issues_api.py @@ -1,6 +1,7 @@ import uuid +from collections.abc import Callable, MutableMapping from datetime import datetime, timedelta -from typing import Any, Callable, MutableMapping, Tuple, Union +from typing import Any from unittest.mock import Mock, patch import pytest @@ -17,8 +18,10 @@ def base_insert_event( - now: datetime = datetime.now(), -) -> Tuple[int, str, MutableMapping[str, Any]]: + now: datetime | None = None, +) -> tuple[int, str, MutableMapping[str, Any]]: + if now is None: + now = datetime.now() return ( 2, "insert", @@ -49,7 +52,7 @@ def base_insert_event( class TestSearchIssuesSnQLApi(SimpleAPITest, BaseApiTest, ConfigurationTest): @pytest.fixture - def test_entity(self) -> Union[str, Tuple[str, str]]: + def test_entity(self) -> str | tuple[str, str]: return "search_issues" @pytest.fixture @@ -111,25 +114,25 @@ def test_simple_delete(self, mock_produce_delete: Mock) -> None: occurrence_id = str(uuid.uuid4()) group_id = 4 - evt: MutableMapping[str, Any] = dict( - organization_id=1, - project_id=3, - event_id=str(uuid.uuid4().hex), - group_id=group_id, - primary_hash=str(uuid.uuid4().hex), - datetime=datetime.utcnow().isoformat() + "Z", - platform="other", - message="message", - data={"received": now.timestamp()}, - occurrence_data=dict( - id=occurrence_id, - type=1, - issue_title="search me", - fingerprint=["one", "two"], - detection_time=now.timestamp(), - ), - retention_days=90, - ) + evt: MutableMapping[str, Any] = { + "organization_id": 1, + "project_id": 3, + "event_id": str(uuid.uuid4().hex), + "group_id": group_id, + "primary_hash": str(uuid.uuid4().hex), + "datetime": datetime.utcnow().isoformat() + "Z", + "platform": "other", + "message": "message", + "data": {"received": now.timestamp()}, + "occurrence_data": { + "id": occurrence_id, + "type": 1, + "issue_title": "search me", + "fingerprint": ["one", "two"], + "detection_time": now.timestamp(), + }, + "retention_days": 90, + } assert self.events_storage write_unprocessed_events(self.events_storage, [evt]) @@ -196,25 +199,25 @@ def test_bad_delete(self) -> None: def test_simple_search_query(self) -> None: now = datetime.now().replace(minute=0, second=0, microsecond=0) - evt: MutableMapping[str, Any] = dict( - organization_id=1, - project_id=2, - event_id=str(uuid.uuid4().hex), - group_id=3, - primary_hash=str(uuid.uuid4().hex), - datetime=datetime.utcnow().isoformat() + "Z", - platform="other", - message="message", - data={"received": now.timestamp()}, - occurrence_data=dict( - id=str(uuid.uuid4().hex), - type=1, - issue_title="search me", - fingerprint=["one", "two"], - detection_time=now.timestamp(), - ), - retention_days=90, - ) + evt: MutableMapping[str, Any] = { + "organization_id": 1, + "project_id": 2, + "event_id": str(uuid.uuid4().hex), + "group_id": 3, + "primary_hash": str(uuid.uuid4().hex), + "datetime": datetime.utcnow().isoformat() + "Z", + "platform": "other", + "message": "message", + "data": {"received": now.timestamp()}, + "occurrence_data": { + "id": str(uuid.uuid4().hex), + "type": 1, + "issue_title": "search me", + "fingerprint": ["one", "two"], + "detection_time": now.timestamp(), + }, + "retention_days": 90, + } assert self.events_storage write_unprocessed_events(self.events_storage, [evt]) diff --git a/tests/test_snql_api.py b/tests/test_snql_api.py index 32dcb4e170e..2f67eb53643 100644 --- a/tests/test_snql_api.py +++ b/tests/test_snql_api.py @@ -211,21 +211,16 @@ def test_sub_query(self) -> None: "/discover/snql", data=json.dumps( { - "query": """MATCH { + "query": f"""MATCH {{ MATCH (discover_events) SELECT count() AS count BY project_id, tags[custom_tag] - WHERE type != 'transaction' AND project_id = %s - AND timestamp >= toDateTime('%s') - AND timestamp < toDateTime('%s') - } + WHERE type != 'transaction' AND project_id = {self.project_id} + AND timestamp >= toDateTime('{self.base_time.isoformat()}') + AND timestamp < toDateTime('{self.next_time.isoformat()}') + }} SELECT avg(count) AS avg_count ORDER BY avg_count ASC - LIMIT 1000""" - % ( - self.project_id, - self.base_time.isoformat(), - self.next_time.isoformat(), - ), + LIMIT 1000""", "tenant_ids": {"referrer": "r", "organization_id": 123}, } ), diff --git a/tests/test_transactions_api.py b/tests/test_transactions_api.py index cc7d9f635db..80ce37cbc5e 100644 --- a/tests/test_transactions_api.py +++ b/tests/test_transactions_api.py @@ -1,8 +1,9 @@ import calendar import uuid from collections import defaultdict -from datetime import datetime, timedelta, timezone -from typing import Any, Callable, Generator, Tuple, Union +from collections.abc import Callable, Generator +from datetime import UTC, datetime, timedelta +from typing import Any import pytest import simplejson as json @@ -22,7 +23,7 @@ @pytest.mark.redis_db class TestTransactionsApi(BaseApiTest): @pytest.fixture - def test_entity(self) -> Union[str, Tuple[str, str]]: + def test_entity(self) -> str | tuple[str, str]: return "transactions" @pytest.fixture @@ -35,7 +36,7 @@ def setup_teardown( events_db: None, redis_db: None, _build_snql_post_methods: Callable[[str], Any], - ) -> Generator[None, None, None]: + ) -> Generator[None]: self.post = _build_snql_post_methods # values for test data @@ -48,7 +49,7 @@ def setup_teardown( self.skew = timedelta(minutes=self.minutes) self.base_time = datetime.utcnow().replace( - minute=0, second=0, microsecond=0, tzinfo=timezone.utc + minute=0, second=0, microsecond=0, tzinfo=UTC ) - timedelta(minutes=self.minutes) self.storage = get_writable_storage(StorageKey.TRANSACTIONS) self.generate_fizzbuzz_events() @@ -100,10 +101,10 @@ def generate_fizzbuzz_events(self) -> None: "type": "transaction", "transaction": "/api/do_things", "start_timestamp": datetime.timestamp( - (self.base_time + timedelta(minutes=tick)) + self.base_time + timedelta(minutes=tick) ), "timestamp": datetime.timestamp( - (self.base_time + timedelta(minutes=tick, seconds=1)) + self.base_time + timedelta(minutes=tick, seconds=1) ), "tags": { # Sentry @@ -648,20 +649,16 @@ def test_span_id(self) -> None: assert data["data"][0]["span_id"] == "841662216cc598b1" def test_limitby_multicolumn(self) -> None: - query_str = """MATCH (transactions) + query_str = f"""MATCH (transactions) SELECT project_id, environment, platform, event_id WHERE project_id = 1 - AND finish_ts >= toDateTime('{start_time}') - AND finish_ts < toDateTime('{end_time}') - LIMIT {limit_by_count} BY environment, platform - """.format( - start_time=(self.base_time - self.skew).isoformat(), - end_time=(self.base_time + self.skew).isoformat(), - limit_by_count=LIMIT_BY_COUNT, - ) + AND finish_ts >= toDateTime('{(self.base_time - self.skew).isoformat()}') + AND finish_ts < toDateTime('{(self.base_time + self.skew).isoformat()}') + LIMIT {LIMIT_BY_COUNT} BY environment, platform + """ response = self.app.post( SNQL_ROUTE, data=json.dumps( @@ -690,24 +687,21 @@ def test_limitby_multicolumn(self) -> None: for datum in parsed_data["data"]: records_by_limit_columns[(datum["platform"], datum["environment"])].append(datum) - for key in records_by_limit_columns.keys(): + for key in records_by_limit_columns: assert len(records_by_limit_columns[key]) == LIMIT_BY_COUNT, key def test_arrayjoin_multicolumn(self) -> None: - query_str = """MATCH (transactions) + query_str = f"""MATCH (transactions) SELECT event_id, measurements.key, measurements.value ARRAY JOIN measurements.key, measurements.value WHERE project_id = 1 - AND finish_ts >= toDateTime('{start_time}') - AND finish_ts < toDateTime('{end_time}') + AND finish_ts >= toDateTime('{(self.base_time - self.skew).isoformat()}') + AND finish_ts < toDateTime('{(self.base_time + self.skew).isoformat()}') ORDER BY event_id ASC, measurements.key ASC LIMIT 4 - """.format( - start_time=(self.base_time - self.skew).isoformat(), - end_time=(self.base_time + self.skew).isoformat(), - ) + """ response = self.app.post( SNQL_ROUTE, data=json.dumps( diff --git a/tests/test_writer.py b/tests/test_writer.py index 2dde6ad4828..0ad0f6d2b0f 100644 --- a/tests/test_writer.py +++ b/tests/test_writer.py @@ -1,5 +1,4 @@ import gzip -from typing import Optional import pytest import rapidjson @@ -48,7 +47,7 @@ def test_error_handling(self) -> None: class FakeQuery(FormattedQuery): - def get_sql(self, format: Optional[str] = None) -> str: + def get_sql(self, format: str | None = None) -> str: return "SELECT count() FROM groupedmessage_local;" diff --git a/tests/utils/conftest.py b/tests/utils/conftest.py deleted file mode 100644 index 628f8b1eb97..00000000000 --- a/tests/utils/conftest.py +++ /dev/null @@ -1,6 +0,0 @@ -import pytest - - -@pytest.fixture(autouse=True) -def run_migrations() -> None: - pass diff --git a/tests/utils/metrics/test_gauge.py b/tests/utils/metrics/test_gauge.py index ef9e8253121..1ba858a6f27 100644 --- a/tests/utils/metrics/test_gauge.py +++ b/tests/utils/metrics/test_gauge.py @@ -1,8 +1,9 @@ from __future__ import annotations +from collections.abc import Callable from concurrent.futures import Future, wait from threading import Barrier, Thread -from typing import Any, Callable +from typing import Any import pytest diff --git a/tests/utils/test_columns_validator.py b/tests/utils/test_columns_validator.py index df62c2517fd..06eafd001ca 100644 --- a/tests/utils/test_columns_validator.py +++ b/tests/utils/test_columns_validator.py @@ -1,5 +1,6 @@ +from collections.abc import Sequence from datetime import datetime -from typing import Any, Sequence +from typing import Any import pytest @@ -11,6 +12,7 @@ Date, Float, Int, + InvalidColumnType, String, Tuple, UInt, @@ -81,8 +83,10 @@ def test_validator(column_name: str, values: Sequence[Any], is_valid: bool) -> None: col_validator = ColumnValidator(COLUMNS) - if is_valid == True: + if is_valid: col_validator.validate(column_name, values) else: - with pytest.raises(Exception): + # invalid inputs raise InvalidColumnType, except tuple-arity mismatches + # which trip an assert in _valid_tuple + with pytest.raises((InvalidColumnType, AssertionError)): col_validator.validate(column_name, values) diff --git a/tests/utils/test_describer.py b/tests/utils/test_describer.py index 948f5eb79ef..2e26d65151a 100644 --- a/tests/utils/test_describer.py +++ b/tests/utils/test_describer.py @@ -1,16 +1,16 @@ -from typing import List, Optional, Sequence +from collections.abc import Sequence from snuba.utils.describer import Description, DescriptionVisitor, Property class TestDescriber(DescriptionVisitor): def __init__(self) -> None: - self.__content: List[Optional[str]] = [] + self.__content: list[str | None] = [] - def get_content(self) -> Sequence[Optional[str]]: + def get_content(self) -> Sequence[str | None]: return self.__content - def visit_header(self, header: Optional[str]) -> None: + def visit_header(self, header: str | None) -> None: self.__content.append(header) def visit_description(self, desc: Description) -> None: diff --git a/tests/utils/test_rate_limiter.py b/tests/utils/test_rate_limiter.py index 04cac373f6d..244fc3710f8 100644 --- a/tests/utils/test_rate_limiter.py +++ b/tests/utils/test_rate_limiter.py @@ -1,4 +1,4 @@ -from typing import Sequence, Tuple +from collections.abc import Sequence from unittest.mock import Mock, patch import pytest @@ -37,19 +37,18 @@ @patch("time.sleep") @pytest.mark.parametrize("trials", test_cases) def test_rate_limiter( - mock_sleep: Mock, trials: Sequence[Tuple[str, float, RateLimitResult, int]] + mock_sleep: Mock, trials: Sequence[tuple[str, float, RateLimitResult, int]] ) -> None: rate_limiter = RateLimiter("bucket", 3.0) - for bucket, time_resp, expected_result, expected_quota in trials: - with patch("time.time", return_value=time_resp): - # We need to pass the rate limit in since we are mocking time - # When mocking time, get_config would be completely unreliable - # and this would propagate across tests because get_config - # memoizes the result in a way that cannot be reset between tests. - with rate_limiter as quota: - assert quota == (expected_result, expected_quota) - if expected_result == RateLimitResult.THROTTLED: - mock_sleep.assert_called_once() + for _bucket, time_resp, expected_result, expected_quota in trials: + # We need to pass the rate limit in since we are mocking time + # When mocking time, get_config would be completely unreliable + # and this would propagate across tests because get_config + # memoizes the result in a way that cannot be reset between tests. + with patch("time.time", return_value=time_resp), rate_limiter as quota: + assert quota == (expected_result, expected_quota) + if expected_result == RateLimitResult.THROTTLED: + mock_sleep.assert_called_once() def test_disabled() -> None: diff --git a/tests/utils/test_registered_class.py b/tests/utils/test_registered_class.py index e3775816b74..e2aea59c4ba 100644 --- a/tests/utils/test_registered_class.py +++ b/tests/utils/test_registered_class.py @@ -1,4 +1,4 @@ -from typing import Type, cast +from typing import cast import pytest @@ -112,10 +112,10 @@ def config_key(cls) -> str: return "base" @classmethod - def get_class_from_name(cls, name: str) -> Type["TypedFromName"]: + def get_class_from_name(cls, name: str) -> type["TypedFromName"]: # NOTE: This method cannot be type safe without doing this cast. Such is the nature of metaprogramming res = cls.class_from_name(name) - return cast(Type[TypedFromName], res) + return cast(type[TypedFromName], res) class ExtraName(TypedFromName): @@ -124,7 +124,7 @@ def config_key(cls) -> str: return "extra_name" -def get_from_name(name: str) -> Type[TypedFromName]: +def get_from_name(name: str) -> type[TypedFromName]: return TypedFromName.get_class_from_name(name) diff --git a/tests/utils/test_threaded_function_delegator.py b/tests/utils/test_threaded_function_delegator.py index 1fe0f1a43bb..f66e9fdc4da 100644 --- a/tests/utils/test_threaded_function_delegator.py +++ b/tests/utils/test_threaded_function_delegator.py @@ -1,5 +1,5 @@ import threading -from typing import Any, List, Tuple +from typing import Any from unittest.mock import ANY, Mock, call from snuba.utils.threaded_function_delegator import Result, ThreadedFunctionDelegator @@ -16,10 +16,10 @@ def test() -> None: "three": Mock(return_value=3), } - def selector_func(_: int) -> Tuple[str, List[str]]: + def selector_func(_: int) -> tuple[str, list[str]]: return ("one", ["two"]) - def callback_func(primary: Tuple[str, int], other: List[Tuple[str, int]]) -> None: + def callback_func(primary: tuple[str, int], other: list[tuple[str, int]]) -> None: assert result_received.wait(timeout=5), "Timeout while waiting for the main thread." callback_done.set() diff --git a/tests/web/rpc/test_base.py b/tests/web/rpc/test_base.py index ab49c9703ba..f060359dc9a 100644 --- a/tests/web/rpc/test_base.py +++ b/tests/web/rpc/test_base.py @@ -1,7 +1,6 @@ import time import uuid from datetime import timedelta -from typing import Type from unittest.mock import patch import pytest @@ -23,7 +22,7 @@ QueryTimeoutException, ) from snuba.web.rpc.v1.endpoint_trace_item_table import EndpointTraceItemTable -from tests.backends.metrics import TestingMetricsBackend +from tests.backends.metrics import Events, TestingMetricsBackend from tests.web.rpc.v1.test_utils import BASE_TIME RANDOM_REQUEST_ID = str(uuid.uuid4()) @@ -69,7 +68,7 @@ class ErrorRPC(RPCEndpoint[TimeSeriesRequest, TimeSeriesRequest]): duration_millis = 100 @classmethod - def response_class(cls) -> Type[TimeSeriesRequest]: + def response_class(cls) -> type[TimeSeriesRequest]: return TimeSeriesRequest @classmethod @@ -85,7 +84,7 @@ class SilentErrorRPC(RPCEndpoint[TimeSeriesRequest, TimeSeriesRequest]): duration_millis = 100 @classmethod - def response_class(cls) -> Type[TimeSeriesRequest]: + def response_class(cls) -> type[TimeSeriesRequest]: return TimeSeriesRequest @classmethod @@ -104,7 +103,7 @@ def version(cls) -> str: return "v1" @classmethod - def response_class(cls) -> Type[TimeSeriesRequest]: + def response_class(cls) -> type[TimeSeriesRequest]: return TimeSeriesRequest def _execute(self, in_msg: TimeSeriesRequest) -> TimeSeriesRequest: @@ -165,11 +164,11 @@ def test_metrics() -> None: for _ in range(len(metrics_backend.calls)) ] - metric_names_to_metric = {m.name: m for m in metrics_backend.calls} # type: ignore - assert metric_names_to_metric["rpc.endpoint_timing"].value == pytest.approx( # type: ignore + metric_names_to_metric = {m.name: m for m in metrics_backend.calls if not isinstance(m, Events)} + assert metric_names_to_metric["rpc.endpoint_timing"].value == pytest.approx( MyRPC.duration_millis, rel=10 ) - assert metric_names_to_metric["rpc.request_success"].value == 1 # type: ignore + assert metric_names_to_metric["rpc.request_success"].value == 1 @pytest.mark.redis_db @@ -192,8 +191,10 @@ def test_error_metrics() -> None: for _ in range(len(metrics_backend.calls)) ] - metric_names_to_metric = {m.name: m for m in metrics_backend.calls} # type: ignore - assert metric_names_to_metric["rpc.request_error"].value == 1 # type: ignore + metric_names_to_metric = { + m.name: m for m in metrics_backend.calls if not isinstance(m, Events) + } + assert metric_names_to_metric["rpc.request_error"].value == 1 sentry_sdk_mock.assert_called() @@ -208,8 +209,10 @@ def test_should_report_false_not_captured() -> None: with pytest.raises(ColumnTypeError): rpc_call.execute(_get_in_msg()) - metric_names_to_metric = {m.name: m for m in metrics_backend.calls} # type: ignore - assert metric_names_to_metric["rpc.request_error"].value == 1 # type: ignore + metric_names_to_metric = { + m.name: m for m in metrics_backend.calls if not isinstance(m, Events) + } + assert metric_names_to_metric["rpc.request_error"].value == 1 sentry_sdk_mock.assert_not_called() diff --git a/tests/web/rpc/test_common.py b/tests/web/rpc/test_common.py index fe018e77562..035db334ae8 100644 --- a/tests/web/rpc/test_common.py +++ b/tests/web/rpc/test_common.py @@ -1,5 +1,5 @@ import uuid -from datetime import datetime, timedelta, timezone +from datetime import UTC, datetime, timedelta import pytest from google.protobuf import json_format, struct_pb2 @@ -28,7 +28,7 @@ from sentry_protos.snuba.v1.trace_item_pb2 import AnyValue from snuba import settings -from snuba.datasets.storages.factory import get_storage +from snuba.datasets.storages.factory import get_writable_storage from snuba.datasets.storages.storage_key import StorageKey from snuba.protos.common import ATTRIBUTES_TO_COALESCE from snuba.query.expressions import FunctionCall, Lambda, Literal @@ -425,7 +425,7 @@ class TestAnyAttributeFilterIntegration: @pytest.fixture(autouse=True) def setup(self, eap: None, redis_db: None) -> None: - self.base_time = datetime.now(tz=timezone.utc).replace( + self.base_time = datetime.now(tz=UTC).replace( minute=0, second=0, microsecond=0 ) - timedelta(hours=1) self.start_ts = Timestamp(seconds=int((self.base_time - timedelta(hours=1)).timestamp())) @@ -456,8 +456,8 @@ def setup(self, eap: None, redis_db: None) -> None: }, ), ] - storage = get_storage(StorageKey("eap_items")) - write_raw_unprocessed_events(storage, messages) # type: ignore + storage = get_writable_storage(StorageKey("eap_items")) + write_raw_unprocessed_events(storage, messages) def _execute(self, filt: TraceItemFilter) -> list[str]: """Run a TraceItemTable query with the given filter, returning diff --git a/tests/web/rpc/v1/routing_strategies/common.py b/tests/web/rpc/v1/routing_strategies/common.py index ced17d27f6a..f00885a1a55 100644 --- a/tests/web/rpc/v1/routing_strategies/common.py +++ b/tests/web/rpc/v1/routing_strategies/common.py @@ -1,9 +1,8 @@ from datetime import datetime -from typing import Dict, List, Tuple from sentry_relay.consts import DataCategory -from snuba.datasets.storages.factory import get_storage +from snuba.datasets.storages.factory import get_writable_storage from snuba.datasets.storages.storage_key import StorageKey from snuba.web.rpc.storage_routing.routing_strategies.common import Outcome from tests.helpers import write_raw_unprocessed_events @@ -15,7 +14,7 @@ def gen_ingest_outcome( project_id: int = 1, org_id: int = 1, outcome_category: int = DataCategory.SPAN_INDEXED, -) -> Dict[str, int | str | None]: +) -> dict[str, int | str | None]: """Generate a single ingest outcome record. Args: @@ -42,7 +41,7 @@ def gen_ingest_outcome( def store_outcomes_data( - outcome_data: List[Tuple[datetime, int]], + outcome_data: list[tuple[datetime, int]], outcome_category: int = DataCategory.SPAN_INDEXED, org_id: int = 1, project_id: int = 1, @@ -55,7 +54,7 @@ def store_outcomes_data( outcome_category: The outcome category to use for all records when using 2-tuple format (defaults to SPAN_INDEXED, ignored for 3-tuple format) """ - outcomes_storage = get_storage(StorageKey("outcomes_raw")) + outcomes_storage = get_writable_storage(StorageKey("outcomes_raw")) messages = [] for item in outcome_data: @@ -70,7 +69,7 @@ def store_outcomes_data( time, num_outcomes, outcome_category=category, org_id=org_id, project_id=project_id ) ) - write_raw_unprocessed_events(outcomes_storage, messages) # type: ignore + write_raw_unprocessed_events(outcomes_storage, messages) # Available outcome categories (from DataCategory class): diff --git a/tests/web/rpc/v1/test_create_subscription.py b/tests/web/rpc/v1/test_create_subscription.py index 4a692270132..dbc16cfd276 100644 --- a/tests/web/rpc/v1/test_create_subscription.py +++ b/tests/web/rpc/v1/test_create_subscription.py @@ -1,6 +1,7 @@ import base64 +from collections.abc import Callable from datetime import UTC, datetime, timedelta -from typing import Any, Callable +from typing import Any import pytest from confluent_kafka.admin import AdminClient diff --git a/tests/web/rpc/v1/test_endpoint_delete_trace_items.py b/tests/web/rpc/v1/test_endpoint_delete_trace_items.py index cef1037bd26..ef8b651b4f1 100644 --- a/tests/web/rpc/v1/test_endpoint_delete_trace_items.py +++ b/tests/web/rpc/v1/test_endpoint_delete_trace_items.py @@ -1,5 +1,5 @@ import uuid -from datetime import datetime, timedelta, timezone +from datetime import UTC, datetime, timedelta from typing import Any from unittest.mock import Mock, patch @@ -27,7 +27,7 @@ ) from sentry_protos.snuba.v1.trace_item_pb2 import AnyValue -from snuba.datasets.storages.factory import get_storage +from snuba.datasets.storages.factory import get_writable_storage from snuba.datasets.storages.storage_key import StorageKey from snuba.web.rpc.common.exceptions import BadSnubaRPCRequestException from snuba.web.rpc.v1.endpoint_delete_trace_items import EndpointDeleteTraceItems @@ -38,7 +38,7 @@ _REQUEST_ID = uuid.uuid4().hex _TRACE_ID = str(uuid.uuid4()) -_BASE_TIME = datetime.now(tz=timezone.utc).replace( +_BASE_TIME = datetime.now(tz=UTC).replace( minute=0, second=0, microsecond=0, @@ -68,8 +68,8 @@ @pytest.fixture(autouse=False) def setup_teardown(eap: None, redis_db: None) -> None: - items_storage = get_storage(StorageKey("eap_items")) - write_raw_unprocessed_events(items_storage, _SPANS) # type: ignore + items_storage = get_writable_storage(StorageKey("eap_items")) + write_raw_unprocessed_events(items_storage, _SPANS) @pytest.mark.eap @@ -183,12 +183,14 @@ def test_filters_with_equals_operation_accepted(self) -> None: ], ) - with patch("snuba.web.bulk_delete_query._enforce_max_rows", return_value=10): - with patch("snuba.web.bulk_delete_query.produce_delete_query") as mock_produce: - EndpointDeleteTraceItems().execute(message) + with ( + patch("snuba.web.bulk_delete_query._enforce_max_rows", return_value=10), + patch("snuba.web.bulk_delete_query.produce_delete_query") as mock_produce, + ): + EndpointDeleteTraceItems().execute(message) - # Verify produce_delete_query was called with attribute_conditions - assert mock_produce.call_count == 1 + # Verify produce_delete_query was called with attribute_conditions + assert mock_produce.call_count == 1 def test_filters_with_in_operation_accepted(self) -> None: """Test that filters with OP_IN are properly converted to attribute_conditions""" @@ -226,11 +228,11 @@ def test_filters_with_in_operation_accepted(self) -> None: ], ) - with patch("snuba.web.bulk_delete_query._enforce_max_rows", return_value=10): - with patch("snuba.web.bulk_delete_query.produce_delete_query"): - assert isinstance( - EndpointDeleteTraceItems().execute(message), DeleteTraceItemsResponse - ) + with ( + patch("snuba.web.bulk_delete_query._enforce_max_rows", return_value=10), + patch("snuba.web.bulk_delete_query.produce_delete_query"), + ): + assert isinstance(EndpointDeleteTraceItems().execute(message), DeleteTraceItemsResponse) def test_filters_with_unsupported_operation_rejected(self) -> None: """Test that filters with operations other than OP_EQUALS/OP_IN are rejected""" diff --git a/tests/web/rpc/v1/test_endpoint_export_trace_items.py b/tests/web/rpc/v1/test_endpoint_export_trace_items.py index cfcf0d5d4a7..e14d30f5608 100644 --- a/tests/web/rpc/v1/test_endpoint_export_trace_items.py +++ b/tests/web/rpc/v1/test_endpoint_export_trace_items.py @@ -1,7 +1,7 @@ import re import uuid from collections import namedtuple -from datetime import datetime, timedelta, timezone +from datetime import UTC, datetime, timedelta from typing import Any import pytest @@ -11,7 +11,7 @@ from sentry_protos.snuba.v1.request_common_pb2 import RequestMeta, TraceItemType from sentry_protos.snuba.v1.trace_item_pb2 import TraceItem -from snuba.datasets.storages.factory import get_storage +from snuba.datasets.storages.factory import get_writable_storage from snuba.datasets.storages.storage_key import StorageKey from snuba.web import QueryResult from snuba.web.query import run_query @@ -86,9 +86,9 @@ def _assert_attributes_keys(trace_items: list[TraceItem]) -> None: @pytest.fixture(autouse=False) def setup_teardown(eap: None, redis_db: None) -> None: - items_storage = get_storage(StorageKey("eap_items")) - write_raw_unprocessed_events(items_storage, _SPANS) # type: ignore - write_raw_unprocessed_events(items_storage, _LOGS) # type: ignore + items_storage = get_writable_storage(StorageKey("eap_items")) + write_raw_unprocessed_events(items_storage, _SPANS) + write_raw_unprocessed_events(items_storage, _LOGS) @pytest.mark.eap @@ -155,8 +155,8 @@ def test_pagination_with_128_bit_item_id(self, eap: Any, redis_db: Any) -> None: ) for _ in range(num_items) ] - items_storage = get_storage(StorageKey("eap_items")) - write_raw_unprocessed_events(items_storage, items_data) # type: ignore + items_storage = get_writable_storage(StorageKey("eap_items")) + write_raw_unprocessed_events(items_storage, items_data) message = ExportTraceItemsRequest( meta=RequestMeta( @@ -246,7 +246,7 @@ def test_pagination_with_real_flex_window( routed_start_sec = start_sec + 5 write_raw_unprocessed_events( - get_storage(StorageKey("eap_items")), # type: ignore[arg-type] + get_writable_storage(StorageKey("eap_items")), [ gen_item_message( start_timestamp=BASE_TIME + timedelta(seconds=i), @@ -294,9 +294,7 @@ def _capture_query( def _to_sec(s: str) -> int: return int( - datetime.strptime(s, "%Y-%m-%d %H:%M:%S") - .replace(tzinfo=timezone.utc) - .timestamp() + datetime.strptime(s, "%Y-%m-%d %H:%M:%S").replace(tzinfo=UTC).timestamp() ) sql_queried_windows.append((_to_sec(start_m.group(1)), _to_sec(end_m.group(1)))) diff --git a/tests/web/rpc/v1/test_endpoint_get_trace.py b/tests/web/rpc/v1/test_endpoint_get_trace.py index fc625d7edf1..f5bc28e389f 100644 --- a/tests/web/rpc/v1/test_endpoint_get_trace.py +++ b/tests/web/rpc/v1/test_endpoint_get_trace.py @@ -1,5 +1,5 @@ import uuid -from datetime import datetime, timedelta, timezone +from datetime import UTC, datetime, timedelta from operator import attrgetter from typing import Any from unittest.mock import patch @@ -29,7 +29,7 @@ from sentry_protos.snuba.v1.trace_item_pb2 import AnyValue, TraceItem from snuba import state -from snuba.datasets.storages.factory import get_storage +from snuba.datasets.storages.factory import get_writable_storage from snuba.datasets.storages.storage_key import StorageKey from snuba.settings import ENABLE_TRACE_PAGINATION_DEFAULT from snuba.web.rpc.common.common import ATTRIBUTES_ARRAY_ALLOWLIST @@ -45,7 +45,7 @@ from tests.web.rpc.v1.test_utils import SERVER_NAME, gen_item_message _TRACE_ID = uuid.uuid4().hex -_BASE_TIME = datetime.now(tz=timezone.utc).replace( +_BASE_TIME = datetime.now(tz=UTC).replace( minute=0, second=0, microsecond=0, @@ -148,10 +148,10 @@ def _convert_to_attribute_value(value: AnyValue) -> AttributeValue: @pytest.fixture(autouse=False) def setup_teardown(eap: None, redis_db: None) -> None: - items_storage = get_storage(StorageKey("eap_items")) + items_storage = get_writable_storage(StorageKey("eap_items")) - write_raw_unprocessed_events(items_storage, _SPANS) # type: ignore - write_raw_unprocessed_events(items_storage, _LOGS) # type: ignore + write_raw_unprocessed_events(items_storage, _SPANS) + write_raw_unprocessed_events(items_storage, _LOGS) @pytest.mark.eap @@ -226,7 +226,7 @@ def test_with_data_all_attributes(self, setup_teardown: Any) -> None: key=attrgetter("key.name"), ), ) - for timestamp, span in zip(timestamps, spans) + for timestamp, span in zip(timestamps, spans, strict=False) ], ), ], @@ -321,7 +321,7 @@ def test_with_specific_attributes(self, setup_teardown: Any) -> None: ), ], ) - for timestamp, span in zip(timestamps, spans) + for timestamp, span in zip(timestamps, spans, strict=False) ], ), ], @@ -443,7 +443,7 @@ def test_with_logs(self, setup_teardown: Any) -> None: ), ], ) - for timestamp, log in zip(timestamps, logs) + for timestamp, log in zip(timestamps, logs, strict=False) ], ), ], diff --git a/tests/web/rpc/v1/test_endpoint_get_traces.py b/tests/web/rpc/v1/test_endpoint_get_traces.py index 3d1b9d14272..8cf9664c98d 100644 --- a/tests/web/rpc/v1/test_endpoint_get_traces.py +++ b/tests/web/rpc/v1/test_endpoint_get_traces.py @@ -1,6 +1,6 @@ import uuid from collections import defaultdict -from datetime import datetime, timedelta, timezone +from datetime import UTC, datetime, timedelta from typing import Any import pytest @@ -29,7 +29,7 @@ from sentry_protos.snuba.v1.trace_item_pb2 import AnyValue, TraceItem from snuba import state -from snuba.datasets.storages.factory import get_storage +from snuba.datasets.storages.factory import get_writable_storage from snuba.datasets.storages.storage_key import StorageKey from snuba.web.rpc.common.exceptions import BadSnubaRPCRequestException from snuba.web.rpc.v1.endpoint_get_traces import EndpointGetTraces @@ -46,7 +46,7 @@ ) _TRACE_IDS = [uuid.uuid4().hex for _ in range(10)] -_BASE_TIME = datetime.now(tz=timezone.utc).replace( +_BASE_TIME = datetime.now(tz=UTC).replace( minute=0, second=0, microsecond=0, @@ -92,11 +92,11 @@ @pytest.fixture(autouse=False) def setup_teardown(clickhouse_db: None, redis_db: None) -> None: - items_storage = get_storage(StorageKey("eap_items")) + items_storage = get_writable_storage(StorageKey("eap_items")) state.set_config("enable_trace_sampling", True) - write_raw_unprocessed_events(items_storage, _SPANS) # type: ignore - write_raw_unprocessed_events(items_storage, _ADDITIONAL_SPANS) # type: ignore + write_raw_unprocessed_events(items_storage, _SPANS) + write_raw_unprocessed_events(items_storage, _ADDITIONAL_SPANS) @pytest.mark.clickhouse_db @@ -165,7 +165,7 @@ def test_with_data(self, setup_teardown: Any) -> None: ), ], ) - for start_timestamp in reversed(sorted(trace_id_per_start_timestamp.keys())) + for start_timestamp in sorted(trace_id_per_start_timestamp.keys(), reverse=True) ], page_token=PageToken(offset=len(_TRACE_IDS + _ADDITIONAL_TRACE_IDS)), meta=ResponseMeta( @@ -493,7 +493,7 @@ def test_with_data_and_aggregated_fields_all_keys(self, setup_teardown: Any) -> ), ], ) - for start_timestamp in reversed(sorted(trace_id_per_start_timestamp.keys())) + for start_timestamp in sorted(trace_id_per_start_timestamp.keys(), reverse=True) ], page_token=PageToken(offset=len(_TRACE_IDS)), meta=ResponseMeta( @@ -566,7 +566,7 @@ def test_with_data_and_aggregated_fields(self, setup_teardown: Any) -> None: ), ], ) - for start_timestamp in reversed(sorted(trace_id_per_start_timestamp.keys())) + for start_timestamp in sorted(trace_id_per_start_timestamp.keys(), reverse=True) ], page_token=PageToken(offset=len(_TRACE_IDS)), meta=ResponseMeta( @@ -640,7 +640,7 @@ def test_with_data_and_aggregated_fields_ignore_case(self, setup_teardown: Any) ), ], ) - for start_timestamp in reversed(sorted(trace_id_per_start_timestamp.keys())) + for start_timestamp in sorted(trace_id_per_start_timestamp.keys(), reverse=True) ], page_token=PageToken(offset=len(_TRACE_IDS)), meta=ResponseMeta( diff --git a/tests/web/rpc/v1/test_endpoint_time_series/test_endpoint_time_series.py b/tests/web/rpc/v1/test_endpoint_time_series/test_endpoint_time_series.py index b4932c00350..08fd67b82f3 100644 --- a/tests/web/rpc/v1/test_endpoint_time_series/test_endpoint_time_series.py +++ b/tests/web/rpc/v1/test_endpoint_time_series/test_endpoint_time_series.py @@ -1,7 +1,8 @@ +from collections.abc import Callable from dataclasses import dataclass from datetime import UTC, datetime, timedelta from itertools import chain -from typing import Any, Callable +from typing import Any from unittest.mock import MagicMock, call, patch import pytest @@ -40,7 +41,7 @@ ) from sentry_protos.snuba.v1.trace_item_pb2 import AnyValue, ArrayValue -from snuba.datasets.storages.factory import get_storage +from snuba.datasets.storages.factory import get_writable_storage from snuba.datasets.storages.storage_key import StorageKey from snuba.web import QueryException from snuba.web.rpc import RPCEndpoint @@ -72,16 +73,18 @@ def store_spans_timeseries( period_secs: int, len_secs: int, metrics: list[DummyMetric], - attributes: dict[str, AnyValue] = {}, + attributes: dict[str, AnyValue] | None = None, ) -> None: + if attributes is None: + attributes = {} messages = [] for secs in range(0, len_secs, period_secs): dt = start_datetime + timedelta(seconds=secs) a = attributes | {m.name: AnyValue(double_value=m.get_value(secs)) for m in metrics} messages.append(gen_item_message(dt, a)) - items_storage = get_storage(StorageKey("eap_items")) + items_storage = get_writable_storage(StorageKey("eap_items")) - write_raw_unprocessed_events(items_storage, messages) # type: ignore + write_raw_unprocessed_events(items_storage, messages) @pytest.mark.eap @@ -241,6 +244,7 @@ def test_conditional_aggregation(self) -> None: for sum_datapoint, avg_datapoint in zip( expected_sum_timeseries.data_points, expected_avg_timeseries.data_points, + strict=False, ) ], ) @@ -1173,6 +1177,7 @@ def test_formula(self) -> None: for sum_datapoint, avg_datapoint in zip( expected_sum_timeseries.data_points, expected_avg_timeseries.data_points, + strict=False, ) ], ) @@ -1969,8 +1974,8 @@ def make_array(*names: str) -> AnyValue: ) ) - items_storage = get_storage(StorageKey("eap_items")) - write_raw_unprocessed_events(items_storage, messages) # type: ignore + items_storage = get_writable_storage(StorageKey("eap_items")) + write_raw_unprocessed_events(items_storage, messages) message = TimeSeriesRequest( meta=RequestMeta( @@ -2080,9 +2085,7 @@ def test_muliply_attribute_aggregation(self) -> None: ) # figure out the expected value for avg(game_size * game_size_unit_mult) timeseries data_points_bytes = list( - chain( - map(lambda x: x * 10**9, data_points_gb), map(lambda x: x * 10**6, data_points_mb) - ) + chain((x * 10**9 for x in data_points_gb), (x * 10**6 for x in data_points_mb)) ) # query for avg(game_size * game_size_unit_mult) diff --git a/tests/web/rpc/v1/test_endpoint_time_series/test_endpoint_time_series_cross_item_sampling.py b/tests/web/rpc/v1/test_endpoint_time_series/test_endpoint_time_series_cross_item_sampling.py index 95ad4a199ba..4c5082f1033 100644 --- a/tests/web/rpc/v1/test_endpoint_time_series/test_endpoint_time_series_cross_item_sampling.py +++ b/tests/web/rpc/v1/test_endpoint_time_series/test_endpoint_time_series_cross_item_sampling.py @@ -102,37 +102,36 @@ def test_cross_item_query_sampling_enabled(self) -> None: storage_keys, storage_tracker = track_storage_selections() - with storage_tracker: - with patch.object(RPCEndpoint, "_RPCEndpoint__before_execute"): - message = create_time_series_request( - start_time=start_time, - end_time=end_time, - trace_item_type=TraceItemType.TRACE_ITEM_TYPE_SPAN, - expressions=[create_count_expression()], - trace_filters=trace_filters, - granularity_secs=3600, - ) - - mock_routing_decision = create_mock_routing_decision(Tier.TIER_8, message) - - endpoint = EndpointTimeSeries() - endpoint.routing_decision = mock_routing_decision - endpoint.execute(message) - - # Verify storages were selected (should have at least 2 calls: inner + outer) - assert len(storage_keys) >= 2, ( - f"Expected at least 2 storage selections, got {len(storage_keys)}" - ) - - # The inner query should use downsampled storage (TIER_8) - assert StorageKey.EAP_ITEMS_DOWNSAMPLE_8 in storage_keys, ( - f"Inner query should use EAP_ITEMS_DOWNSAMPLE_8, got: {storage_keys}" - ) - - # The outer query should use full storage (EAP_ITEMS) - assert StorageKey.EAP_ITEMS in storage_keys, ( - f"Outer query should use EAP_ITEMS, got: {storage_keys}" - ) + with storage_tracker, patch.object(RPCEndpoint, "_RPCEndpoint__before_execute"): + message = create_time_series_request( + start_time=start_time, + end_time=end_time, + trace_item_type=TraceItemType.TRACE_ITEM_TYPE_SPAN, + expressions=[create_count_expression()], + trace_filters=trace_filters, + granularity_secs=3600, + ) + + mock_routing_decision = create_mock_routing_decision(Tier.TIER_8, message) + + endpoint = EndpointTimeSeries() + endpoint.routing_decision = mock_routing_decision + endpoint.execute(message) + + # Verify storages were selected (should have at least 2 calls: inner + outer) + assert len(storage_keys) >= 2, ( + f"Expected at least 2 storage selections, got {len(storage_keys)}" + ) + + # The inner query should use downsampled storage (TIER_8) + assert StorageKey.EAP_ITEMS_DOWNSAMPLE_8 in storage_keys, ( + f"Inner query should use EAP_ITEMS_DOWNSAMPLE_8, got: {storage_keys}" + ) + + # The outer query should use full storage (EAP_ITEMS) + assert StorageKey.EAP_ITEMS in storage_keys, ( + f"Outer query should use EAP_ITEMS, got: {storage_keys}" + ) def test_cross_item_query_sampling_disabled(self) -> None: """ @@ -154,29 +153,28 @@ def test_cross_item_query_sampling_disabled(self) -> None: storage_keys, storage_tracker = track_storage_selections() - with storage_tracker: - with patch.object(RPCEndpoint, "_RPCEndpoint__before_execute"): - message = create_time_series_request( - start_time=start_time, - end_time=end_time, - trace_item_type=TraceItemType.TRACE_ITEM_TYPE_SPAN, - expressions=[create_count_expression()], - trace_filters=trace_filters, - granularity_secs=3600, - ) - - mock_routing_decision = create_mock_routing_decision(Tier.TIER_8, message) - - endpoint = EndpointTimeSeries() - endpoint.routing_decision = mock_routing_decision - endpoint.execute(message) - - # When feature is disabled, both inner and outer queries should use the same tier - assert len(storage_keys) >= 2, ( - f"Expected at least 2 storage selections, got {len(storage_keys)}" - ) - - # All storages should be TIER_8 (downsampled) - assert all(key == StorageKey.EAP_ITEMS_DOWNSAMPLE_8 for key in storage_keys), ( - f"All queries should use EAP_ITEMS_DOWNSAMPLE_8, got: {storage_keys}" - ) + with storage_tracker, patch.object(RPCEndpoint, "_RPCEndpoint__before_execute"): + message = create_time_series_request( + start_time=start_time, + end_time=end_time, + trace_item_type=TraceItemType.TRACE_ITEM_TYPE_SPAN, + expressions=[create_count_expression()], + trace_filters=trace_filters, + granularity_secs=3600, + ) + + mock_routing_decision = create_mock_routing_decision(Tier.TIER_8, message) + + endpoint = EndpointTimeSeries() + endpoint.routing_decision = mock_routing_decision + endpoint.execute(message) + + # When feature is disabled, both inner and outer queries should use the same tier + assert len(storage_keys) >= 2, ( + f"Expected at least 2 storage selections, got {len(storage_keys)}" + ) + + # All storages should be TIER_8 (downsampled) + assert all(key == StorageKey.EAP_ITEMS_DOWNSAMPLE_8 for key in storage_keys), ( + f"All queries should use EAP_ITEMS_DOWNSAMPLE_8, got: {storage_keys}" + ) diff --git a/tests/web/rpc/v1/test_endpoint_time_series/test_endpoint_time_series_extrapolation.py b/tests/web/rpc/v1/test_endpoint_time_series/test_endpoint_time_series_extrapolation.py index d3fce8472d8..e3ecf07f10c 100644 --- a/tests/web/rpc/v1/test_endpoint_time_series/test_endpoint_time_series_extrapolation.py +++ b/tests/web/rpc/v1/test_endpoint_time_series/test_endpoint_time_series_extrapolation.py @@ -1,6 +1,6 @@ +from collections.abc import Callable from dataclasses import dataclass from datetime import UTC, datetime, timedelta -from typing import Callable import pytest from google.protobuf.timestamp_pb2 import Timestamp @@ -20,7 +20,7 @@ ) from sentry_protos.snuba.v1.trace_item_pb2 import AnyValue -from snuba.datasets.storages.factory import get_storage +from snuba.datasets.storages.factory import get_writable_storage from snuba.datasets.storages.storage_key import StorageKey from snuba.web.rpc.v1.endpoint_time_series import EndpointTimeSeries from tests.base import BaseApiTest @@ -68,8 +68,8 @@ def store_timeseries( server_sample_rate=real_server_sample_rate, ), ) - items_storage = get_storage(StorageKey("eap_items")) - write_raw_unprocessed_events(items_storage, messages) # type: ignore + items_storage = get_writable_storage(StorageKey("eap_items")) + write_raw_unprocessed_events(items_storage, messages) @pytest.mark.clickhouse_db diff --git a/tests/web/rpc/v1/test_endpoint_time_series/test_endpoint_time_series_logs.py b/tests/web/rpc/v1/test_endpoint_time_series/test_endpoint_time_series_logs.py index e43d4baeb48..f47743397ce 100644 --- a/tests/web/rpc/v1/test_endpoint_time_series/test_endpoint_time_series_logs.py +++ b/tests/web/rpc/v1/test_endpoint_time_series/test_endpoint_time_series_logs.py @@ -17,7 +17,7 @@ ) from sentry_protos.snuba.v1.trace_item_pb2 import AnyValue -from snuba.datasets.storages.factory import get_storage +from snuba.datasets.storages.factory import get_writable_storage from snuba.datasets.storages.storage_key import StorageKey from snuba.web.rpc.v1.endpoint_time_series import EndpointTimeSeries from tests.base import BaseApiTest @@ -27,7 +27,7 @@ @pytest.fixture(autouse=False) def setup_logs_in_db(clickhouse_db: None, redis_db: None) -> None: - logs_storage = get_storage(StorageKey("eap_items")) + logs_storage = get_writable_storage(StorageKey("eap_items")) messages = [] for i in range(240): timestamp = BASE_TIME - timedelta(seconds=30 * i) @@ -48,7 +48,7 @@ def setup_logs_in_db(clickhouse_db: None, redis_db: None) -> None: }, ) ) - write_raw_unprocessed_events(logs_storage, messages) # type: ignore + write_raw_unprocessed_events(logs_storage, messages) @pytest.mark.clickhouse_db diff --git a/tests/web/rpc/v1/test_endpoint_trace_item_attribute_names.py b/tests/web/rpc/v1/test_endpoint_trace_item_attribute_names.py index 513a9fe13a6..ec4d6425fb6 100644 --- a/tests/web/rpc/v1/test_endpoint_trace_item_attribute_names.py +++ b/tests/web/rpc/v1/test_endpoint_trace_item_attribute_names.py @@ -11,7 +11,7 @@ from sentry_protos.snuba.v1.trace_item_filter_pb2 import ExistsFilter, TraceItemFilter from sentry_protos.snuba.v1.trace_item_pb2 import AnyValue -from snuba.datasets.storages.factory import get_storage +from snuba.datasets.storages.factory import get_writable_storage from snuba.datasets.storages.storage_key import StorageKey from snuba.web.rpc.v1.endpoint_trace_item_attribute_names import ( EndpointTraceItemAttributeNames, @@ -58,9 +58,9 @@ def generate_span_event_message(id: int) -> bytes: attributes=attributes, ) - items_storage = get_storage(StorageKey("eap_items")) + items_storage = get_writable_storage(StorageKey("eap_items")) messages = [generate_span_event_message(i) for i in range(num_rows)] - write_raw_unprocessed_events(items_storage, messages) # type: ignore + write_raw_unprocessed_events(items_storage, messages) @pytest.fixture(autouse=True) diff --git a/tests/web/rpc/v1/test_endpoint_trace_item_details.py b/tests/web/rpc/v1/test_endpoint_trace_item_details.py index 915261264e1..a7a15d0d397 100644 --- a/tests/web/rpc/v1/test_endpoint_trace_item_details.py +++ b/tests/web/rpc/v1/test_endpoint_trace_item_details.py @@ -16,7 +16,7 @@ from sentry_protos.snuba.v1.trace_item_attribute_pb2 import AttributeKey from sentry_protos.snuba.v1.trace_item_pb2 import AnyValue, ArrayValue -from snuba.datasets.storages.factory import get_storage +from snuba.datasets.storages.factory import get_writable_storage from snuba.datasets.storages.storage_key import StorageKey from snuba.web.rpc.v1.endpoint_trace_item_details import ( EndpointTraceItemDetails, @@ -38,7 +38,7 @@ @pytest.fixture(autouse=False) def setup_logs_in_db(eap: None, redis_db: None) -> None: - logs_storage = get_storage(StorageKey("eap_items")) + logs_storage = get_writable_storage(StorageKey("eap_items")) messages = [] for i in range(120): timestamp = BASE_TIME + timedelta(minutes=i) @@ -63,12 +63,12 @@ def setup_logs_in_db(eap: None, redis_db: None) -> None: }, ) ) - write_raw_unprocessed_events(logs_storage, messages) # type: ignore + write_raw_unprocessed_events(logs_storage, messages) @pytest.fixture(autouse=False) def setup_spans_in_db(eap: None, redis_db: None) -> None: - spans_storage = get_storage(StorageKey("eap_items")) + spans_storage = get_writable_storage(StorageKey("eap_items")) messages = [ gen_item_message( start_timestamp=BASE_TIME - timedelta(minutes=i), @@ -81,7 +81,7 @@ def setup_spans_in_db(eap: None, redis_db: None) -> None: for i in range(120) ] - write_raw_unprocessed_events(spans_storage, messages) # type: ignore + write_raw_unprocessed_events(spans_storage, messages) def _str_tags_array(*values: str) -> AnyValue: @@ -260,9 +260,9 @@ def test_endpoint_on_logs(self, setup_logs_in_db: Any) -> None: def test_endpoint_returns_array_attribute(self, eap: None, redis_db: None) -> None: """Allowlisted attributes_array paths are exposed as val_array on TraceItemDetails.""" span_ts = BASE_TIME - timedelta(minutes=1) - storage = get_storage(StorageKey("eap_items")) + storage = get_writable_storage(StorageKey("eap_items")) write_raw_unprocessed_events( - storage, # type: ignore + storage, [ gen_item_message( span_ts, @@ -335,9 +335,9 @@ def test_endpoint_returns_array_attribute(self, eap: None, redis_db: None) -> No def test_dotted_key_array_attribute_parsed_properly(self, eap: None, redis_db: None) -> None: trace_id = uuid.uuid4().hex span_ts = BASE_TIME - timedelta(minutes=1) - storage = get_storage(StorageKey("eap_items")) + storage = get_writable_storage(StorageKey("eap_items")) write_raw_unprocessed_events( - storage, # type: ignore + storage, [ gen_item_message( span_ts, diff --git a/tests/web/rpc/v1/test_endpoint_trace_item_stats.py b/tests/web/rpc/v1/test_endpoint_trace_item_stats.py index acf9d5d788d..24282bc51fc 100644 --- a/tests/web/rpc/v1/test_endpoint_trace_item_stats.py +++ b/tests/web/rpc/v1/test_endpoint_trace_item_stats.py @@ -18,7 +18,7 @@ ) from sentry_protos.snuba.v1.trace_item_pb2 import AnyValue -from snuba.datasets.storages.factory import get_storage +from snuba.datasets.storages.factory import get_writable_storage from snuba.datasets.storages.storage_key import StorageKey from snuba.web.rpc.v1.endpoint_trace_item_stats import EndpointTraceItemStats from tests.base import BaseApiTest @@ -66,7 +66,7 @@ def pick_n_deterministic(choices: list[Any], weights: list[int], num_choices: in @pytest.fixture(autouse=False) def setup_teardown(clickhouse_db: None, redis_db: None) -> None: - items_storage = get_storage(StorageKey("eap_items")) + items_storage = get_writable_storage(StorageKey("eap_items")) messages = [] durations = pick_n_deterministic( choices=["10", "30", "50", None], @@ -91,7 +91,7 @@ def setup_teardown(clickhouse_db: None, redis_db: None) -> None: ) ) - write_raw_unprocessed_events(items_storage, messages) # type: ignore + write_raw_unprocessed_events(items_storage, messages) @pytest.mark.clickhouse_db diff --git a/tests/web/rpc/v1/test_endpoint_trace_item_stats_heatmap.py b/tests/web/rpc/v1/test_endpoint_trace_item_stats_heatmap.py index 4945defcce8..8d8a61310c4 100644 --- a/tests/web/rpc/v1/test_endpoint_trace_item_stats_heatmap.py +++ b/tests/web/rpc/v1/test_endpoint_trace_item_stats_heatmap.py @@ -18,7 +18,7 @@ ) from sentry_protos.snuba.v1.trace_item_pb2 import AnyValue -from snuba.datasets.storages.factory import get_storage +from snuba.datasets.storages.factory import get_writable_storage from snuba.datasets.storages.storage_key import StorageKey from snuba.web.rpc.v1.endpoint_trace_item_stats import EndpointTraceItemStats from tests.base import BaseApiTest @@ -41,7 +41,7 @@ def setup_heatmap_teardown(eap: None, redis_db: None) -> None: - status_code: 200, 404, 500 - numeric_attr: values 0-99 for numeric bucketing tests """ - items_storage = get_storage(StorageKey("eap_items")) + items_storage = get_writable_storage(StorageKey("eap_items")) messages = [] for i in range(100): @@ -88,7 +88,7 @@ def setup_heatmap_teardown(eap: None, redis_db: None) -> None: ) ) - write_raw_unprocessed_events(items_storage, messages) # type: ignore + write_raw_unprocessed_events(items_storage, messages) @pytest.mark.eap diff --git a/tests/web/rpc/v1/test_endpoint_trace_item_stats_logs.py b/tests/web/rpc/v1/test_endpoint_trace_item_stats_logs.py index 1fc1e84252a..ef9629a0b8b 100644 --- a/tests/web/rpc/v1/test_endpoint_trace_item_stats_logs.py +++ b/tests/web/rpc/v1/test_endpoint_trace_item_stats_logs.py @@ -12,7 +12,7 @@ from sentry_protos.snuba.v1.request_common_pb2 import RequestMeta, TraceItemType from sentry_protos.snuba.v1.trace_item_pb2 import AnyValue -from snuba.datasets.storages.factory import get_storage +from snuba.datasets.storages.factory import get_writable_storage from snuba.datasets.storages.storage_key import StorageKey from snuba.web.rpc.v1.endpoint_trace_item_stats import EndpointTraceItemStats from tests.base import BaseApiTest @@ -22,7 +22,7 @@ @pytest.fixture(autouse=False) def setup_logs_in_db(clickhouse_db: None, redis_db: None) -> None: - logs_storage = get_storage(StorageKey("eap_items")) + logs_storage = get_writable_storage(StorageKey("eap_items")) messages = [] for i in range(120): timestamp = BASE_TIME + timedelta(minutes=i) @@ -41,7 +41,7 @@ def setup_logs_in_db(clickhouse_db: None, redis_db: None) -> None: }, ) ) - write_raw_unprocessed_events(logs_storage, messages) # type: ignore + write_raw_unprocessed_events(logs_storage, messages) @pytest.mark.clickhouse_db diff --git a/tests/web/rpc/v1/test_endpoint_trace_item_table/test_endpoint_trace_item_table.py b/tests/web/rpc/v1/test_endpoint_trace_item_table/test_endpoint_trace_item_table.py index db37dda8481..9be260ce0f5 100644 --- a/tests/web/rpc/v1/test_endpoint_trace_item_table/test_endpoint_trace_item_table.py +++ b/tests/web/rpc/v1/test_endpoint_trace_item_table/test_endpoint_trace_item_table.py @@ -57,7 +57,7 @@ ) from sentry_protos.snuba.v1.trace_item_pb2 import AnyValue, ArrayValue -from snuba.datasets.storages.factory import get_storage +from snuba.datasets.storages.factory import get_writable_storage from snuba.datasets.storages.storage_key import StorageKey from snuba.query import OrderBy, OrderByDirection from snuba.query.dsl import Functions as f @@ -95,7 +95,7 @@ @pytest.fixture(autouse=False) def setup_teardown(clickhouse_db: None, redis_db: None) -> None: - items_storage = get_storage(StorageKey("eap_items")) + items_storage = get_writable_storage(StorageKey("eap_items")) messages = [ gen_item_message( start_timestamp=BASE_TIME + timedelta(minutes=i), @@ -134,7 +134,7 @@ def setup_teardown(clickhouse_db: None, redis_db: None) -> None: ) for i in range(_SPAN_COUNT) ] - write_raw_unprocessed_events(items_storage, messages) # type: ignore + write_raw_unprocessed_events(items_storage, messages) @pytest.mark.clickhouse_db @@ -1449,7 +1449,7 @@ def test_aggregation_filter_basic_backward_compat(self) -> None: # first I write new messages with different value of kylestags, # theres a different number of messages for each tag so that # each will have a different sum value when i do aggregate - items_storage = get_storage(StorageKey("eap_items")) + items_storage = get_writable_storage(StorageKey("eap_items")) msg_timestamp = BASE_TIME + timedelta(minutes=1) messages = ( [ @@ -1474,7 +1474,7 @@ def test_aggregation_filter_basic_backward_compat(self) -> None: for i in range(30) ] ) - write_raw_unprocessed_events(items_storage, messages) # type: ignore + write_raw_unprocessed_events(items_storage, messages) message = TraceItemTableRequest( meta=RequestMeta( @@ -1546,7 +1546,7 @@ def test_aggregation_filter_basic(self) -> None: # first I write new messages with different value of kylestags, # theres a different number of messages for each tag so that # each will have a different sum value when i do aggregate - items_storage = get_storage(StorageKey("eap_items")) + items_storage = get_writable_storage(StorageKey("eap_items")) msg_timestamp = BASE_TIME - timedelta(minutes=1) messages = ( [ @@ -1573,7 +1573,7 @@ def test_aggregation_filter_basic(self) -> None: for i in range(30) ] ) - write_raw_unprocessed_events(items_storage, messages) # type: ignore + write_raw_unprocessed_events(items_storage, messages) message = TraceItemTableRequest( meta=RequestMeta( @@ -1641,7 +1641,7 @@ def test_conditional_aggregation_in_select(self, setup_teardown: Any) -> None: """ This test sums only if the traceitem contains kylestag = val2 """ - items_storage = get_storage(StorageKey("eap_items")) + items_storage = get_writable_storage(StorageKey("eap_items")) msg_timestamp = BASE_TIME - timedelta(minutes=1) messages = ( [ @@ -1666,7 +1666,7 @@ def test_conditional_aggregation_in_select(self, setup_teardown: Any) -> None: for i in range(3) ] ) - write_raw_unprocessed_events(items_storage, messages) # type: ignore + write_raw_unprocessed_events(items_storage, messages) message = TraceItemTableRequest( meta=RequestMeta( @@ -1726,7 +1726,7 @@ def test_conditional_aggregation_in_select(self, setup_teardown: Any) -> None: ] def test_reliability_with_conditional_aggregation(self) -> None: - items_storage = get_storage(StorageKey("eap_items")) + items_storage = get_writable_storage(StorageKey("eap_items")) msg_timestamp = BASE_TIME - timedelta(minutes=1) messages = [ gen_item_message( @@ -1738,7 +1738,7 @@ def test_reliability_with_conditional_aggregation(self) -> None: server_sample_rate=0.85, ), ] - write_raw_unprocessed_events(items_storage, messages) # type: ignore + write_raw_unprocessed_events(items_storage, messages) message = TraceItemTableRequest( meta=RequestMeta( @@ -1833,7 +1833,7 @@ def test_aggregation_filter_and_or_backward_compat(self, setup_teardown: Any) -> # first I write new messages with different value of kylestags, # theres a different number of messages for each tag so that # each will have a different sum value when i do aggregate - items_storage = get_storage(StorageKey("eap_items")) + items_storage = get_writable_storage(StorageKey("eap_items")) msg_timestamp = BASE_TIME - timedelta(minutes=1) messages = ( [ @@ -1858,7 +1858,7 @@ def test_aggregation_filter_and_or_backward_compat(self, setup_teardown: Any) -> for i in range(30) ] ) - write_raw_unprocessed_events(items_storage, messages) # type: ignore + write_raw_unprocessed_events(items_storage, messages) base_message = TraceItemTableRequest( meta=RequestMeta( @@ -2056,7 +2056,7 @@ def test_aggregation_filter_and_or(self, setup_teardown: Any) -> None: # first I write new messages with different value of kylestags, # theres a different number of messages for each tag so that # each will have a different sum value when i do aggregate - items_storage = get_storage(StorageKey("eap_items")) + items_storage = get_writable_storage(StorageKey("eap_items")) msg_timestamp = BASE_TIME - timedelta(minutes=1) messages = ( [ @@ -2081,7 +2081,7 @@ def test_aggregation_filter_and_or(self, setup_teardown: Any) -> None: for i in range(30) ] ) - write_raw_unprocessed_events(items_storage, messages) # type: ignore + write_raw_unprocessed_events(items_storage, messages) base_message = TraceItemTableRequest( meta=RequestMeta( @@ -2230,7 +2230,7 @@ def test_bad_aggregation_filter(self, setup_teardown: Any) -> None: # first I write new messages with different value of kylestags, # theres a different number of messages for each tag so that # each will have a different sum value when i do aggregate - items_storage = get_storage(StorageKey("eap_items")) + items_storage = get_writable_storage(StorageKey("eap_items")) msg_timestamp = BASE_TIME - timedelta(minutes=1) messages = ( [ @@ -2255,7 +2255,7 @@ def test_bad_aggregation_filter(self, setup_teardown: Any) -> None: for i in range(30) ] ) - write_raw_unprocessed_events(items_storage, messages) # type: ignore + write_raw_unprocessed_events(items_storage, messages) message = TraceItemTableRequest( meta=RequestMeta( @@ -2334,7 +2334,7 @@ def test_aggregation_filter_with_binary_formula(self) -> None: This simulates a SQL HAVING clause with complex expressions. """ # Write test data with different success/failure patterns for different services - items_storage = get_storage(StorageKey("eap_items")) + items_storage = get_writable_storage(StorageKey("eap_items")) msg_timestamp = BASE_TIME - timedelta(minutes=1) # Service A: High success rate (9 success, 1 failure = 10% failure rate) @@ -2401,7 +2401,7 @@ def test_aggregation_filter_with_binary_formula(self) -> None: ] all_messages = service_a_messages + service_b_messages + service_c_messages - write_raw_unprocessed_events(items_storage, all_messages) # type: ignore + write_raw_unprocessed_events(items_storage, all_messages) message = TraceItemTableRequest( meta=RequestMeta( @@ -3269,7 +3269,7 @@ def test_virtual_column_with_missing_attribute(self) -> None: ] + [AttributeValue(val_str="default") for _ in range(5)] def test_normal_mode_end_to_end(self) -> None: - items_storage = get_storage(StorageKey("eap_items")) + items_storage = get_writable_storage(StorageKey("eap_items")) msg_timestamp = BASE_TIME - timedelta(minutes=1) messages = [ gen_item_message( @@ -3278,7 +3278,7 @@ def test_normal_mode_end_to_end(self) -> None: ) for _ in range(3600) ] - write_raw_unprocessed_events(items_storage, messages) # type: ignore + write_raw_unprocessed_events(items_storage, messages) best_effort_message = TraceItemTableRequest( meta=RequestMeta( @@ -3299,7 +3299,7 @@ def test_normal_mode_end_to_end(self) -> None: EndpointTraceItemTable().execute(best_effort_message) def test_downsampling_uses_hexintcolumnprocessor(self) -> None: - items_storage = get_storage(StorageKey("eap_items")) + items_storage = get_writable_storage(StorageKey("eap_items")) msg_timestamp = BASE_TIME - timedelta(minutes=1) messages = [ gen_item_message( @@ -3308,7 +3308,7 @@ def test_downsampling_uses_hexintcolumnprocessor(self) -> None: ) for _ in range(3600) ] - write_raw_unprocessed_events(items_storage, messages) # type: ignore + write_raw_unprocessed_events(items_storage, messages) best_effort_message = TraceItemTableRequest( meta=RequestMeta( @@ -3552,7 +3552,7 @@ def test_multiply_attribute_aggregation(self) -> None: * Second batch: game_size = 500 to 850, game_size_unit_mult = 10^6 (MB) * Query for avg(game_size * game_size_unit_mult) and verify the result. """ - items_storage = get_storage(StorageKey("eap_items")) + items_storage = get_writable_storage(StorageKey("eap_items")) data_points_gb = list(range(1, 10 + 1)) data_points_mb = list(range(500, 850 + 1)) @@ -3578,7 +3578,7 @@ def test_multiply_attribute_aggregation(self) -> None: ) for val in data_points_mb ] - write_raw_unprocessed_events(items_storage, gb_messages + mb_messages) # type: ignore + write_raw_unprocessed_events(items_storage, gb_messages + mb_messages) # Calculate expected average of (game_size * game_size_unit_mult) all_products = [val * 10**9 for val in data_points_gb] + [ @@ -3654,9 +3654,9 @@ class TestArrayWildcardSearch(BaseApiTest): def test_like_filter_on_array_attribute(self) -> None: """Wildcard search on array attributes using LIKE returns matching items.""" span_ts = BASE_TIME - timedelta(minutes=1) - items_storage = get_storage(StorageKey("eap_items")) + items_storage = get_writable_storage(StorageKey("eap_items")) write_raw_unprocessed_events( - items_storage, # type: ignore + items_storage, [ gen_item_message( span_ts, attributes={"tags": _str_array("auth-error", "timeout", "retry")} @@ -3701,9 +3701,9 @@ def test_like_filter_on_array_attribute(self) -> None: def test_not_like_filter_on_array_attribute(self) -> None: """NOT_LIKE on array attributes excludes items where any element matches.""" span_ts = BASE_TIME - timedelta(minutes=1) - items_storage = get_storage(StorageKey("eap_items")) + items_storage = get_writable_storage(StorageKey("eap_items")) write_raw_unprocessed_events( - items_storage, # type: ignore + items_storage, [ gen_item_message(span_ts, attributes={"tags": _str_array("auth-error", "timeout")}), gen_item_message(span_ts, attributes={"tags": _str_array("success", "cached")}), @@ -3744,9 +3744,9 @@ def test_not_like_filter_on_array_attribute(self) -> None: def test_trace_item_table_array_op_equals_includes_string_ignore_case(self) -> None: """OP_EQUALS with ignore_case matches a string in a TYPE_ARRAY (element-wise).""" span_ts = BASE_TIME - timedelta(minutes=1) - items_storage = get_storage(StorageKey("eap_items")) + items_storage = get_writable_storage(StorageKey("eap_items")) write_raw_unprocessed_events( - items_storage, # type: ignore + items_storage, [ gen_item_message(span_ts, attributes={"tags": _str_array("ERROR", "other")}), gen_item_message( @@ -3793,9 +3793,9 @@ def test_trace_item_table_array_op_equals_includes_string_ignore_case(self) -> N def test_trace_item_table_array_op_equals_includes_int(self) -> None: """OP_EQUALS on TYPE_ARRAY with val_int=45 returns rows where some element is 45.""" span_ts = BASE_TIME - timedelta(minutes=1) - items_storage = get_storage(StorageKey("eap_items")) + items_storage = get_writable_storage(StorageKey("eap_items")) write_raw_unprocessed_events( - items_storage, # type: ignore + items_storage, [ gen_item_message(span_ts, attributes={"frame_linenos": _int_array(1, 45, 200)}), gen_item_message(span_ts, attributes={"frame_linenos": _int_array(10, 20)}), @@ -3893,9 +3893,9 @@ def test_trace_item_table_array_op_equals_all_scalar_rhs_types( ) -> None: """OP_EQUALS on TYPE_ARRAY: each scalar AttributeValue type matches a stored element""" span_ts = BASE_TIME - timedelta(minutes=1) - items_storage = get_storage(StorageKey("eap_items")) + items_storage = get_writable_storage(StorageKey("eap_items")) write_raw_unprocessed_events( - items_storage, # type: ignore + items_storage, [ gen_item_message(span_ts, attributes=match_attrs), gen_item_message(span_ts, attributes=no_match_attrs), @@ -3938,9 +3938,9 @@ class TestTraceItemTableArrayColumn(BaseApiTest): def test_select_array_column_returns_val_array(self) -> None: """TYPE_ARRAY columns are returned as val_array on TraceItemTable.""" span_ts = BASE_TIME - timedelta(minutes=1) - items_storage = get_storage(StorageKey("eap_items")) + items_storage = get_writable_storage(StorageKey("eap_items")) write_raw_unprocessed_events( - items_storage, # type: ignore + items_storage, [ gen_item_message( span_ts, @@ -4018,7 +4018,7 @@ def test_apply_labels_to_columns_backward_compat(self) -> None: label="avg(custom_measurement_2)", extrapolation_mode=ExtrapolationMode.EXTRAPOLATION_MODE_NONE, ), - label=None, # type: ignore + label=None, # type: ignore[arg-type] ), ], order_by=[], @@ -4046,7 +4046,7 @@ def test_apply_labels_to_columns(self) -> None: label="avg(custom_measurement_2)", extrapolation_mode=ExtrapolationMode.EXTRAPOLATION_MODE_NONE, ), - label=None, # type: ignore + label=None, # type: ignore[arg-type] ), ], order_by=[], diff --git a/tests/web/rpc/v1/test_endpoint_trace_item_table/test_endpoint_trace_item_table_cross_item_sampling.py b/tests/web/rpc/v1/test_endpoint_trace_item_table/test_endpoint_trace_item_table_cross_item_sampling.py index ab3b41ee650..3855a295702 100644 --- a/tests/web/rpc/v1/test_endpoint_trace_item_table/test_endpoint_trace_item_table_cross_item_sampling.py +++ b/tests/web/rpc/v1/test_endpoint_trace_item_table/test_endpoint_trace_item_table_cross_item_sampling.py @@ -83,40 +83,37 @@ def test_cross_item_query_sampling_enabled(self) -> None: storage_keys, storage_tracker = track_storage_selections() - with storage_tracker: - with patch.object(RPCEndpoint, "_RPCEndpoint__before_execute"): - message = create_trace_item_table_request( - start_time=start_time, - end_time=end_time, - trace_item_type=TraceItemType.TRACE_ITEM_TYPE_SPAN, - columns=[ - Column( - key=AttributeKey(type=AttributeKey.TYPE_STRING, name="sentry.span_id") - ) - ], - trace_filters=trace_filters, - ) - - mock_routing_decision = create_mock_routing_decision(Tier.TIER_8, message) - - endpoint = EndpointTraceItemTable() - endpoint.routing_decision = mock_routing_decision - endpoint.execute(message) - - # Verify storages were selected (should have at least 2 calls: inner + outer) - assert len(storage_keys) >= 2, ( - f"Expected at least 2 storage selections, got {len(storage_keys)}" - ) - - # The inner query should use downsampled storage (TIER_8) - assert StorageKey.EAP_ITEMS_DOWNSAMPLE_8 in storage_keys, ( - f"Inner query should use EAP_ITEMS_DOWNSAMPLE_8, got: {storage_keys}" - ) - - # The outer query should use full storage (EAP_ITEMS) - assert StorageKey.EAP_ITEMS in storage_keys, ( - f"Outer query should use EAP_ITEMS, got: {storage_keys}" - ) + with storage_tracker, patch.object(RPCEndpoint, "_RPCEndpoint__before_execute"): + message = create_trace_item_table_request( + start_time=start_time, + end_time=end_time, + trace_item_type=TraceItemType.TRACE_ITEM_TYPE_SPAN, + columns=[ + Column(key=AttributeKey(type=AttributeKey.TYPE_STRING, name="sentry.span_id")) + ], + trace_filters=trace_filters, + ) + + mock_routing_decision = create_mock_routing_decision(Tier.TIER_8, message) + + endpoint = EndpointTraceItemTable() + endpoint.routing_decision = mock_routing_decision + endpoint.execute(message) + + # Verify storages were selected (should have at least 2 calls: inner + outer) + assert len(storage_keys) >= 2, ( + f"Expected at least 2 storage selections, got {len(storage_keys)}" + ) + + # The inner query should use downsampled storage (TIER_8) + assert StorageKey.EAP_ITEMS_DOWNSAMPLE_8 in storage_keys, ( + f"Inner query should use EAP_ITEMS_DOWNSAMPLE_8, got: {storage_keys}" + ) + + # The outer query should use full storage (EAP_ITEMS) + assert StorageKey.EAP_ITEMS in storage_keys, ( + f"Outer query should use EAP_ITEMS, got: {storage_keys}" + ) def test_cross_item_query_sampling_disabled(self) -> None: """ @@ -138,32 +135,29 @@ def test_cross_item_query_sampling_disabled(self) -> None: storage_keys, storage_tracker = track_storage_selections() - with storage_tracker: - with patch.object(RPCEndpoint, "_RPCEndpoint__before_execute"): - message = create_trace_item_table_request( - start_time=start_time, - end_time=end_time, - trace_item_type=TraceItemType.TRACE_ITEM_TYPE_SPAN, - columns=[ - Column( - key=AttributeKey(type=AttributeKey.TYPE_STRING, name="sentry.span_id") - ) - ], - trace_filters=trace_filters, - ) - - mock_routing_decision = create_mock_routing_decision(Tier.TIER_8, message) - - endpoint = EndpointTraceItemTable() - endpoint.routing_decision = mock_routing_decision - endpoint.execute(message) - - # When feature is disabled, both inner and outer queries should use the same tier - assert len(storage_keys) >= 2, ( - f"Expected at least 2 storage selections, got {len(storage_keys)}" - ) - - # All storages should be TIER_8 (downsampled) - assert all(key == StorageKey.EAP_ITEMS_DOWNSAMPLE_8 for key in storage_keys), ( - f"All queries should use EAP_ITEMS_DOWNSAMPLE_8, got: {storage_keys}" - ) + with storage_tracker, patch.object(RPCEndpoint, "_RPCEndpoint__before_execute"): + message = create_trace_item_table_request( + start_time=start_time, + end_time=end_time, + trace_item_type=TraceItemType.TRACE_ITEM_TYPE_SPAN, + columns=[ + Column(key=AttributeKey(type=AttributeKey.TYPE_STRING, name="sentry.span_id")) + ], + trace_filters=trace_filters, + ) + + mock_routing_decision = create_mock_routing_decision(Tier.TIER_8, message) + + endpoint = EndpointTraceItemTable() + endpoint.routing_decision = mock_routing_decision + endpoint.execute(message) + + # When feature is disabled, both inner and outer queries should use the same tier + assert len(storage_keys) >= 2, ( + f"Expected at least 2 storage selections, got {len(storage_keys)}" + ) + + # All storages should be TIER_8 (downsampled) + assert all(key == StorageKey.EAP_ITEMS_DOWNSAMPLE_8 for key in storage_keys), ( + f"All queries should use EAP_ITEMS_DOWNSAMPLE_8, got: {storage_keys}" + ) diff --git a/tests/web/rpc/v1/test_endpoint_trace_item_table/test_endpoint_trace_item_table_extrapolation.py b/tests/web/rpc/v1/test_endpoint_trace_item_table/test_endpoint_trace_item_table_extrapolation.py index c0a4544d20d..a50c81b700a 100644 --- a/tests/web/rpc/v1/test_endpoint_trace_item_table/test_endpoint_trace_item_table_extrapolation.py +++ b/tests/web/rpc/v1/test_endpoint_trace_item_table/test_endpoint_trace_item_table_extrapolation.py @@ -26,7 +26,7 @@ ) from sentry_protos.snuba.v1.trace_item_pb2 import AnyValue -from snuba.datasets.storages.factory import get_storage +from snuba.datasets.storages.factory import get_writable_storage from snuba.datasets.storages.storage_key import StorageKey from snuba.web.rpc.v1.endpoint_trace_item_table import EndpointTraceItemTable from tests.base import BaseApiTest @@ -45,7 +45,7 @@ @pytest.mark.redis_db class TestTraceItemTableWithExtrapolation(BaseApiTest): def test_aggregation_on_attribute_column_backward_compat(self) -> None: - items_storage = get_storage(StorageKey("eap_items")) + items_storage = get_writable_storage(StorageKey("eap_items")) attributes = { "custom_tag": AnyValue(string_value="blah"), } @@ -77,7 +77,7 @@ def test_aggregation_on_attribute_column_backward_compat(self) -> None: ) write_raw_unprocessed_events( - items_storage, # type: ignore + items_storage, messages_w_measurement + messages_no_measurement, ) @@ -157,7 +157,7 @@ def test_aggregation_on_attribute_column_backward_compat(self) -> None: assert abs(measurement_p90 - 4) < 0.01 # weighted p90 - 4 def test_aggregation_on_attribute_column(self) -> None: - items_storage = get_storage(StorageKey("eap_items")) + items_storage = get_writable_storage(StorageKey("eap_items")) attributes = { "custom_tag": AnyValue(string_value="blah"), } @@ -189,7 +189,7 @@ def test_aggregation_on_attribute_column(self) -> None: ) write_raw_unprocessed_events( - items_storage, # type: ignore + items_storage, messages_w_measurement + messages_no_measurement, ) @@ -269,7 +269,7 @@ def test_aggregation_on_attribute_column(self) -> None: assert abs(measurement_p90 - 4) < 0.01 # weighted p90 - 4 def test_conditional_aggregation_on_attribute_column(self) -> None: - items_storage = get_storage(StorageKey("eap_items")) + items_storage = get_writable_storage(StorageKey("eap_items")) messages_w_measurement, messages_no_measurement = [], [] for i in range(5): start_timestamp = BASE_TIME - timedelta(minutes=i) @@ -298,7 +298,7 @@ def test_conditional_aggregation_on_attribute_column(self) -> None: ) write_raw_unprocessed_events( - items_storage, # type: ignore + items_storage, messages_w_measurement + messages_no_measurement, ) @@ -365,7 +365,7 @@ def test_conditional_aggregation_on_attribute_column(self) -> None: assert abs(measurement_avg - 2.6) < 0.000001 # weighted average - (1*2 + 3*8) / (2+8) def test_count_reliability_backward_compat(self) -> None: - items_storage = get_storage(StorageKey("eap_items")) + items_storage = get_writable_storage(StorageKey("eap_items")) attributes = { "custom_tag": AnyValue(string_value="blah"), } @@ -395,7 +395,7 @@ def test_count_reliability_backward_compat(self) -> None: ) write_raw_unprocessed_events( - items_storage, # type: ignore + items_storage, messages_w_measurement + messages_no_measurement, ) @@ -430,12 +430,12 @@ def test_count_reliability_backward_compat(self) -> None: response = EndpointTraceItemTable().execute(message) measurement_count = [v.val_double for v in response.column_values[0].results][0] print(measurement_count) - measurement_reliability = [v for v in response.column_values[0].reliabilities][0] + measurement_reliability = list(response.column_values[0].reliabilities)[0] assert measurement_count == 5 assert measurement_reliability == Reliability.RELIABILITY_HIGH def test_count_reliability(self) -> None: - items_storage = get_storage(StorageKey("eap_items")) + items_storage = get_writable_storage(StorageKey("eap_items")) attributes = { "custom_tag": AnyValue(string_value="blah"), } @@ -464,7 +464,7 @@ def test_count_reliability(self) -> None: ) write_raw_unprocessed_events( - items_storage, # type: ignore + items_storage, messages_w_measurement + messages_no_measurement, ) @@ -495,12 +495,12 @@ def test_count_reliability(self) -> None: ) response = EndpointTraceItemTable().execute(message) measurement_count = [v.val_double for v in response.column_values[0].results][0] - measurement_reliability = [v for v in response.column_values[0].reliabilities][0] + measurement_reliability = list(response.column_values[0].reliabilities)[0] assert measurement_count == 5 assert measurement_reliability == Reliability.RELIABILITY_HIGH def test_count_reliability_with_group_by_backward_compat(self) -> None: - items_storage = get_storage(StorageKey("eap_items")) + items_storage = get_writable_storage(StorageKey("eap_items")) messages_w_measurement, messages_no_measurement = [], [] for i in range(5): start_timestamp = BASE_TIME - timedelta(minutes=i + 1) @@ -528,7 +528,7 @@ def test_count_reliability_with_group_by_backward_compat(self) -> None: ) write_raw_unprocessed_events( - items_storage, # type: ignore + items_storage, messages_w_measurement + messages_no_measurement, ) @@ -596,29 +596,29 @@ def test_count_reliability_with_group_by_backward_compat(self) -> None: assert measurement_tags == ["foo"] measurement_sums = [v.val_double for v in response.column_values[1].results] - measurement_reliabilities = [v for v in response.column_values[1].reliabilities] + measurement_reliabilities = list(response.column_values[1].reliabilities) assert measurement_sums == [sum(range(5))] assert measurement_reliabilities == [Reliability.RELIABILITY_HIGH] measurement_avgs = [v.val_double for v in response.column_values[2].results] - measurement_reliabilities = [v for v in response.column_values[2].reliabilities] + measurement_reliabilities = list(response.column_values[2].reliabilities) assert len(measurement_avgs) == 1 assert measurement_avgs[0] == sum(range(5)) / 5 assert measurement_reliabilities == [Reliability.RELIABILITY_HIGH] measurement_counts = [v.val_double for v in response.column_values[3].results] - measurement_reliabilities = [v for v in response.column_values[3].reliabilities] + measurement_reliabilities = list(response.column_values[3].reliabilities) assert measurement_counts == [5] assert measurement_reliabilities == [Reliability.RELIABILITY_HIGH] measurement_p90s = [v.val_double for v in response.column_values[4].results] - measurement_reliabilities = [v for v in response.column_values[4].reliabilities] + measurement_reliabilities = list(response.column_values[4].reliabilities) assert len(measurement_p90s) == 1 assert measurement_p90s[0] == 4 assert measurement_reliabilities == [Reliability.RELIABILITY_LOW] def test_count_reliability_with_group_by(self) -> None: - items_storage = get_storage(StorageKey("eap_items")) + items_storage = get_writable_storage(StorageKey("eap_items")) messages_w_measurement, messages_no_measurement = [], [] for i in range(5): start_timestamp = BASE_TIME - timedelta(minutes=i + 1) @@ -646,7 +646,7 @@ def test_count_reliability_with_group_by(self) -> None: ) write_raw_unprocessed_events( - items_storage, # type: ignore + items_storage, messages_w_measurement + messages_no_measurement, ) @@ -714,23 +714,23 @@ def test_count_reliability_with_group_by(self) -> None: assert measurement_tags == ["foo"] measurement_sums = [v.val_double for v in response.column_values[1].results] - measurement_reliabilities = [v for v in response.column_values[1].reliabilities] + measurement_reliabilities = list(response.column_values[1].reliabilities) assert measurement_sums == [sum(range(5))] assert measurement_reliabilities == [Reliability.RELIABILITY_HIGH] measurement_avgs = [v.val_double for v in response.column_values[2].results] - measurement_reliabilities = [v for v in response.column_values[2].reliabilities] + measurement_reliabilities = list(response.column_values[2].reliabilities) assert len(measurement_avgs) == 1 assert measurement_avgs[0] == sum(range(5)) / 5 assert measurement_reliabilities == [Reliability.RELIABILITY_HIGH] measurement_counts = [v.val_double for v in response.column_values[3].results] - measurement_reliabilities = [v for v in response.column_values[3].reliabilities] + measurement_reliabilities = list(response.column_values[3].reliabilities) assert measurement_counts == [5] assert measurement_reliabilities == [Reliability.RELIABILITY_HIGH] measurement_p90s = [v.val_double for v in response.column_values[4].results] - measurement_reliabilities = [v for v in response.column_values[4].reliabilities] + measurement_reliabilities = list(response.column_values[4].reliabilities) assert len(measurement_p90s) == 1 assert measurement_p90s[0] == 4 assert measurement_reliabilities == [Reliability.RELIABILITY_LOW] @@ -1092,7 +1092,7 @@ def test_formula_reliability_with_group_by(self) -> None: ] def test_aggregation_with_nulls(self) -> None: - items_storage = get_storage(StorageKey("eap_items")) + items_storage = get_writable_storage(StorageKey("eap_items")) messages_a, messages_b = [], [] for i in range(5): start_timestamp = BASE_TIME - timedelta(minutes=i + 1) @@ -1117,7 +1117,7 @@ def test_aggregation_with_nulls(self) -> None: ) ) write_raw_unprocessed_events( - items_storage, # type: ignore + items_storage, messages_a + messages_b, ) diff --git a/tests/web/rpc/v1/test_endpoint_trace_item_table/test_endpoint_trace_item_table_logs.py b/tests/web/rpc/v1/test_endpoint_trace_item_table/test_endpoint_trace_item_table_logs.py index 7a3453b002d..b66186ba7bd 100644 --- a/tests/web/rpc/v1/test_endpoint_trace_item_table/test_endpoint_trace_item_table_logs.py +++ b/tests/web/rpc/v1/test_endpoint_trace_item_table/test_endpoint_trace_item_table_logs.py @@ -26,7 +26,7 @@ ) from sentry_protos.snuba.v1.trace_item_pb2 import AnyValue -from snuba.datasets.storages.factory import get_storage +from snuba.datasets.storages.factory import get_writable_storage from snuba.datasets.storages.storage_key import StorageKey from snuba.web.rpc.v1.endpoint_trace_item_table import EndpointTraceItemTable from tests.base import BaseApiTest @@ -41,7 +41,7 @@ @pytest.fixture(autouse=False) def setup_logs_in_db(eap: None, redis_db: None) -> None: - logs_storage = get_storage(StorageKey("eap_items")) + logs_storage = get_writable_storage(StorageKey("eap_items")) messages = [] for i in range(120): timestamp = BASE_TIME - timedelta(minutes=i) @@ -61,7 +61,7 @@ def setup_logs_in_db(eap: None, redis_db: None) -> None: }, ) ) - write_raw_unprocessed_events(logs_storage, messages) # type: ignore + write_raw_unprocessed_events(logs_storage, messages) @pytest.mark.eap diff --git a/tests/web/rpc/v1/test_endpoint_trace_item_table/test_occurrence_hourly_event_rate.py b/tests/web/rpc/v1/test_endpoint_trace_item_table/test_occurrence_hourly_event_rate.py index f792ca7b118..e8c1797b192 100644 --- a/tests/web/rpc/v1/test_endpoint_trace_item_table/test_occurrence_hourly_event_rate.py +++ b/tests/web/rpc/v1/test_endpoint_trace_item_table/test_occurrence_hourly_event_rate.py @@ -6,7 +6,7 @@ else: rate = count / hours_since_first_seen """ -from datetime import datetime, timedelta, timezone +from datetime import UTC, datetime, timedelta from typing import Any import pytest @@ -36,7 +36,7 @@ ) from sentry_protos.snuba.v1.trace_item_pb2 import AnyValue -from snuba.datasets.storages.factory import get_storage +from snuba.datasets.storages.factory import get_writable_storage from snuba.datasets.storages.storage_key import StorageKey from snuba.web.rpc.v1.endpoint_trace_item_table import EndpointTraceItemTable from tests.base import BaseApiTest @@ -47,9 +47,7 @@ # Base time for test data - 3 hours ago to ensure data is within query window -BASE_TIME = datetime.now(tz=timezone.utc).replace(minute=0, second=0, microsecond=0) - timedelta( - hours=3 -) +BASE_TIME = datetime.now(tz=UTC).replace(minute=0, second=0, microsecond=0) - timedelta(hours=3) START_TIMESTAMP = Timestamp(seconds=int((BASE_TIME - timedelta(days=14)).timestamp())) END_TIMESTAMP = Timestamp(seconds=int((BASE_TIME + timedelta(hours=1)).timestamp())) @@ -157,7 +155,7 @@ def setup_occurrence_data(clickhouse_db: None, redis_db: None) -> dict[str, Any] - Varying first_seen times (some older than a week, some newer) - Events spread over the past week for counting """ - items_storage = get_storage(StorageKey("eap_items")) + items_storage = get_writable_storage(StorageKey("eap_items")) now = BASE_TIME one_week_ago = now - timedelta(days=7) @@ -235,7 +233,7 @@ def setup_occurrence_data(clickhouse_db: None, redis_db: None) -> dict[str, Any] messages = _create_occurrence_items_for_group(group_id, timestamps[:event_count]) all_messages.extend(messages) - write_raw_unprocessed_events(items_storage, all_messages) # type: ignore + write_raw_unprocessed_events(items_storage, all_messages) return { "expected_rates": expected_rates, @@ -399,7 +397,7 @@ def test_p95_of_precomputed_attribute(self, setup_occurrence_data: dict[str, Any not over a computed formula. """ # Create test data with pre-computed hourly rates as attributes - items_storage = get_storage(StorageKey("eap_items")) + items_storage = get_writable_storage(StorageKey("eap_items")) now = BASE_TIME # Create occurrences with explicit hourly_rate values @@ -420,7 +418,7 @@ def test_p95_of_precomputed_attribute(self, setup_occurrence_data: dict[str, Any ) ) - write_raw_unprocessed_events(items_storage, messages) # type: ignore + write_raw_unprocessed_events(items_storage, messages) # Query P95 of the pre-computed rate message = TraceItemTableRequest( diff --git a/tests/web/rpc/v1/test_endpoint_trace_item_table/test_trace_item_table_flex_time.py b/tests/web/rpc/v1/test_endpoint_trace_item_table/test_trace_item_table_flex_time.py index 08aa6f3c940..d159b5accc0 100644 --- a/tests/web/rpc/v1/test_endpoint_trace_item_table/test_trace_item_table_flex_time.py +++ b/tests/web/rpc/v1/test_endpoint_trace_item_table/test_trace_item_table_flex_time.py @@ -21,7 +21,7 @@ from sentry_protos.snuba.v1.trace_item_pb2 import AnyValue from sentry_relay.consts import DataCategory -from snuba.datasets.storages.factory import get_storage +from snuba.datasets.storages.factory import get_writable_storage from snuba.datasets.storages.storage_key import StorageKey from snuba.web.rpc.common.exceptions import BadSnubaRPCRequestException from snuba.web.rpc.storage_routing.routing_strategies.outcomes_flex_time import ( @@ -45,7 +45,7 @@ class LogOutcomeDataPoint: def _store_logs_and_outcomes(data_points: list[LogOutcomeDataPoint]) -> None: - items_storage = get_storage(StorageKey("eap_items")) + items_storage = get_writable_storage(StorageKey("eap_items")) messages = [] outcome_data = [] @@ -84,7 +84,7 @@ def _store_logs_and_outcomes(data_points: list[LogOutcomeDataPoint]) -> None: ) messages.append(message) outcome_data.append((data_point.time, data_point.num_outcomes)) - write_raw_unprocessed_events(items_storage, messages) # type: ignore + write_raw_unprocessed_events(items_storage, messages) store_outcomes_data(outcome_data, DataCategory.LOG_ITEM, org_id=_ORG_ID, project_id=_PROJECT_ID) @@ -398,7 +398,7 @@ def test_paginate_first_page_empty(self, eap: Any) -> None: elif times_queried == 2: assert result_size == 120 else: - assert False + raise AssertionError() assert times_queried == expected_times_queried diff --git a/tests/web/rpc/v1/test_storage_routing.py b/tests/web/rpc/v1/test_storage_routing.py index 792bf70d7a4..6fb0c229e18 100644 --- a/tests/web/rpc/v1/test_storage_routing.py +++ b/tests/web/rpc/v1/test_storage_routing.py @@ -42,12 +42,12 @@ class AnyInt(int): def __eq__(self, other: object) -> bool: - return isinstance(other, int) or isinstance(other, self.__class__) + return isinstance(other, (int, self.__class__)) class AnyFloat(float): def __eq__(self, other: object) -> bool: - return isinstance(other, int) or isinstance(other, self.__class__) + return isinstance(other, (int, self.__class__)) def _get_in_msg() -> TimeSeriesRequest: diff --git a/tests/web/rpc/v1/test_trace_item_attribute_values_v1.py b/tests/web/rpc/v1/test_trace_item_attribute_values_v1.py index d3fb60fe701..78d1f63ce99 100644 --- a/tests/web/rpc/v1/test_trace_item_attribute_values_v1.py +++ b/tests/web/rpc/v1/test_trace_item_attribute_values_v1.py @@ -1,5 +1,6 @@ -from datetime import UTC, datetime, timedelta, timezone -from typing import Any, Generator, List +from collections.abc import Generator +from datetime import UTC, datetime, timedelta +from typing import Any import pytest from google.protobuf.timestamp_pb2 import Timestamp @@ -12,16 +13,14 @@ from sentry_protos.snuba.v1.trace_item_pb2 import AnyValue from sentry_protos.snuba.v1.trace_item_pb2 import TraceItem as TraceItemMessage -from snuba.datasets.storages.factory import get_storage +from snuba.datasets.storages.factory import get_writable_storage from snuba.datasets.storages.storage_key import StorageKey from snuba.web.rpc.v1.trace_item_attribute_values import AttributeValuesRequest from tests.base import BaseApiTest from tests.helpers import write_raw_unprocessed_events from tests.web.rpc.v1.test_utils import gen_item_message -BASE_TIME = datetime.now(timezone.utc).replace(minute=0, second=0, microsecond=0) - timedelta( - minutes=180 -) +BASE_TIME = datetime.now(UTC).replace(minute=0, second=0, microsecond=0) - timedelta(minutes=180) COMMON_META = RequestMeta( project_ids=[1, 2, 3], organization_id=1, @@ -54,8 +53,8 @@ @pytest.fixture(autouse=True) -def setup_teardown(eap: None, redis_db: None) -> Generator[List[bytes], None, None]: - items_storage = get_storage(StorageKey("eap_items")) +def setup_teardown(eap: None, redis_db: None) -> Generator[list[bytes]]: + items_storage = get_writable_storage(StorageKey("eap_items")) start_timestamp = BASE_TIME messages = [ gen_item_message( @@ -117,7 +116,7 @@ def setup_teardown(eap: None, redis_db: None) -> Generator[List[bytes], None, No }, ), ] - write_raw_unprocessed_events(items_storage, messages) # type: ignore + write_raw_unprocessed_events(items_storage, messages) yield messages @@ -173,7 +172,7 @@ def test_empty_results(self) -> None: assert res.values == [] assert res.counts == [] - def test_item_id_substring_match(self, setup_teardown: List[bytes]) -> None: + def test_item_id_substring_match(self, setup_teardown: list[bytes]) -> None: first_msg_bytes = setup_teardown[0] first_msg = TraceItemMessage() first_msg.ParseFromString(first_msg_bytes) @@ -197,9 +196,9 @@ def test_item_id_substring_match(self, setup_teardown: List[bytes]) -> None: def test_deprecated_alias_attribute(self) -> None: """db.system.name request returns values stored only under deprecated key db.system.""" - items_storage = get_storage(StorageKey("eap_items")) + items_storage = get_writable_storage(StorageKey("eap_items")) write_raw_unprocessed_events( - items_storage, # type: ignore + items_storage, [ gen_item_message( start_timestamp=BASE_TIME, diff --git a/tests/web/rpc/v1/test_utils.py b/tests/web/rpc/v1/test_utils.py index 5f7ed0760de..1293563b854 100644 --- a/tests/web/rpc/v1/test_utils.py +++ b/tests/web/rpc/v1/test_utils.py @@ -1,6 +1,6 @@ import uuid -from datetime import UTC, datetime, timedelta, timezone -from typing import Any, Optional +from datetime import UTC, datetime, timedelta +from typing import Any from google.protobuf.timestamp_pb2 import Timestamp from sentry_protos.snuba.v1.request_common_pb2 import RequestMeta, TraceItemType @@ -13,7 +13,7 @@ ) from sentry_protos.snuba.v1.trace_item_pb2 import AnyValue, ArrayValue, TraceItem -from snuba.datasets.storages.factory import get_storage +from snuba.datasets.storages.factory import get_writable_storage from snuba.datasets.storages.storage_key import StorageKey from tests.helpers import write_raw_unprocessed_events @@ -82,7 +82,7 @@ } # current UTC time rounded down to the start of the current hour, then minus 180 minutes. -BASE_TIME = datetime.now(tz=timezone.utc).replace( +BASE_TIME = datetime.now(tz=UTC).replace( minute=0, second=0, microsecond=0, @@ -91,10 +91,10 @@ def write_eap_item( start_timestamp: datetime, - raw_attributes: dict[str, str | float | int | bool] = {}, + raw_attributes: dict[str, str | float | int | bool] | None = None, count: int = 1, server_sample_rate: float = 1.0, - item_id: Optional[bytes] = None, + item_id: bytes | None = None, ) -> None: """ This is a helper function to write a single or multiple eap-spans to the database. @@ -106,16 +106,19 @@ def write_eap_item( count: the number of these spans to write. """ + if raw_attributes is None: + raw_attributes = {} + def convert_attribute_value(value: Any) -> AnyValue: if isinstance(value, str): return AnyValue(string_value=value) - elif isinstance(value, int): + if isinstance(value, int): return AnyValue(int_value=value) - elif isinstance(value, bool): + if isinstance(value, bool): return AnyValue(bool_value=value) - elif isinstance(value, float): + if isinstance(value, float): return AnyValue(double_value=value) - elif isinstance(value, list): + if isinstance(value, list): return AnyValue( array_value=ArrayValue(values=[convert_attribute_value(v) for v in value]) ) @@ -126,7 +129,7 @@ def convert_attribute_value(value: Any) -> AnyValue: attributes[key] = convert_attribute_value(value) write_raw_unprocessed_events( - get_storage(StorageKey("eap_items")), # type: ignore + get_writable_storage(StorageKey("eap_items")), [ gen_item_message( start_timestamp=start_timestamp, @@ -141,17 +144,19 @@ def convert_attribute_value(value: Any) -> AnyValue: def gen_item_message( start_timestamp: datetime, - attributes: dict[str, AnyValue] = {}, + attributes: dict[str, AnyValue] | None = None, type: TraceItemType.ValueType = TraceItemType.TRACE_ITEM_TYPE_SPAN, - trace_id: Optional[str] = None, + trace_id: str | None = None, server_sample_rate: float = 1.0, client_sample_rate: float = 1.0, - end_timestamp: Optional[datetime] = None, + end_timestamp: datetime | None = None, remove_default_attributes: bool = False, - item_id: Optional[bytes] = None, - project_id: Optional[int] = None, - organization_id: Optional[int] = None, + item_id: bytes | None = None, + project_id: int | None = None, + organization_id: int | None = None, ) -> bytes: + if attributes is None: + attributes = {} item_timestamp = Timestamp() item_timestamp.FromDatetime(start_timestamp) received = Timestamp() @@ -320,12 +325,12 @@ def create_cross_item_test_data() -> tuple[list[str], list[bytes], datetime, dat def write_cross_item_data_to_storage(items: list[bytes]) -> None: """Write cross-item test data to storage.""" - from snuba.datasets.storages.factory import get_storage + from snuba.datasets.storages.factory import get_writable_storage from snuba.datasets.storages.storage_key import StorageKey from tests.helpers import write_raw_unprocessed_events - storage = get_storage(StorageKey("eap_items")) - write_raw_unprocessed_events(storage, items) # type: ignore + storage = get_writable_storage(StorageKey("eap_items")) + write_raw_unprocessed_events(storage, items) def track_storage_selections() -> tuple[list[StorageKey], Any]: diff --git a/tests/web/test__get_allocation_policy.py b/tests/web/test__get_allocation_policy.py index b265f2086a2..c3157bec024 100644 --- a/tests/web/test__get_allocation_policy.py +++ b/tests/web/test__get_allocation_policy.py @@ -1,7 +1,6 @@ from __future__ import annotations from datetime import datetime -from typing import Union import pytest @@ -141,7 +140,7 @@ def __post_init__(self) -> None: ], ) def test__get_allocation_policies( - query: Union[ClickhouseQuery, CompositeQuery[Table]], + query: ClickhouseQuery | CompositeQuery[Table], expected_allocation_policies: list[AllocationPolicy], ) -> None: assert _get_allocation_policies(query) == expected_allocation_policies diff --git a/tests/web/test_bulk_delete_query.py b/tests/web/test_bulk_delete_query.py index 0a07b0f75ab..c05db76d606 100644 --- a/tests/web/test_bulk_delete_query.py +++ b/tests/web/test_bulk_delete_query.py @@ -1,7 +1,8 @@ from __future__ import annotations import time -from typing import Any, Mapping, Optional +from collections.abc import Mapping +from typing import Any from unittest.mock import Mock, patch import pytest @@ -37,7 +38,7 @@ } -def get_attribution_info(tenant_ids: Optional[Mapping[str, int | str]] = None) -> Mapping[str, Any]: +def get_attribution_info(tenant_ids: Mapping[str, int | str] | None = None) -> Mapping[str, Any]: return { "tenant_ids": tenant_ids or {"project_id": 1, "organization_id": 1}, "referrer": "some_referrer", @@ -177,16 +178,18 @@ def test_attribute_conditions_valid_occurrence() -> None: attr_info = get_attribution_info() # Mock out _enforce_max_rows to avoid needing actual data - with patch("snuba.web.bulk_delete_query._enforce_max_rows", return_value=10): - with patch("snuba.web.bulk_delete_query.produce_delete_query") as mock_produce: - # Should not raise an exception, but should return empty dict since - # functionality is not yet launched (permit_delete_by_attribute=0 by default) - result = delete_from_storage(storage, conditions, attr_info, attribute_conditions) + with ( + patch("snuba.web.bulk_delete_query._enforce_max_rows", return_value=10), + patch("snuba.web.bulk_delete_query.produce_delete_query") as mock_produce, + ): + # Should not raise an exception, but should return empty dict since + # functionality is not yet launched (permit_delete_by_attribute=0 by default) + result = delete_from_storage(storage, conditions, attr_info, attribute_conditions) - # Should return empty because the feature flag is off - assert result == {} - # Should not have produced a message since we return early - assert mock_produce.call_count == 0 + # Should return empty because the feature flag is off + assert result == {} + # Should not have produced a message since we return early + assert mock_produce.call_count == 0 @pytest.mark.redis_db @@ -225,10 +228,12 @@ def test_attribute_conditions_missing_item_type() -> None: # Since item_type is now in AttributeConditions, we need to test a different scenario # The validation now should pass, but we need to ensure item_type is also in conditions - with patch("snuba.web.bulk_delete_query._enforce_max_rows", return_value=10): - with patch("snuba.web.bulk_delete_query.produce_delete_query"): - # This should now succeed since we're no longer checking conditions dict - delete_from_storage(storage, conditions, attr_info, attribute_conditions) + with ( + patch("snuba.web.bulk_delete_query._enforce_max_rows", return_value=10), + patch("snuba.web.bulk_delete_query.produce_delete_query"), + ): + # This should now succeed since we're no longer checking conditions dict + delete_from_storage(storage, conditions, attr_info, attribute_conditions) @pytest.mark.redis_db @@ -268,27 +273,29 @@ def test_attribute_conditions_feature_flag_enabled() -> None: try: # Mock out _enforce_max_rows to avoid needing actual data - with patch("snuba.web.bulk_delete_query._enforce_max_rows", return_value=10): - with patch("snuba.web.bulk_delete_query.produce_delete_query") as mock_produce: - # Should process normally and produce a message - result = delete_from_storage(storage, conditions, attr_info, attribute_conditions) - - # Should have produced a message - assert mock_produce.call_count == 1 - # Should return success results - assert result != {} - - # Verify the message includes attribute_conditions - call_args = mock_produce.call_args[0][0] - assert "attribute_conditions" in call_args - assert call_args["attribute_conditions"] == { - "group_id": { - "attr_key_name": "group_id", - "attr_key_type": AttributeKey.TYPE_INT, - "attr_values": [12345], - } + with ( + patch("snuba.web.bulk_delete_query._enforce_max_rows", return_value=10), + patch("snuba.web.bulk_delete_query.produce_delete_query") as mock_produce, + ): + # Should process normally and produce a message + result = delete_from_storage(storage, conditions, attr_info, attribute_conditions) + + # Should have produced a message + assert mock_produce.call_count == 1 + # Should return success results + assert result != {} + + # Verify the message includes attribute_conditions + call_args = mock_produce.call_args[0][0] + assert "attribute_conditions" in call_args + assert call_args["attribute_conditions"] == { + "group_id": { + "attr_key_name": "group_id", + "attr_key_type": AttributeKey.TYPE_INT, + "attr_values": [12345], } - assert call_args["attribute_conditions_item_type"] == TRACE_ITEM_TYPE_OCCURRENCE + } + assert call_args["attribute_conditions_item_type"] == TRACE_ITEM_TYPE_OCCURRENCE finally: # Clean up: disable the feature flag set_config("permit_delete_by_attribute", 0) diff --git a/tests/web/test_db_query.py b/tests/web/test_db_query.py index 6b5290e2702..bb9a804ffd7 100644 --- a/tests/web/test_db_query.py +++ b/tests/web/test_db_query.py @@ -1,6 +1,7 @@ from __future__ import annotations -from typing import Any, Mapping, MutableMapping, Optional +from collections.abc import Mapping, MutableMapping +from typing import Any from unittest import mock import pytest @@ -11,6 +12,7 @@ from snuba.clickhouse.formatter.query import format_query from snuba.clickhouse.query import Query as ClickhouseQuery from snuba.configs.configuration import Configuration, ResourceIdentifier +from snuba.datasets.schemas.tables import TableSource from snuba.datasets.storage import Storage from snuba.datasets.storages.factory import get_storage from snuba.datasets.storages.storage_key import StorageKey @@ -197,7 +199,7 @@ def test_query_settings_from_config( query_config: Mapping[str, Any], expected: MutableMapping[str, Any], - query_prefix: Optional[str], + query_prefix: str | None, async_override: bool, referrer: str, ) -> None: @@ -212,10 +214,12 @@ def _build_test_query( select_expression: str, allocation_policies: list[AllocationPolicy] | None = None ) -> tuple[ClickhouseQuery, Storage, AttributionInfo]: storage = get_storage(StorageKey("errors_ro")) + data_source = storage.get_schema().get_data_source() + assert isinstance(data_source, TableSource) return ( ClickhouseQuery( from_clause=Table( - storage.get_schema().get_data_source().get_table_name(), # type: ignore + data_source.get_table_name(), schema=storage.get_schema().get_columns(), final=False, allocation_policies=allocation_policies or storage.get_allocation_policies(), @@ -380,7 +384,8 @@ def test_db_query_success() -> None: assert len(query_metadata_list) == 1 assert result.extra["stats"] == stats assert result.extra["sql"] is not None - assert set(result.result["profile"].keys()) == { # type: ignore + assert result.result["profile"] is not None + assert set(result.result["profile"].keys()) == { "elapsed", "bytes", "progress_bytes", @@ -413,31 +418,34 @@ def test_bypass_cache_referrer() -> None: # cache should not be used for "some_bypass_cache_referrer" so if the # bypass does not work, the test will try to use a bad cache - with mock.patch("snuba.settings.BYPASS_CACHE_REFERRERS", ["some_bypass_cache_referrer"]): - with mock.patch("snuba.web.db_query._get_cache_partition"): - result = db_query( - clickhouse_query=query, - query_settings=HTTPQuerySettings(), - attribution_info=attribution_info, - dataset_name="events", - query_metadata_list=query_metadata_list, - formatted_query=format_query(query), - reader=storage.get_cluster().get_reader(), - timer=Timer("foo"), - stats=stats, - trace_id="trace_id", - robust=False, - ) - assert len(query_metadata_list) == 1 - assert result.extra["stats"] == stats - assert result.extra["sql"] is not None - assert set(result.result["profile"].keys()) == { # type: ignore - "elapsed", - "bytes", - "progress_bytes", - "blocks", - "rows", - } + with ( + mock.patch("snuba.settings.BYPASS_CACHE_REFERRERS", ["some_bypass_cache_referrer"]), + mock.patch("snuba.web.db_query._get_cache_partition"), + ): + result = db_query( + clickhouse_query=query, + query_settings=HTTPQuerySettings(), + attribution_info=attribution_info, + dataset_name="events", + query_metadata_list=query_metadata_list, + formatted_query=format_query(query), + reader=storage.get_cluster().get_reader(), + timer=Timer("foo"), + stats=stats, + trace_id="trace_id", + robust=False, + ) + assert len(query_metadata_list) == 1 + assert result.extra["stats"] == stats + assert result.extra["sql"] is not None + assert result.result["profile"] is not None + assert set(result.result["profile"].keys()) == { + "elapsed", + "bytes", + "progress_bytes", + "blocks", + "rows", + } @pytest.mark.events_db @@ -473,10 +481,16 @@ def __init__( self, max_threads: int, policy_name: str, - storage_key: StorageKey = StorageKey("doesntmatter"), - required_tenant_types: list[str] = ["a", "b", "c"], - default_config_overrides: dict[str, Any] = {}, + storage_key: StorageKey | None = None, + required_tenant_types: list[str] | None = None, + default_config_overrides: dict[str, Any] | None = None, ) -> None: + if storage_key is None: + storage_key = StorageKey("doesntmatter") + if default_config_overrides is None: + default_config_overrides = {} + if required_tenant_types is None: + required_tenant_types = ["a", "b", "c"] super().__init__( storage_key=ResourceIdentifier(storage_key), required_tenant_types=required_tenant_types, @@ -791,7 +805,9 @@ def _get_quota_allowance( trace_id="trace_id", robust=False, ) - assert settings.get_resource_quota().max_threads == POLICY_THREADS # type: ignore + resource_quota = settings.get_resource_quota() + assert resource_quota is not None + assert resource_quota.max_threads == POLICY_THREADS assert stats["max_threads"] == POLICY_THREADS assert query_metadata_list[0].stats["max_threads"] == POLICY_THREADS diff --git a/tests/web/test_max_rows_enforcer.py b/tests/web/test_max_rows_enforcer.py index 729f8d19e99..7ebabdedd5a 100644 --- a/tests/web/test_max_rows_enforcer.py +++ b/tests/web/test_max_rows_enforcer.py @@ -1,5 +1,6 @@ +from collections.abc import Callable from datetime import datetime -from typing import Any, Callable +from typing import Any from unittest import mock import pytest diff --git a/tests/web/test_project_finder.py b/tests/web/test_project_finder.py index 2736924e028..aa833213088 100644 --- a/tests/web/test_project_finder.py +++ b/tests/web/test_project_finder.py @@ -1,5 +1,3 @@ -from typing import Set, Union - import pytest from snuba.clickhouse.columns import UUID, ColumnSet, UInt @@ -8,7 +6,7 @@ from snuba.query.composite import CompositeQuery from snuba.query.conditions import ConditionFunctions, binary_condition from snuba.query.data_source.projects_finder import ProjectsFinder -from snuba.query.data_source.simple import Entity +from snuba.query.data_source.simple import Entity, LogicalDataSource from snuba.query.expressions import Column, FunctionCall, Literal from snuba.query.logical import Query from snuba.utils.schemas import Column as EntityColumn @@ -65,8 +63,8 @@ TEST_CASES, ) def test_count_columns( - query: Union[Query, CompositeQuery[Entity]], - expected_proj: Set[int], + query: Query | CompositeQuery[LogicalDataSource], + expected_proj: set[int], ) -> None: project_finder = ProjectsFinder() - assert project_finder.visit(query) == expected_proj # type: ignore + assert project_finder.visit(query) == expected_proj diff --git a/tests/web/test_results.py b/tests/web/test_results.py index fbcef2eb36c..60cb4973290 100644 --- a/tests/web/test_results.py +++ b/tests/web/test_results.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Mapping +from collections.abc import Mapping import pytest diff --git a/tests/web/test_tables_collector.py b/tests/web/test_tables_collector.py index 450611aa546..72a90b00ffd 100644 --- a/tests/web/test_tables_collector.py +++ b/tests/web/test_tables_collector.py @@ -1,5 +1,3 @@ -from typing import Optional, Set, Union - import pytest from snuba.clickhouse.columns import UUID, ColumnSet, String, UInt @@ -135,10 +133,10 @@ TEST_CASES, ) def test_count_columns( - query: Union[ClickhouseQuery, CompositeQuery[Table]], - expected_tables: Set[str], + query: ClickhouseQuery | CompositeQuery[Table], + expected_tables: set[str], expected_final: bool, - expected_sampling: Optional[float], + expected_sampling: float | None, ) -> None: tables_collector = TablesCollector() tables_collector.visit(query) diff --git a/tests/web/test_views.py b/tests/web/test_views.py index 7459220555d..c185757b6db 100644 --- a/tests/web/test_views.py +++ b/tests/web/test_views.py @@ -32,7 +32,7 @@ def snuba_api() -> FlaskClient: def test_response_dumping() -> None: - data = { + data: dict[str, Any] = { "data": [ {"count": 5181337, "release": "elsa"}, {"count": 2170, "release": "simba"}, @@ -58,7 +58,7 @@ def test_response_dumping() -> None: dumped_payload = dump_payload(data) clean_data = copy.deepcopy(data) - clean_data["data"][3]["release"] = "RAW_BYTESTRING__" + b"x;\x83\xc0\x05".hex() # type: ignore + clean_data["data"][3]["release"] = "RAW_BYTESTRING__" + b"x;\x83\xc0\x05".hex() assert json.loads(dumped_payload) == clean_data diff --git a/uv.lock b/uv.lock index 9712d097269..40e10a7ad1c 100644 --- a/uv.lock +++ b/uv.lock @@ -2,8 +2,10 @@ version = 1 revision = 3 requires-python = ">=3.13" resolution-markers = [ - "sys_platform == 'darwin'", - "sys_platform == 'linux'", + "python_full_version >= '3.15' and sys_platform == 'darwin'", + "python_full_version < '3.15' and sys_platform == 'darwin'", + "python_full_version >= '3.15' and sys_platform == 'linux'", + "python_full_version < '3.15' and sys_platform == 'linux'", ] supported-markers = [ "sys_platform == 'darwin'", @@ -20,6 +22,16 @@ members = [ "snuba", ] +[[package]] +name = "ast-serialize" +version = "0.3.0" +source = { registry = "https://pypi.devinfra.sentry.io/simple" } +wheels = [ + { url = "https://pypi.devinfra.sentry.io/wheels/ast_serialize-0.3.0-cp39-abi3-macosx_11_0_arm64.whl", hash = "sha256:ef6d3c08b7b4cd29b48410338e134764a00e76d25841eb02c1084e868c888ecc" }, + { url = "https://pypi.devinfra.sentry.io/wheels/ast_serialize-0.3.0-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3d841424f41b886e98044abc80769c14a956e6e5ccd5fb5b0d9f5ead72be18a4" }, + { url = "https://pypi.devinfra.sentry.io/wheels/ast_serialize-0.3.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c1c9e763d70293d65ce1e1ea8c943140c68d0953f0268c7ee0998f2e07f77dd0" }, +] + [[package]] name = "attrs" version = "25.3.0" @@ -486,6 +498,19 @@ wheels = [ { url = "https://pypi.devinfra.sentry.io/wheels/jsonschema_specifications-2023.12.1-py3-none-any.whl", hash = "sha256:87e4fdf3a94858b8a2ba2778d9ba57d8a9cafca7c7489c46ba0d30a8bc6a9c3c" }, ] +[[package]] +name = "librt" +version = "0.11.0" +source = { registry = "https://pypi.devinfra.sentry.io/simple" } +wheels = [ + { url = "https://pypi.devinfra.sentry.io/wheels/librt-0.11.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:fa475675db22290c3158e1d42326d0f5a65f04f44a0e68c3630a25b53560fb9c" }, + { url = "https://pypi.devinfra.sentry.io/wheels/librt-0.11.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:621db29691044bdeda22e789e482e1b0f3a985d90e3426c9c6d17606416205ea" }, + { url = "https://pypi.devinfra.sentry.io/wheels/librt-0.11.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7c39513d8b7477a2e1ed8c43fc21c524e8d5a0f8d4e8b7b074dbdbe7820a08e2" }, + { url = "https://pypi.devinfra.sentry.io/wheels/librt-0.11.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:b1ecbd9819deccc39b7542bf4d2a740d8a620694d39989e58661d3763458f8d4" }, + { url = "https://pypi.devinfra.sentry.io/wheels/librt-0.11.0-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:7da327dacd7be8f8ec36547373550744a3cc0e536d54665cd83f8bcd961200e8" }, + { url = "https://pypi.devinfra.sentry.io/wheels/librt-0.11.0-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:05fb8fb2ab90e21c8d12ea240d744ad514da9baf381ebfa70d91d20d21713175" }, +] + [[package]] name = "markupsafe" version = "3.0.2" @@ -527,14 +552,23 @@ wheels = [ [[package]] name = "mypy" -version = "1.1.1" +version = "2.1.0" source = { registry = "https://pypi.devinfra.sentry.io/simple" } dependencies = [ + { name = "ast-serialize", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "librt", marker = "(platform_python_implementation != 'PyPy' and sys_platform == 'darwin') or (platform_python_implementation != 'PyPy' and sys_platform == 'linux')" }, { name = "mypy-extensions", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, - { name = "typing-extensions", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "pathspec", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "typing-extensions", version = "4.12.2", source = { registry = "https://pypi.devinfra.sentry.io/simple" }, marker = "(python_full_version < '3.15' and sys_platform == 'darwin') or (python_full_version < '3.15' and sys_platform == 'linux')" }, + { name = "typing-extensions", version = "4.15.0", source = { registry = "https://pypi.devinfra.sentry.io/simple" }, marker = "(python_full_version >= '3.15' and sys_platform == 'darwin') or (python_full_version >= '3.15' and sys_platform == 'linux')" }, ] wheels = [ - { url = "https://pypi.devinfra.sentry.io/wheels/mypy-1.1.1-py3-none-any.whl", hash = "sha256:4e4e8b362cdf99ba00c2b218036002bdcdf1e0de085cdb296a49df03fb31dfc4" }, + { url = "https://pypi.devinfra.sentry.io/wheels/mypy-2.1.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:8de55a8c861f2a49331f807be98d90caeceeef520bde13d43a160207f8af613e" }, + { url = "https://pypi.devinfra.sentry.io/wheels/mypy-2.1.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:5fdf2941a07434af755837d9880f7d7d25f1dacb1af9dcd4b9b66f2220a3024e" }, + { url = "https://pypi.devinfra.sentry.io/wheels/mypy-2.1.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:e195b817c13f02352a9c124301f9f30f078405444679b6753c1b96b6eed37285" }, + { url = "https://pypi.devinfra.sentry.io/wheels/mypy-2.1.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:49890d4f76ac9e06ec117f9e09f3174da70a620a0c300953d8595c926e80947f" }, + { url = "https://pypi.devinfra.sentry.io/wheels/mypy-2.1.0-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:761be68e023ef5d94678772396a8af1220030f80837a3afd8d0aef3b419666f4" }, + { url = "https://pypi.devinfra.sentry.io/wheels/mypy-2.1.0-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c90345fc182dc363b891350457ec69c35140858538f38b4540845afcc32b1aef" }, ] [[package]] @@ -572,6 +606,14 @@ wheels = [ { url = "https://pypi.devinfra.sentry.io/wheels/parsimonious-0.10.0-py3-none-any.whl", hash = "sha256:982ab435fabe86519b57f6b35610aa4e4e977e9f02a14353edf4bbc75369fc0f" }, ] +[[package]] +name = "pathspec" +version = "1.0.4" +source = { registry = "https://pypi.devinfra.sentry.io/simple" } +wheels = [ + { url = "https://pypi.devinfra.sentry.io/wheels/pathspec-1.0.4-py3-none-any.whl", hash = "sha256:fb6ae2fd4e7c921a165808a552060e722767cfa526f99ca5156ed2ce45a5c723" }, +] + [[package]] name = "platformdirs" version = "4.3.7" @@ -762,7 +804,8 @@ name = "python-utils" version = "3.8.1" source = { registry = "https://pypi.devinfra.sentry.io/simple" } dependencies = [ - { name = "typing-extensions", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "typing-extensions", version = "4.12.2", source = { registry = "https://pypi.devinfra.sentry.io/simple" }, marker = "(python_full_version < '3.15' and sys_platform == 'darwin') or (python_full_version < '3.15' and sys_platform == 'linux')" }, + { name = "typing-extensions", version = "4.15.0", source = { registry = "https://pypi.devinfra.sentry.io/simple" }, marker = "(python_full_version >= '3.15' and sys_platform == 'darwin') or (python_full_version >= '3.15' and sys_platform == 'linux')" }, ] wheels = [ { url = "https://pypi.devinfra.sentry.io/wheels/python_utils-3.8.1-py2.py3-none-any.whl", hash = "sha256:efdf31c8154667d7dc0317547c8e6d3b506c5d4b6e360e0c89662306262fc0ab" }, @@ -911,7 +954,8 @@ version = "1.28.0" source = { registry = "https://pypi.devinfra.sentry.io/simple" } dependencies = [ { name = "sentry-sdk", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, - { name = "typing-extensions", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "typing-extensions", version = "4.12.2", source = { registry = "https://pypi.devinfra.sentry.io/simple" }, marker = "(python_full_version < '3.15' and sys_platform == 'darwin') or (python_full_version < '3.15' and sys_platform == 'linux')" }, + { name = "typing-extensions", version = "4.15.0", source = { registry = "https://pypi.devinfra.sentry.io/simple" }, marker = "(python_full_version >= '3.15' and sys_platform == 'darwin') or (python_full_version >= '3.15' and sys_platform == 'linux')" }, ] wheels = [ { url = "https://pypi.devinfra.sentry.io/wheels/sentry_devenv-1.28.0-py3-none-any.whl", hash = "sha256:304b603c561c4a0a206c7d1346aebf8ec44e6175f44ecce0b9c1a5848fb3f7ca" }, @@ -927,7 +971,8 @@ dependencies = [ { name = "python-rapidjson", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "pyyaml", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "sentry-protos", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, - { name = "typing-extensions", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "typing-extensions", version = "4.12.2", source = { registry = "https://pypi.devinfra.sentry.io/simple" }, marker = "(python_full_version < '3.15' and sys_platform == 'darwin') or (python_full_version < '3.15' and sys_platform == 'linux')" }, + { name = "typing-extensions", version = "4.15.0", source = { registry = "https://pypi.devinfra.sentry.io/simple" }, marker = "(python_full_version >= '3.15' and sys_platform == 'darwin') or (python_full_version >= '3.15' and sys_platform == 'linux')" }, ] wheels = [ { url = "https://pypi.devinfra.sentry.io/wheels/sentry_kafka_schemas-2.1.24-py2.py3-none-any.whl", hash = "sha256:50ca2ce88598b2a2eacdeb609a783153c6a7bb4e3f5e3bc4121690b8a3c2f293" }, @@ -1106,7 +1151,8 @@ dev = [ { name = "types-requests", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "types-setuptools", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "types-simplejson", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, - { name = "typing-extensions", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "typing-extensions", version = "4.12.2", source = { registry = "https://pypi.devinfra.sentry.io/simple" }, marker = "(python_full_version < '3.15' and sys_platform == 'darwin') or (python_full_version < '3.15' and sys_platform == 'linux')" }, + { name = "typing-extensions", version = "4.15.0", source = { registry = "https://pypi.devinfra.sentry.io/simple" }, marker = "(python_full_version >= '3.15' and sys_platform == 'darwin') or (python_full_version >= '3.15' and sys_platform == 'linux')" }, ] [package.metadata] @@ -1160,7 +1206,7 @@ requires-dist = [ dev = [ { name = "devservices", specifier = ">=1.2.1" }, { name = "freezegun", specifier = ">=1.5.5" }, - { name = "mypy", specifier = ">=1.1.1" }, + { name = "mypy", specifier = ">=1.18.2" }, { name = "pre-commit", specifier = ">=4.2.0" }, { name = "pytest", specifier = ">=9.0.3" }, { name = "pytest-cov", specifier = ">=4.1.0" }, @@ -1319,10 +1365,26 @@ wheels = [ name = "typing-extensions" version = "4.12.2" source = { registry = "https://pypi.devinfra.sentry.io/simple" } +resolution-markers = [ + "python_full_version < '3.15' and sys_platform == 'darwin'", + "python_full_version < '3.15' and sys_platform == 'linux'", +] wheels = [ { url = "https://pypi.devinfra.sentry.io/wheels/typing_extensions-4.12.2-py3-none-any.whl", hash = "sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d" }, ] +[[package]] +name = "typing-extensions" +version = "4.15.0" +source = { registry = "https://pypi.devinfra.sentry.io/simple" } +resolution-markers = [ + "python_full_version >= '3.15' and sys_platform == 'darwin'", + "python_full_version >= '3.15' and sys_platform == 'linux'", +] +wheels = [ + { url = "https://pypi.devinfra.sentry.io/wheels/typing_extensions-4.15.0-py3-none-any.whl", hash = "sha256:f0fa19c6845758ab08074a0cfa8b7aecb71c999ca73d62883bc25cc018c4e548" }, +] + [[package]] name = "tzdata" version = "2025.2" @@ -1376,7 +1438,8 @@ name = "watchdog" version = "3.0.0" source = { registry = "https://pypi.devinfra.sentry.io/simple" } resolution-markers = [ - "sys_platform == 'linux'", + "python_full_version >= '3.15' and sys_platform == 'linux'", + "python_full_version < '3.15' and sys_platform == 'linux'", ] wheels = [ { url = "https://pypi.devinfra.sentry.io/wheels/watchdog-3.0.0-py3-none-manylinux2014_aarch64.whl", hash = "sha256:0e06ab8858a76e1219e68c7573dfeba9dd1c0219476c5a44d5333b01d7e1743a" }, @@ -1388,7 +1451,8 @@ name = "watchdog" version = "6.0.0" source = { registry = "https://pypi.devinfra.sentry.io/simple" } resolution-markers = [ - "sys_platform == 'darwin'", + "python_full_version >= '3.15' and sys_platform == 'darwin'", + "python_full_version < '3.15' and sys_platform == 'darwin'", ] wheels = [ { url = "https://pypi.devinfra.sentry.io/wheels/watchdog-6.0.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:76aae96b00ae814b181bb25b1b98076d5fc84e8a53cd8885a318b42b6d3a5134" },