diff --git a/pyiceberg/table/update/snapshot.py b/pyiceberg/table/update/snapshot.py index 7931edacdd..de541cd4a2 100644 --- a/pyiceberg/table/update/snapshot.py +++ b/pyiceberg/table/update/snapshot.py @@ -894,6 +894,36 @@ def remove_tag(self, tag_name: str) -> ManageSnapshots: """ return self._remove_ref_snapshot(ref_name=tag_name) + def replace_tag(self, tag_name: str, snapshot_id: int) -> ManageSnapshots: + """ + Replace the tag with the given name to point to the specified snapshot. + + Args: + tag_name (str): Tag to replace + snapshot_id (int): new snapshot id for the given tag + Returns: + This for method chaining + """ + self._commit_if_ref_updates_exist() + + refs = self._transaction.table_metadata.refs + if tag_name not in refs: + raise ValueError(f"Tag does not exist: {tag_name}") + + ref = refs[tag_name] + if ref.snapshot_ref_type != SnapshotRefType.TAG: + raise ValueError(f"Ref {tag_name} is not a tag") + + update, requirement = self._transaction._set_ref_snapshot( + snapshot_id=snapshot_id, + ref_name=tag_name, + type=SnapshotRefType.TAG, + max_ref_age_ms=ref.max_ref_age_ms, + ) + self._updates += update + self._requirements += requirement + return self + def create_branch( self, snapshot_id: int, diff --git a/tests/integration/test_snapshot_operations.py b/tests/integration/test_snapshot_operations.py index 07fb77edbb..3e0f5644a8 100644 --- a/tests/integration/test_snapshot_operations.py +++ b/tests/integration/test_snapshot_operations.py @@ -23,7 +23,7 @@ from pyiceberg.catalog import Catalog from pyiceberg.table import Table -from pyiceberg.table.refs import SnapshotRef +from pyiceberg.table.refs import SnapshotRef, SnapshotRefType @pytest.fixture @@ -107,6 +107,55 @@ def test_remove_branch(catalog: Catalog) -> None: assert tbl.metadata.refs.get(branch_name, None) is None +@pytest.mark.integration +@pytest.mark.parametrize("catalog", [lf("session_catalog_hive"), lf("session_catalog")]) +def test_replace_tag(catalog: Catalog) -> None: + identifier = "default.test_table_snapshot_operations" + tbl = catalog.load_table(identifier) + assert len(tbl.history()) > 2 + + current_snapshot_id = tbl.history()[-1].snapshot_id + older_snapshot_id = tbl.history()[-2].snapshot_id + + tag_name = "my-tag" + tbl.manage_snapshots().create_tag(older_snapshot_id, tag_name, 1).commit() + tag = tbl.metadata.refs.get(tag_name) + assert tag is not None + assert tag.snapshot_id == older_snapshot_id + assert tag.snapshot_ref_type == SnapshotRefType.TAG + assert tag.max_ref_age_ms == 1 + + tbl.manage_snapshots().replace_tag(tag_name=tag_name, snapshot_id=current_snapshot_id).commit() + + tag = tbl.metadata.refs.get(tag_name) + assert tag is not None + assert tag.snapshot_id == current_snapshot_id + assert tag.snapshot_ref_type == SnapshotRefType.TAG + assert tag.max_ref_age_ms == 1 + + +@pytest.mark.integration +@pytest.mark.parametrize("catalog", [lf("session_catalog_hive"), lf("session_catalog")]) +def test_replace_missing_tag(catalog: Catalog) -> None: + identifier = "default.test_table_snapshot_operations" + tbl = catalog.load_table(identifier) + snapshot_id = tbl.history()[-1].snapshot_id + + with pytest.raises(ValueError, match="Tag does not exist: test"): + tbl.manage_snapshots().replace_tag(tag_name="test", snapshot_id=snapshot_id).commit() + + +@pytest.mark.integration +@pytest.mark.parametrize("catalog", [lf("session_catalog_hive"), lf("session_catalog")]) +def test_replace_tag_with_branch(catalog: Catalog) -> None: + identifier = "default.test_table_snapshot_operations" + tbl = catalog.load_table(identifier) + snapshot_id = tbl.history()[-1].snapshot_id + + with pytest.raises(ValueError, match="Ref main is not a tag"): + tbl.manage_snapshots().replace_tag(tag_name="main", snapshot_id=snapshot_id).commit() + + @pytest.mark.integration @pytest.mark.parametrize("catalog", [lf("session_catalog_hive"), lf("session_catalog")]) def test_set_current_snapshot(catalog: Catalog) -> None: