diff --git a/sqlmesh/core/config/__init__.py b/sqlmesh/core/config/__init__.py index 42ed82c6e6..a57af66b2e 100644 --- a/sqlmesh/core/config/__init__.py +++ b/sqlmesh/core/config/__init__.py @@ -32,6 +32,7 @@ load_configs as load_configs, ) from sqlmesh.core.config.migration import MigrationConfig as MigrationConfig +from sqlmesh.core.config.ownership import OwnershipConfig as OwnershipConfig from sqlmesh.core.config.model import ModelDefaultsConfig as ModelDefaultsConfig from sqlmesh.core.config.naming import NameInferenceConfig as NameInferenceConfig from sqlmesh.core.config.linter import LinterConfig as LinterConfig diff --git a/sqlmesh/core/config/ownership.py b/sqlmesh/core/config/ownership.py new file mode 100644 index 0000000000..9730be8d78 --- /dev/null +++ b/sqlmesh/core/config/ownership.py @@ -0,0 +1,78 @@ +from __future__ import annotations + +import re +import typing as t + +from pydantic.functional_validators import BeforeValidator + +from sqlmesh.core.config.base import BaseConfig +from sqlmesh.core.config.common import compile_regex_mapping + +if t.TYPE_CHECKING: + from sqlmesh.core.engine_adapter.base import EngineAdapter + OwnershipMapping = t.Dict[re.Pattern, str] + EnvironmentOwnerResolver = t.Callable[[str, EngineAdapter], t.Optional[str]] + PhysicalOwnerResolver = t.Callable[[EngineAdapter], t.Optional[str]] +else: + OwnershipMapping = t.Annotated[t.Dict[re.Pattern, str], BeforeValidator(compile_regex_mapping)] + EnvironmentOwnerResolver = t.Callable + PhysicalOwnerResolver = t.Callable + + +class OwnershipConfig(BaseConfig): + """Configuration for object ownership rules applied at creation time. + + For static YAML-based config, use ``environment_owner_mapping`` and + ``physical_owner``. For programmatic config where the principal must be + resolved at plan-execution time (e.g. via ``adapter.current_user()`` or a + Databricks API call), supply ``environment_owner_resolver`` and/or + ``physical_owner_resolver`` instead — callables take precedence over the + static fields. + + Example (YAML):: + + ownership: + environment_owner_mapping: + "^prod$": "svc_prod_spn" + ".*": "group:shared-developers" + physical_owner: "group:shared-developers" + + Example (Python):: + + OwnershipConfig( + environment_owner_resolver=lambda env, adapter: ( + adapter.current_user() if env == "prod" else "group:shared-developers" + ), + physical_owner="group:shared-developers", + ) + """ + + environment_owner_mapping: OwnershipMapping = {} + environment_owner_resolver: t.Optional[EnvironmentOwnerResolver] = None + physical_owner: t.Optional[str] = None + physical_owner_resolver: t.Optional[PhysicalOwnerResolver] = None + + @property + def is_active(self) -> bool: + """True when any ownership rule is configured.""" + return bool( + self.environment_owner_resolver is not None + or self.environment_owner_mapping + or self.physical_owner is not None + or self.physical_owner_resolver is not None + ) + + def resolve_owner(self, environment_name: str, adapter: "EngineAdapter") -> t.Optional[str]: + """Return the configured owner for the given environment, or None.""" + if self.environment_owner_resolver is not None: + return self.environment_owner_resolver(environment_name, adapter) + for pattern, owner in self.environment_owner_mapping.items(): + if pattern.fullmatch(environment_name): + return owner + return None + + def resolve_physical_owner(self, adapter: "EngineAdapter") -> t.Optional[str]: + """Return the configured physical-layer owner, or None.""" + if self.physical_owner_resolver is not None: + return self.physical_owner_resolver(adapter) + return self.physical_owner diff --git a/sqlmesh/core/config/root.py b/sqlmesh/core/config/root.py index 211d271b01..fae30f6fbd 100644 --- a/sqlmesh/core/config/root.py +++ b/sqlmesh/core/config/root.py @@ -30,6 +30,7 @@ from sqlmesh.core.config.format import FormatConfig from sqlmesh.core.config.gateway import GatewayConfig from sqlmesh.core.config.janitor import JanitorConfig +from sqlmesh.core.config.ownership import OwnershipConfig from sqlmesh.core.config.migration import MigrationConfig from sqlmesh.core.config.model import ModelDefaultsConfig from sqlmesh.core.config.naming import NameInferenceConfig as NameInferenceConfig @@ -118,6 +119,7 @@ class Config(BaseConfig): gateway_managed_virtual_layer: Whether the models' views in the virtual layer are created by the model-specific gateway rather than the default gateway. infer_python_dependencies: Whether to statically analyze Python code to automatically infer Python package requirements. environment_catalog_mapping: A mapping from regular expressions to catalog names. The catalog name is used to determine the target catalog for a given environment. + ownership: Ownership rules applied at schema/view creation time. Maps environment name patterns to owner principals so objects are correctly owned even after a partial run. default_target_environment: The name of the environment that will be the default target for the `sqlmesh plan` and `sqlmesh run` commands. log_limit: The default number of logs to keep. format: The formatting options for SQL code. @@ -175,6 +177,7 @@ class Config(BaseConfig): janitor: JanitorConfig = JanitorConfig() cache_dir: t.Optional[str] = None dbt: t.Optional[DbtConfig] = None + ownership: OwnershipConfig = Field(default_factory=OwnershipConfig) _FIELD_UPDATE_STRATEGY: t.ClassVar[t.Dict[str, UpdateStrategy]] = { "gateways": UpdateStrategy.NESTED_UPDATE, @@ -194,6 +197,7 @@ class Config(BaseConfig): "after_all": UpdateStrategy.EXTEND, "linter": UpdateStrategy.NESTED_UPDATE, "dbt": UpdateStrategy.NESTED_UPDATE, + "ownership": UpdateStrategy.NESTED_UPDATE, } _connection_config_validator = connection_config_validator diff --git a/sqlmesh/core/config/scheduler.py b/sqlmesh/core/config/scheduler.py index 9d9d1d3c79..dd8f9caaa7 100644 --- a/sqlmesh/core/config/scheduler.py +++ b/sqlmesh/core/config/scheduler.py @@ -128,12 +128,15 @@ class BuiltInSchedulerConfig(_EngineAdapterStateSyncSchedulerConfig, BaseConfig) type_: t.Literal["builtin"] = Field(alias="type", default="builtin") def create_plan_evaluator(self, context: GenericContext) -> PlanEvaluator: + ownership = context.config.ownership + ownership_config = ownership if ownership.is_active else None return BuiltInPlanEvaluator( state_sync=context.state_sync, snapshot_evaluator=context.snapshot_evaluator, create_scheduler=context.create_scheduler, default_catalog=context.default_catalog, console=context.console, + ownership_config=ownership_config, ) def get_default_catalog_per_gateway(self, context: GenericContext) -> t.Dict[str, str]: diff --git a/sqlmesh/core/engine_adapter/base.py b/sqlmesh/core/engine_adapter/base.py index 5465ea1197..b6954d3f3d 100644 --- a/sqlmesh/core/engine_adapter/base.py +++ b/sqlmesh/core/engine_adapter/base.py @@ -1418,6 +1418,38 @@ def _create_schema( raise logger.warning("Failed to create %s '%s': %s", kind.lower(), schema_name, e) + def current_user(self) -> str: + """Return the identity of the currently-connected principal. + + Uses SQL ``CURRENT_USER()`` which is supported by Spark/Databricks and + DuckDB. Override in adapters where a different mechanism is required. + """ + row = self.fetchone("SELECT CURRENT_USER()") + if not row: + raise SQLMeshError("Could not determine current user: CURRENT_USER() returned no rows") + return row[0] + + def alter_schema_owner(self, schema_name: SchemaName, owner: str) -> None: + """Set the owner of a schema. + + No-op by default. Override in dialect-specific adapters that support ownership control + (e.g. Spark/Databricks Unity Catalog: ALTER SCHEMA ... OWNER TO ...). + """ + + def alter_view_owner(self, view_name: TableName, owner: str) -> None: + """Set the owner of a view. + + No-op by default. Override in dialect-specific adapters that support ownership control + (e.g. Spark/Databricks Unity Catalog: ALTER VIEW ... OWNER TO ...). + """ + + def alter_table_owner(self, table_name: TableName, owner: str) -> None: + """Set the owner of a table. + + No-op by default. Override in dialect-specific adapters that support ownership control + (e.g. Spark/Databricks Unity Catalog: ALTER TABLE ... OWNER TO ...). + """ + def drop_schema( self, schema_name: SchemaName, diff --git a/sqlmesh/core/engine_adapter/spark.py b/sqlmesh/core/engine_adapter/spark.py index 9199aa3bcd..411dbd69f8 100644 --- a/sqlmesh/core/engine_adapter/spark.py +++ b/sqlmesh/core/engine_adapter/spark.py @@ -553,6 +553,27 @@ def _build_create_comment_column_exp( return f"ALTER TABLE {table_sql} ALTER COLUMN {column_sql} COMMENT {comment_sql}" + def alter_schema_owner(self, schema_name: SchemaName, owner: str) -> None: + schema_sql = exp.to_table(schema_name, dialect=self.dialect).sql( + dialect=self.dialect, identify=True + ) + owner_sql = exp.to_identifier(owner, quoted=True).sql(dialect=self.dialect) + self.execute(f"ALTER SCHEMA {schema_sql} OWNER TO {owner_sql}") + + def alter_view_owner(self, view_name: TableName, owner: str) -> None: + view_sql = exp.to_table(view_name, dialect=self.dialect).sql( + dialect=self.dialect, identify=True + ) + owner_sql = exp.to_identifier(owner, quoted=True).sql(dialect=self.dialect) + self.execute(f"ALTER VIEW {view_sql} OWNER TO {owner_sql}") + + def alter_table_owner(self, table_name: TableName, owner: str) -> None: + table_sql = exp.to_table(table_name, dialect=self.dialect).sql( + dialect=self.dialect, identify=True + ) + owner_sql = exp.to_identifier(owner, quoted=True).sql(dialect=self.dialect) + self.execute(f"ALTER TABLE {table_sql} OWNER TO {owner_sql}") + @classmethod def _wap_branch_name(cls, wap_id: str) -> str: return f"{cls.WAP_PREFIX}{wap_id}" diff --git a/sqlmesh/core/plan/evaluator.py b/sqlmesh/core/plan/evaluator.py index f2f432a97e..e69bf4b78f 100644 --- a/sqlmesh/core/plan/evaluator.py +++ b/sqlmesh/core/plan/evaluator.py @@ -40,6 +40,7 @@ from sqlmesh.core.plan.common import identify_restatement_intervals_across_snapshot_versions from sqlmesh.utils import CorrelationId from sqlmesh.utils.concurrency import NodeExecutionFailedError +from sqlmesh.core.config.ownership import OwnershipConfig from sqlmesh.utils.errors import PlanError, ConflictingPlanError, SQLMeshError from sqlmesh.utils.date import now, to_timestamp @@ -74,12 +75,14 @@ def __init__( create_scheduler: t.Callable[[t.Iterable[Snapshot], SnapshotEvaluator], Scheduler], default_catalog: t.Optional[str], console: t.Optional[Console] = None, + ownership_config: t.Optional[OwnershipConfig] = None, ): self.state_sync = state_sync self.snapshot_evaluator = snapshot_evaluator self.create_scheduler = create_scheduler self.default_catalog = default_catalog self.console = console or get_console() + self.ownership_config = ownership_config self._circuit_breaker: t.Optional[t.Callable[[], bool]] = None def evaluate( @@ -172,6 +175,11 @@ def visit_physical_layer_update_stage( self.console.log_success(skip_message) return + physical_owner = ( + self.ownership_config.resolve_physical_owner(self.snapshot_evaluator.adapter) + if self.ownership_config + else None + ) completion_status = None progress_stopped = False try: @@ -185,6 +193,7 @@ def visit_physical_layer_update_stage( x, plan.environment, self.default_catalog ), on_complete=self.console.update_creation_progress, + owner=physical_owner, ) if completion_status.is_nothing_to_do: self.console.log_success(skip_message) @@ -209,9 +218,14 @@ def visit_physical_layer_update_stage( def visit_physical_layer_schema_creation_stage( self, stage: stages.PhysicalLayerSchemaCreationStage, plan: EvaluatablePlan ) -> None: + physical_owner = ( + self.ownership_config.resolve_physical_owner(self.snapshot_evaluator.adapter) + if self.ownership_config + else None + ) try: self.snapshot_evaluator.create_physical_schemas( - stage.snapshots, stage.deployability_index + stage.snapshots, stage.deployability_index, owner=physical_owner ) except Exception as ex: raise PlanError("Plan application failed.") from ex @@ -434,6 +448,11 @@ def _promote_snapshots( deployability_index: t.Optional[DeployabilityIndex] = None, on_complete: t.Optional[t.Callable[[SnapshotInfoLike], None]] = None, ) -> None: + owner: t.Optional[str] = None + if self.ownership_config: + owner = self.ownership_config.resolve_owner( + environment_naming_info.name, self.snapshot_evaluator.adapter + ) self.snapshot_evaluator.promote( target_snapshots, start=plan.start, @@ -449,6 +468,7 @@ def _promote_snapshots( environment_naming_info=environment_naming_info, deployability_index=deployability_index, on_complete=on_complete, + owner=owner, ) def _demote_snapshots( diff --git a/sqlmesh/core/snapshot/evaluator.py b/sqlmesh/core/snapshot/evaluator.py index 4df9ecb695..89ff5b6b82 100644 --- a/sqlmesh/core/snapshot/evaluator.py +++ b/sqlmesh/core/snapshot/evaluator.py @@ -275,6 +275,7 @@ def promote( snapshots: t.Optional[t.Dict[SnapshotId, Snapshot]] = None, table_mapping: t.Optional[t.Dict[str, str]] = None, on_complete: t.Optional[t.Callable[[SnapshotInfoLike], None]] = None, + owner: t.Optional[str] = None, ) -> None: """Promotes the given collection of snapshots in the target environment by replacing a corresponding view with a physical table associated with the given snapshot. @@ -306,7 +307,7 @@ def promote( gateway_table_pairs = [ (gateway, table) for gateway, tables in tables_by_gateway.items() for table in tables ] - self._create_schemas(gateway_table_pairs=gateway_table_pairs) + self._create_schemas(gateway_table_pairs=gateway_table_pairs, owner=owner) # Fetch the view data objects for the promoted snapshots to get them cached self._get_virtual_data_objects(target_snapshots, environment_naming_info) @@ -325,6 +326,7 @@ def promote( environment_naming_info=environment_naming_info, deployability_index=deployability_index, # type: ignore on_complete=on_complete, + owner=owner, ), self.ddl_concurrent_tasks, ) @@ -366,6 +368,7 @@ def create( on_complete: t.Optional[t.Callable[[SnapshotInfoLike], None]] = None, allow_destructive_snapshots: t.Optional[t.Set[str]] = None, allow_additive_snapshots: t.Optional[t.Set[str]] = None, + owner: t.Optional[str] = None, ) -> CompletionStatus: """Creates a physical snapshot schema and table for the given collection of snapshots. @@ -377,6 +380,7 @@ def create( on_complete: A callback to call on each successfully created snapshot. allow_destructive_snapshots: Set of snapshots that are allowed to have destructive schema changes. allow_additive_snapshots: Set of snapshots that are allowed to have additive schema changes. + owner: Optional principal to set as table owner after creation. Returns: CompletionStatus: The status of the creation operation (success, failure, nothing to do). @@ -396,17 +400,22 @@ def create( on_complete=on_complete, allow_destructive_snapshots=allow_destructive_snapshots or set(), allow_additive_snapshots=allow_additive_snapshots or set(), + owner=owner, ) return CompletionStatus.SUCCESS def create_physical_schemas( - self, snapshots: t.Iterable[Snapshot], deployability_index: DeployabilityIndex + self, + snapshots: t.Iterable[Snapshot], + deployability_index: DeployabilityIndex, + owner: t.Optional[str] = None, ) -> None: """Creates the physical schemas for the given snapshots. Args: snapshots: Snapshots to create physical schemas for. deployability_index: Determines snapshots that are deployable in the context of this creation. + owner: Optional principal to set as schema owner after creation. """ tables_by_gateway: t.Dict[t.Optional[str], t.List[str]] = defaultdict(list) for snapshot in snapshots: @@ -418,7 +427,7 @@ def create_physical_schemas( gateway_table_pairs = [ (gateway, table) for gateway, tables in tables_by_gateway.items() for table in tables ] - self._create_schemas(gateway_table_pairs=gateway_table_pairs) + self._create_schemas(gateway_table_pairs=gateway_table_pairs, owner=owner) def get_snapshots_to_create( self, target_snapshots: t.Iterable[Snapshot], deployability_index: DeployabilityIndex @@ -451,6 +460,7 @@ def _create_snapshots( on_complete: t.Optional[t.Callable[[SnapshotInfoLike], None]], allow_destructive_snapshots: t.Set[str], allow_additive_snapshots: t.Set[str], + owner: t.Optional[str] = None, ) -> None: """Internal method to create tables in parallel.""" with self.concurrent_context(): @@ -463,6 +473,7 @@ def _create_snapshots( allow_destructive_snapshots=allow_destructive_snapshots, allow_additive_snapshots=allow_additive_snapshots, on_complete=on_complete, + owner=owner, ), self.ddl_concurrent_tasks, raise_on_error=False, @@ -868,6 +879,7 @@ def create_snapshot( allow_destructive_snapshots: t.Set[str], allow_additive_snapshots: t.Set[str], on_complete: t.Optional[t.Callable[[SnapshotInfoLike], None]] = None, + owner: t.Optional[str] = None, ) -> None: """Creates a physical table for the given snapshot. @@ -878,6 +890,7 @@ def create_snapshot( on_complete: A callback to call on each successfully created database object. allow_destructive_snapshots: Snapshots for which destructive schema changes are allowed. allow_additive_snapshots: Snapshots for which additive schema changes are allowed. + owner: Optional principal to set as table owner after creation. """ if not snapshot.is_model: return @@ -905,6 +918,9 @@ def create_snapshot( **create_render_kwargs ) + is_table_deployable = deployability_index.is_deployable(snapshot) + table_name = snapshot.table_name(is_deployable=is_table_deployable) + if self._can_clone(snapshot, deployability_index): self._clone_snapshot_in_dev( snapshot=snapshot, @@ -917,10 +933,9 @@ def create_snapshot( run_pre_post_statements=True, ) else: - is_table_deployable = deployability_index.is_deployable(snapshot) self._execute_create( snapshot=snapshot, - table_name=snapshot.table_name(is_deployable=is_table_deployable), + table_name=table_name, is_table_deployable=is_table_deployable, deployability_index=deployability_index, create_render_kwargs=create_render_kwargs, @@ -928,6 +943,9 @@ def create_snapshot( dry_run=True, ) + if owner and not isinstance(snapshot.model.kind, ViewKind): + adapter.alter_table_owner(table_name, owner) + evaluation_strategy.run_post_statements( snapshot=snapshot, render_kwargs={**create_render_kwargs, "inside_transaction": False} ) @@ -1257,6 +1275,7 @@ def _promote_snapshot( execution_time: t.Optional[TimeLike] = None, snapshots: t.Optional[t.Dict[SnapshotId, Snapshot]] = None, table_mapping: t.Optional[t.Dict[str, str]] = None, + owner: t.Optional[str] = None, ) -> None: if not snapshot.is_model: return @@ -1298,6 +1317,9 @@ def _promote_snapshot( render_kwargs["snapshots"] = snapshot_by_name adapter.execute(snapshot.model.render_on_virtual_update(**render_kwargs)) + if owner: + adapter.alter_view_owner(view_name, owner) + if on_complete is not None: on_complete(snapshot) @@ -1449,6 +1471,7 @@ def _create_catalogs( def _create_schemas( self, gateway_table_pairs: t.Iterable[t.Tuple[t.Optional[str], t.Union[exp.Table, str]]], + owner: t.Optional[str] = None, ) -> None: table_exprs = [(gateway, exp.to_table(t)) for gateway, t in gateway_table_pairs] unique_schemas = { @@ -1464,6 +1487,8 @@ def _create_schema( logger.info("Creating schema '%s'", schema) adapter = self.get_adapter(gateway) adapter.create_schema(schema) + if owner: + adapter.alter_schema_owner(schema, owner) with self.concurrent_context(): concurrent_apply_to_values( diff --git a/tests/core/engine_adapter/test_spark.py b/tests/core/engine_adapter/test_spark.py index d7c3127f05..618bd381cb 100644 --- a/tests/core/engine_adapter/test_spark.py +++ b/tests/core/engine_adapter/test_spark.py @@ -1092,6 +1092,86 @@ def test_table_format(adapter: SparkEngineAdapter, mocker: MockerFixture): ] +def test_alter_schema_owner(make_mocked_engine_adapter: t.Callable): + adapter = make_mocked_engine_adapter(SparkEngineAdapter) + adapter.alter_schema_owner("catalog.my_schema", "svc_prod_spn") + assert to_sql_calls(adapter) == ["ALTER SCHEMA `catalog`.`my_schema` OWNER TO `svc_prod_spn`"] + + +def test_alter_schema_owner_three_part_name(make_mocked_engine_adapter: t.Callable): + # Schema references are typically 2-part (catalog.schema), but verify quoting is correct. + adapter = make_mocked_engine_adapter(SparkEngineAdapter) + adapter.alter_schema_owner("my_schema", "svc_prod_spn") + assert to_sql_calls(adapter) == ["ALTER SCHEMA `my_schema` OWNER TO `svc_prod_spn`"] + + +def test_alter_view_owner(make_mocked_engine_adapter: t.Callable): + adapter = make_mocked_engine_adapter(SparkEngineAdapter) + adapter.alter_view_owner("catalog.my_schema.my_view", "svc_prod_spn") + assert to_sql_calls(adapter) == [ + "ALTER VIEW `catalog`.`my_schema`.`my_view` OWNER TO `svc_prod_spn`" + ] + + +def test_alter_view_owner_special_chars_in_principal(make_mocked_engine_adapter: t.Callable): + # Databricks Unity Catalog principals can contain colons and @ signs. + # Verify they are safely backtick-quoted and not interpreted as SQL syntax. + adapter = make_mocked_engine_adapter(SparkEngineAdapter) + adapter.alter_view_owner("catalog.sushi__dev.orders", "group:devs@company.com") + assert to_sql_calls(adapter) == [ + "ALTER VIEW `catalog`.`sushi__dev`.`orders` OWNER TO `group:devs@company.com`" + ] + + +def test_alter_table_owner(make_mocked_engine_adapter: t.Callable): + adapter = make_mocked_engine_adapter(SparkEngineAdapter) + adapter.alter_table_owner("catalog.sqlmesh__sushi.orders__abc123", "svc_prod_spn") + assert to_sql_calls(adapter) == [ + "ALTER TABLE `catalog`.`sqlmesh__sushi`.`orders__abc123` OWNER TO `svc_prod_spn`" + ] + + +def test_alter_table_owner_special_chars_in_principal(make_mocked_engine_adapter: t.Callable): + # Databricks Unity Catalog principals can contain colons and @ signs. + adapter = make_mocked_engine_adapter(SparkEngineAdapter) + adapter.alter_table_owner("catalog.sqlmesh__sushi.orders__abc123", "group:data@company.com") + assert to_sql_calls(adapter) == [ + "ALTER TABLE `catalog`.`sqlmesh__sushi`.`orders__abc123` OWNER TO `group:data@company.com`" + ] + + +def test_alter_schema_owner_base_noop(make_mocked_engine_adapter: t.Callable): + # The base EngineAdapter.alter_schema_owner is a no-op: adapters that don't + # support ownership control silently skip it without emitting any SQL. + from sqlmesh.core.engine_adapter.duckdb import DuckDBEngineAdapter + + adapter = make_mocked_engine_adapter(DuckDBEngineAdapter) + adapter.alter_schema_owner("my_schema", "some_owner") + adapter.alter_view_owner("my_schema.my_view", "some_owner") + adapter.alter_table_owner("my_schema.my_table", "some_owner") + # No ALTER SQL should have been emitted + alter_calls = [s for s in to_sql_calls(adapter) if "OWNER" in s.upper()] + assert alter_calls == [] + + +def test_current_user(make_mocked_engine_adapter: t.Callable): + adapter = make_mocked_engine_adapter(SparkEngineAdapter) + adapter.cursor.fetchone.return_value = ("spn-abc-123",) + result = adapter.current_user() + assert result == "spn-abc-123" + sql_calls = to_sql_calls(adapter) + assert any("CURRENT_USER" in s.upper() for s in sql_calls) + + +def test_current_user_base_noop(make_mocked_engine_adapter: t.Callable): + from sqlmesh.core.engine_adapter.duckdb import DuckDBEngineAdapter + + adapter = make_mocked_engine_adapter(DuckDBEngineAdapter) + adapter.cursor.fetchone.return_value = ("duckdb-user",) + result = adapter.current_user() + assert result == "duckdb-user" + + def test_get_data_object_wap_branch(make_mocked_engine_adapter: t.Callable, mocker: MockerFixture): adapter = make_mocked_engine_adapter(SparkEngineAdapter, patch_get_data_objects=False) mocker.patch.object(adapter, "_get_data_objects", return_value=[]) diff --git a/tests/core/integration/test_config.py b/tests/core/integration/test_config.py index 5d571cd7c5..2845057079 100644 --- a/tests/core/integration/test_config.py +++ b/tests/core/integration/test_config.py @@ -15,6 +15,7 @@ GatewayConfig, ModelDefaultsConfig, DuckDBConnectionConfig, + OwnershipConfig, TableNamingConvention, AutoCategorizationMode, ) @@ -578,3 +579,47 @@ def test_auto_categorization(sushi_context: Context): sushi_context.get_snapshot("sushi.waiter_as_customer_by_day", raise_if_missing=True).version == version ) + + +def test_ownership_config_plan_applies_without_error( + tmp_path: Path, monkeypatch: MonkeyPatch +) -> None: + """OwnershipConfig flows through the full plan/apply lifecycle without errors. + + DuckDB's alter_schema_owner/alter_view_owner are no-ops, so we cannot verify + that ownership was actually changed — but we confirm the config plumbing + doesn't break schema creation, view promotion, or dev environment application. + """ + monkeypatch.chdir(tmp_path) + + config = Config( + model_defaults=ModelDefaultsConfig(dialect="duckdb"), + default_connection=DuckDBConnectionConfig(), + ownership=OwnershipConfig( + environment_owner_mapping={ + "^prod$": "svc_prod_owner", + ".*": "group:shared-developers", + } + ), + ) + + models_dir = tmp_path / "models" + models_dir.mkdir() + (models_dir / "model.sql").write_text( + """ + MODEL (name example_schema.test_model, kind FULL); + SELECT '1' AS a + """ + ) + + ctx = Context(config=config, paths=tmp_path) + + # Prod plan/apply — exercises virtual layer schema and view creation with ownership config + ctx.plan(auto_apply=True) + assert ctx.engine_adapter.table_exists("example_schema.test_model") + + # Dev plan/apply — exercises env-suffixed schema and view creation with ownership config + ctx.plan(environment="dev", include_unmodified=True, auto_apply=True) + metadata = DuckDBMetadata.from_context(ctx) + dev_schemas = {s for s in metadata.schemas if "__dev" in s} + assert len(dev_schemas) > 0 diff --git a/tests/core/test_config.py b/tests/core/test_config.py index 8c81a90b8d..582434c0c7 100644 --- a/tests/core/test_config.py +++ b/tests/core/test_config.py @@ -14,6 +14,7 @@ DuckDBConnectionConfig, GatewayConfig, ModelDefaultsConfig, + OwnershipConfig, BigQueryConnectionConfig, MotherDuckConnectionConfig, BuiltInSchedulerConfig, @@ -606,19 +607,189 @@ def test_load_duckdb_attach_config(tmp_path): assert config.gateways["another_gateway"].connection.catalogs.get("memory") == ":memory:" - attach_config_1 = config.gateways["another_gateway"].connection.catalogs.get("sqlite") - assert isinstance(attach_config_1, DuckDBAttachOptions) - assert attach_config_1.type == "sqlite" - assert attach_config_1.path == "test.db" - assert attach_config_1.read_only is False +# --------------------------------------------------------------------------- +# OwnershipConfig tests +# --------------------------------------------------------------------------- - attach_config_2 = config.gateways["another_gateway"].connection.catalogs.get("postgres") - assert isinstance(attach_config_2, DuckDBAttachOptions) - assert attach_config_2.type == "postgres" - assert attach_config_2.path == "dbname=postgres user=postgres host=127.0.0.1" - assert attach_config_2.read_only is True +def test_ownership_config_resolve_owner(): + mock_adapter = mock.MagicMock() + config = OwnershipConfig( + environment_owner_mapping={ + "^prod$": "svc_prod_spn", + ".*": "group:shared-developers", + } + ) + assert config.resolve_owner("prod", mock_adapter) == "svc_prod_spn" + assert config.resolve_owner("dev_alice", mock_adapter) == "group:shared-developers" + assert config.resolve_owner("staging", mock_adapter) == "group:shared-developers" + # "production" does not match ^prod$ so falls through to .* + assert config.resolve_owner("production", mock_adapter) == "group:shared-developers" + + +def test_ownership_config_empty_returns_none(): + mock_adapter = mock.MagicMock() + assert OwnershipConfig().resolve_owner("prod", mock_adapter) is None + assert OwnershipConfig().resolve_owner("dev_env", mock_adapter) is None + + +def test_ownership_config_first_match_wins(): + # The catch-all .* comes before a more specific pattern — it always wins. + # This documents the ordering contract: users must put specific patterns first. + mock_adapter = mock.MagicMock() + config = OwnershipConfig( + environment_owner_mapping={ + ".*": "catch_all_owner", + "^prod$": "prod_owner", + } + ) + assert config.resolve_owner("prod", mock_adapter) == "catch_all_owner" + + +def test_ownership_config_case_sensitive(): + # Patterns are compiled without re.IGNORECASE, so matching is case-sensitive. + mock_adapter = mock.MagicMock() + config = OwnershipConfig(environment_owner_mapping={"^prod$": "svc_prod"}) + assert config.resolve_owner("prod", mock_adapter) == "svc_prod" + assert config.resolve_owner("PROD", mock_adapter) is None + assert config.resolve_owner("Prod", mock_adapter) is None + + +def test_ownership_config_no_match_returns_none(): + mock_adapter = mock.MagicMock() + config = OwnershipConfig(environment_owner_mapping={"^prod$": "svc_prod"}) + assert config.resolve_owner("staging", mock_adapter) is None + assert config.resolve_owner("dev_bob", mock_adapter) is None + + +def test_ownership_config_deserialization_from_dict(): + # Simulates YAML/dict-based config loading (as produced by load_config_from_yaml). + mock_adapter = mock.MagicMock() + config = Config( + model_defaults=ModelDefaultsConfig(dialect="duckdb"), + ownership={ + "environment_owner_mapping": { + "^prod$": "svc_prod_spn", + ".*": "group:shared-developers", + } + }, + ) + assert config.ownership.resolve_owner("prod", mock_adapter) == "svc_prod_spn" + assert config.ownership.resolve_owner("dev", mock_adapter) == "group:shared-developers" + + +def test_ownership_config_nested_update(): + # Config.ownership uses UpdateStrategy.NESTED_UPDATE. + # When two Configs are merged, the second one's environment_owner_mapping + # replaces the first's (REPLACE semantics within OwnershipConfig since + # environment_owner_mapping has no explicit strategy). + mock_adapter = mock.MagicMock() + c1 = Config( + model_defaults=ModelDefaultsConfig(dialect="duckdb"), + ownership=OwnershipConfig(environment_owner_mapping={"^prod$": "spn_prod"}), + ) + c2 = Config( + model_defaults=ModelDefaultsConfig(dialect="duckdb"), + ownership=OwnershipConfig(environment_owner_mapping={".*": "grp_devs"}), + ) + merged = c1.update_with(c2) + # c2's mapping fully replaces c1's — the ^prod$ pattern is gone + assert merged.ownership.resolve_owner("prod", mock_adapter) == "grp_devs" + assert merged.ownership.resolve_owner("dev_alice", mock_adapter) == "grp_devs" + + +def test_config_ownership_defaults_to_empty(): + # Configs without an explicit ownership block have a no-op OwnershipConfig. + mock_adapter = mock.MagicMock() + config = Config(model_defaults=ModelDefaultsConfig(dialect="duckdb")) + assert config.ownership.environment_owner_mapping == {} + assert config.ownership.resolve_owner("prod", mock_adapter) is None + + +def test_ownership_config_physical_owner(): + # physical_owner is a simple optional string — no pattern matching. + config = OwnershipConfig(physical_owner="group:data-platform") + assert config.physical_owner == "group:data-platform" + + +def test_ownership_config_physical_owner_default_none(): + assert OwnershipConfig().physical_owner is None + + +def test_ownership_config_physical_owner_deserialization(): + config = Config( + model_defaults=ModelDefaultsConfig(dialect="duckdb"), + ownership={ + "environment_owner_mapping": {"^prod$": "svc_prod"}, + "physical_owner": "group:data-platform", + }, + ) + assert config.ownership.physical_owner == "group:data-platform" + assert config.ownership.resolve_owner("prod", mock.MagicMock()) == "svc_prod" + + +def test_ownership_config_resolve_owner_callable(): + # A callable resolver takes precedence over environment_owner_mapping and + # receives (env_name, adapter) so it can call adapter.current_user() etc. + mock_adapter = mock.MagicMock() + mock_adapter.current_user.return_value = "spn-dynamic-uuid" + + config = OwnershipConfig( + environment_owner_mapping={".*": "group:fallback"}, + environment_owner_resolver=lambda env, adapter: ( + adapter.current_user() if env == "prod" else "group:shared-developers" + ), + ) + + assert config.resolve_owner("prod", mock_adapter) == "spn-dynamic-uuid" + assert config.resolve_owner("dev_alice", mock_adapter) == "group:shared-developers" + mock_adapter.current_user.assert_called_once() + + +def test_ownership_config_resolver_overrides_mapping(): + # Resolver always wins when set, even if the mapping would also match. + mock_adapter = mock.MagicMock() + config = OwnershipConfig( + environment_owner_mapping={"^prod$": "static-owner"}, + environment_owner_resolver=lambda env, adapter: "dynamic-owner", + ) + assert config.resolve_owner("prod", mock_adapter) == "dynamic-owner" + + +def test_ownership_config_resolve_physical_owner_callable(): + mock_adapter = mock.MagicMock() + mock_adapter.current_user.return_value = "spn-uuid-123" + + config = OwnershipConfig( + physical_owner_resolver=lambda adapter: adapter.current_user(), + ) + assert config.resolve_physical_owner(mock_adapter) == "spn-uuid-123" + mock_adapter.current_user.assert_called_once() + + +def test_ownership_config_resolve_physical_owner_static(): + mock_adapter = mock.MagicMock() + config = OwnershipConfig(physical_owner="group:data-platform") + assert config.resolve_physical_owner(mock_adapter) == "group:data-platform" + mock_adapter.current_user.assert_not_called() + + +def test_ownership_config_physical_owner_resolver_overrides_static(): + mock_adapter = mock.MagicMock() + config = OwnershipConfig( + physical_owner="static-owner", + physical_owner_resolver=lambda adapter: "dynamic-owner", + ) + assert config.resolve_physical_owner(mock_adapter) == "dynamic-owner" + + +def test_ownership_config_is_active(): + assert not OwnershipConfig().is_active + assert OwnershipConfig(environment_owner_mapping={".*": "grp"}).is_active + assert OwnershipConfig(environment_owner_resolver=lambda e, a: None).is_active + assert OwnershipConfig(physical_owner="grp").is_active + assert OwnershipConfig(physical_owner_resolver=lambda a: "grp").is_active def test_load_model_defaults_audits(tmp_path): diff --git a/tests/core/test_snapshot_evaluator.py b/tests/core/test_snapshot_evaluator.py index d7aa9e4a80..a333fb1746 100644 --- a/tests/core/test_snapshot_evaluator.py +++ b/tests/core/test_snapshot_evaluator.py @@ -436,6 +436,152 @@ def test_promote_forward_only(mocker: MockerFixture, adapter_mock, make_snapshot ) +def test_promote_with_owner(mocker: MockerFixture, adapter_mock, make_snapshot): + """When owner is supplied, alter_schema_owner and alter_view_owner are called.""" + evaluator = SnapshotEvaluator(adapter_mock) + + model = SqlModel( + name="test_schema.test_model", + kind=IncrementalByTimeRangeKind(time_column="a"), + storage_format="parquet", + query=parse_one("SELECT a FROM tbl WHERE ds BETWEEN @start_ds and @end_ds"), + ) + snapshot = make_snapshot(model) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + + evaluator.promote( + [snapshot], EnvironmentNamingInfo(name="test_env"), owner="group:shared-developers" + ) + + adapter_mock.alter_schema_owner.assert_called_once_with( + to_schema("test_schema__test_env"), "group:shared-developers" + ) + adapter_mock.alter_view_owner.assert_called_once_with( + "test_schema__test_env.test_model", "group:shared-developers" + ) + + +def test_promote_without_owner_skips_alter(mocker: MockerFixture, adapter_mock, make_snapshot): + """When no owner is configured (the default), ownership DDL is never issued.""" + evaluator = SnapshotEvaluator(adapter_mock) + + model = SqlModel( + name="test_schema.test_model", + kind=IncrementalByTimeRangeKind(time_column="a"), + storage_format="parquet", + query=parse_one("SELECT a FROM tbl WHERE ds BETWEEN @start_ds and @end_ds"), + ) + snapshot = make_snapshot(model) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + + evaluator.promote([snapshot], EnvironmentNamingInfo(name="test_env")) + + adapter_mock.alter_schema_owner.assert_not_called() + adapter_mock.alter_view_owner.assert_not_called() + + +def test_promote_owner_applied_per_view(mocker: MockerFixture, adapter_mock, make_snapshot): + """alter_view_owner is called once per promoted snapshot.""" + evaluator = SnapshotEvaluator(adapter_mock) + + snapshots = [] + for name in ("model_a", "model_b", "model_c"): + model = SqlModel( + name=f"test_schema.{name}", + kind=ViewKind(), + query=parse_one("SELECT 1"), + ) + snapshot = make_snapshot(model) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + snapshots.append(snapshot) + + evaluator.promote(snapshots, EnvironmentNamingInfo(name="test_env"), owner="svc_prod_spn") + + assert adapter_mock.alter_view_owner.call_count == 3 + called_owners = {c.args[1] for c in adapter_mock.alter_view_owner.call_args_list} + assert called_owners == {"svc_prod_spn"} + + +def test_create_with_physical_owner(mocker: MockerFixture, adapter_mock, make_snapshot): + """alter_table_owner is called for each non-view table when physical owner is set.""" + adapter_mock.get_data_objects.return_value = [] + evaluator = SnapshotEvaluator(adapter_mock) + + model = SqlModel( + name="test_schema.test_model", + kind=IncrementalByTimeRangeKind(time_column="ds"), + storage_format="parquet", + query=parse_one("SELECT a, ds FROM tbl WHERE ds BETWEEN @start_ds AND @end_ds"), + ) + snapshot = make_snapshot(model) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + + evaluator.create([snapshot], {}, owner="group:data-platform") + + adapter_mock.alter_table_owner.assert_called_once() + call_args = adapter_mock.alter_table_owner.call_args + assert call_args.args[1] == "group:data-platform" + + +def test_create_without_physical_owner_skips_alter( + mocker: MockerFixture, adapter_mock, make_snapshot +): + """When no physical owner is set, alter_table_owner is never called.""" + adapter_mock.get_data_objects.return_value = [] + evaluator = SnapshotEvaluator(adapter_mock) + + model = SqlModel( + name="test_schema.test_model", + kind=IncrementalByTimeRangeKind(time_column="ds"), + storage_format="parquet", + query=parse_one("SELECT a, ds FROM tbl WHERE ds BETWEEN @start_ds AND @end_ds"), + ) + snapshot = make_snapshot(model) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + + evaluator.create([snapshot], {}) + + adapter_mock.alter_table_owner.assert_not_called() + + +def test_create_view_kind_skips_physical_owner(mocker: MockerFixture, adapter_mock, make_snapshot): + """ViewKind snapshots skip alter_table_owner even when physical_owner is set.""" + adapter_mock.get_data_objects.return_value = [] + evaluator = SnapshotEvaluator(adapter_mock) + + model = SqlModel( + name="test_schema.test_view", + kind=ViewKind(), + query=parse_one("SELECT 1"), + ) + snapshot = make_snapshot(model) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + + evaluator.create([snapshot], {}, owner="group:data-platform") + + adapter_mock.alter_table_owner.assert_not_called() + + +def test_create_physical_schemas_with_owner(mocker: MockerFixture, adapter_mock, make_snapshot): + """create_physical_schemas passes owner to _create_schemas so alter_schema_owner is called.""" + evaluator = SnapshotEvaluator(adapter_mock) + deployability_index = DeployabilityIndex.all_deployable() + + model = SqlModel( + name="test_schema.test_model", + kind=IncrementalByTimeRangeKind(time_column="ds"), + storage_format="parquet", + query=parse_one("SELECT a, ds FROM tbl WHERE ds BETWEEN @start_ds AND @end_ds"), + ) + snapshot = make_snapshot(model) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + + evaluator.create_physical_schemas([snapshot], deployability_index, owner="svc_prod_spn") + + adapter_mock.alter_schema_owner.assert_called_once() + assert adapter_mock.alter_schema_owner.call_args.args[1] == "svc_prod_spn" + + def test_cleanup(mocker: MockerFixture, adapter_mock, make_snapshot): evaluator = SnapshotEvaluator(adapter_mock)