Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions pyiceberg/table/update/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -894,6 +894,38 @@ def remove_tag(self, tag_name: str) -> ManageSnapshots:
"""
return self._remove_ref_snapshot(ref_name=tag_name)

def replace_branch(self, branch_name: str, snapshot_id: int) -> ManageSnapshots:
"""
Replace the branch with the given name to point to the specified snapshot.

Args:
branch_name (str): Branch to replace
snapshot_id (int): new snapshot id for the given branch
Returns:
This for method chaining
"""
self._commit_if_ref_updates_exist()

refs = self._transaction.table_metadata.refs
if branch_name not in refs:
raise ValueError(f"Branch does not exist: {branch_name}")

ref = refs[branch_name]
if ref.snapshot_ref_type != SnapshotRefType.BRANCH:
raise ValueError(f"Ref {branch_name} is not a branch")

update, requirement = self._transaction._set_ref_snapshot(
snapshot_id=snapshot_id,
ref_name=branch_name,
type=SnapshotRefType.BRANCH,
max_ref_age_ms=ref.max_ref_age_ms,
max_snapshot_age_ms=ref.max_snapshot_age_ms,
min_snapshots_to_keep=ref.min_snapshots_to_keep,
)
self._updates += update
self._requirements += requirement
return self

def create_branch(
self,
snapshot_id: int,
Expand Down
58 changes: 57 additions & 1 deletion tests/integration/test_snapshot_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from pyiceberg.catalog import Catalog
from pyiceberg.table import Table
from pyiceberg.table.refs import SnapshotRef
from pyiceberg.table.refs import SnapshotRef, SnapshotRefType


@pytest.fixture
Expand Down Expand Up @@ -107,6 +107,62 @@ def test_remove_branch(catalog: Catalog) -> None:
assert tbl.metadata.refs.get(branch_name, None) is None


@pytest.mark.integration
@pytest.mark.parametrize("catalog", [lf("session_catalog_hive"), lf("session_catalog")])
def test_replace_branch(catalog: Catalog) -> None:
identifier = "default.test_table_snapshot_operations"
tbl = catalog.load_table(identifier)
assert len(tbl.history()) > 2

current_snapshot_id = tbl.history()[-1].snapshot_id
older_snapshot_id = tbl.history()[-2].snapshot_id

branch_name = "my-branch"
tbl.manage_snapshots().create_branch(older_snapshot_id, branch_name, 1, 2, 3).commit()
branch = tbl.metadata.refs.get(branch_name)
assert branch is not None
assert branch.snapshot_id == older_snapshot_id
assert branch.snapshot_ref_type == SnapshotRefType.BRANCH
assert branch.max_ref_age_ms == 1
assert branch.max_snapshot_age_ms == 2
assert branch.min_snapshots_to_keep == 3

tbl.manage_snapshots().replace_branch(branch_name=branch_name, snapshot_id=current_snapshot_id).commit()

branch = tbl.metadata.refs.get(branch_name)
assert branch is not None
assert branch.snapshot_id == current_snapshot_id
assert branch.snapshot_ref_type == SnapshotRefType.BRANCH
assert branch.max_ref_age_ms == 1
assert branch.max_snapshot_age_ms == 2
assert branch.min_snapshots_to_keep == 3


@pytest.mark.integration
@pytest.mark.parametrize("catalog", [lf("session_catalog_hive"), lf("session_catalog")])
def test_replace_missing_branch(catalog: Catalog) -> None:
identifier = "default.test_table_snapshot_operations"
tbl = catalog.load_table(identifier)
snapshot_id = tbl.history()[-1].snapshot_id

with pytest.raises(ValueError, match="Branch does not exist: test"):
tbl.manage_snapshots().replace_branch(branch_name="test", snapshot_id=snapshot_id).commit()


@pytest.mark.integration
@pytest.mark.parametrize("catalog", [lf("session_catalog_hive"), lf("session_catalog")])
def test_replace_branch_with_tag(catalog: Catalog) -> None:
identifier = "default.test_table_snapshot_operations"
tbl = catalog.load_table(identifier)
snapshot_id = tbl.history()[-1].snapshot_id

tag_name = "my-tag"
tbl.manage_snapshots().create_tag(snapshot_id=snapshot_id, tag_name=tag_name).commit()

with pytest.raises(ValueError, match="Ref my-tag is not a branch"):
tbl.manage_snapshots().replace_branch(branch_name=tag_name, snapshot_id=snapshot_id).commit()


@pytest.mark.integration
@pytest.mark.parametrize("catalog", [lf("session_catalog_hive"), lf("session_catalog")])
def test_set_current_snapshot(catalog: Catalog) -> None:
Expand Down