diff --git a/pyiceberg/table/update/snapshot.py b/pyiceberg/table/update/snapshot.py index 7931edacdd..b8be1793ef 100644 --- a/pyiceberg/table/update/snapshot.py +++ b/pyiceberg/table/update/snapshot.py @@ -894,6 +894,38 @@ def remove_tag(self, tag_name: str) -> ManageSnapshots: """ return self._remove_ref_snapshot(ref_name=tag_name) + def replace_branch(self, branch_name: str, snapshot_id: int) -> ManageSnapshots: + """ + Replace the branch with the given name to point to the specified snapshot. + + Args: + branch_name (str): Branch to replace + snapshot_id (int): new snapshot id for the given branch + Returns: + This for method chaining + """ + self._commit_if_ref_updates_exist() + + refs = self._transaction.table_metadata.refs + if branch_name not in refs: + raise ValueError(f"Branch does not exist: {branch_name}") + + ref = refs[branch_name] + if ref.snapshot_ref_type != SnapshotRefType.BRANCH: + raise ValueError(f"Ref {branch_name} is not a branch") + + update, requirement = self._transaction._set_ref_snapshot( + snapshot_id=snapshot_id, + ref_name=branch_name, + type=SnapshotRefType.BRANCH, + max_ref_age_ms=ref.max_ref_age_ms, + max_snapshot_age_ms=ref.max_snapshot_age_ms, + min_snapshots_to_keep=ref.min_snapshots_to_keep, + ) + self._updates += update + self._requirements += requirement + return self + def create_branch( self, snapshot_id: int, diff --git a/tests/integration/test_snapshot_operations.py b/tests/integration/test_snapshot_operations.py index 07fb77edbb..d4ba5a147d 100644 --- a/tests/integration/test_snapshot_operations.py +++ b/tests/integration/test_snapshot_operations.py @@ -23,7 +23,7 @@ from pyiceberg.catalog import Catalog from pyiceberg.table import Table -from pyiceberg.table.refs import SnapshotRef +from pyiceberg.table.refs import SnapshotRef, SnapshotRefType @pytest.fixture @@ -107,6 +107,62 @@ def test_remove_branch(catalog: Catalog) -> None: assert tbl.metadata.refs.get(branch_name, None) is None +@pytest.mark.integration +@pytest.mark.parametrize("catalog", [lf("session_catalog_hive"), lf("session_catalog")]) +def test_replace_branch(catalog: Catalog) -> None: + identifier = "default.test_table_snapshot_operations" + tbl = catalog.load_table(identifier) + assert len(tbl.history()) > 2 + + current_snapshot_id = tbl.history()[-1].snapshot_id + older_snapshot_id = tbl.history()[-2].snapshot_id + + branch_name = "my-branch" + tbl.manage_snapshots().create_branch(older_snapshot_id, branch_name, 1, 2, 3).commit() + branch = tbl.metadata.refs.get(branch_name) + assert branch is not None + assert branch.snapshot_id == older_snapshot_id + assert branch.snapshot_ref_type == SnapshotRefType.BRANCH + assert branch.max_ref_age_ms == 1 + assert branch.max_snapshot_age_ms == 2 + assert branch.min_snapshots_to_keep == 3 + + tbl.manage_snapshots().replace_branch(branch_name=branch_name, snapshot_id=current_snapshot_id).commit() + + branch = tbl.metadata.refs.get(branch_name) + assert branch is not None + assert branch.snapshot_id == current_snapshot_id + assert branch.snapshot_ref_type == SnapshotRefType.BRANCH + assert branch.max_ref_age_ms == 1 + assert branch.max_snapshot_age_ms == 2 + assert branch.min_snapshots_to_keep == 3 + + +@pytest.mark.integration +@pytest.mark.parametrize("catalog", [lf("session_catalog_hive"), lf("session_catalog")]) +def test_replace_missing_branch(catalog: Catalog) -> None: + identifier = "default.test_table_snapshot_operations" + tbl = catalog.load_table(identifier) + snapshot_id = tbl.history()[-1].snapshot_id + + with pytest.raises(ValueError, match="Branch does not exist: test"): + tbl.manage_snapshots().replace_branch(branch_name="test", snapshot_id=snapshot_id).commit() + + +@pytest.mark.integration +@pytest.mark.parametrize("catalog", [lf("session_catalog_hive"), lf("session_catalog")]) +def test_replace_branch_with_tag(catalog: Catalog) -> None: + identifier = "default.test_table_snapshot_operations" + tbl = catalog.load_table(identifier) + snapshot_id = tbl.history()[-1].snapshot_id + + tag_name = "my-tag" + tbl.manage_snapshots().create_tag(snapshot_id=snapshot_id, tag_name=tag_name).commit() + + with pytest.raises(ValueError, match="Ref my-tag is not a branch"): + tbl.manage_snapshots().replace_branch(branch_name=tag_name, snapshot_id=snapshot_id).commit() + + @pytest.mark.integration @pytest.mark.parametrize("catalog", [lf("session_catalog_hive"), lf("session_catalog")]) def test_set_current_snapshot(catalog: Catalog) -> None: