Skip to content

implement stageOnly Commit #2269

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 17 additions & 16 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,9 @@ def _build_partition_predicate(self, partition_records: Set[Record]) -> BooleanE
expr = Or(expr, match_partition_expression)
return expr

def _append_snapshot_producer(self, snapshot_properties: Dict[str, str], branch: Optional[str]) -> _FastAppendFiles:
def _append_snapshot_producer(
self, snapshot_properties: Dict[str, str], branch: Optional[str] = MAIN_BRANCH
) -> _FastAppendFiles:
"""Determine the append type based on table properties.

Args:
Expand Down Expand Up @@ -430,15 +432,14 @@ def update_schema(self, allow_incompatible_changes: bool = False, case_sensitive
name_mapping=self.table_metadata.name_mapping(),
)

def update_snapshot(self, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: Optional[str] = None) -> UpdateSnapshot:
def update_snapshot(
self, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: Optional[str] = MAIN_BRANCH
) -> UpdateSnapshot:
"""Create a new UpdateSnapshot to produce a new snapshot for the table.

Returns:
A new UpdateSnapshot
"""
if branch is None:
branch = MAIN_BRANCH

return UpdateSnapshot(self, io=self._table.io, branch=branch, snapshot_properties=snapshot_properties)

def update_statistics(self) -> UpdateStatistics:
Expand All @@ -450,7 +451,7 @@ def update_statistics(self) -> UpdateStatistics:
"""
return UpdateStatistics(transaction=self)

def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: Optional[str] = None) -> None:
def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: Optional[str] = MAIN_BRANCH) -> None:
"""
Shorthand API for appending a PyArrow table to a table transaction.

Expand Down Expand Up @@ -495,7 +496,7 @@ def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT,
append_files.append_data_file(data_file)

def dynamic_partition_overwrite(
self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: Optional[str] = None
self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: Optional[str] = MAIN_BRANCH
) -> None:
"""
Shorthand for overwriting existing partitions with a PyArrow table.
Expand Down Expand Up @@ -562,7 +563,7 @@ def overwrite(
overwrite_filter: Union[BooleanExpression, str] = ALWAYS_TRUE,
snapshot_properties: Dict[str, str] = EMPTY_DICT,
case_sensitive: bool = True,
branch: Optional[str] = None,
branch: Optional[str] = MAIN_BRANCH,
) -> None:
"""
Shorthand for adding a table overwrite with a PyArrow table to the transaction.
Expand Down Expand Up @@ -628,7 +629,7 @@ def delete(
delete_filter: Union[str, BooleanExpression],
snapshot_properties: Dict[str, str] = EMPTY_DICT,
case_sensitive: bool = True,
branch: Optional[str] = None,
branch: Optional[str] = MAIN_BRANCH,
) -> None:
"""
Shorthand for deleting record from a table.
Expand Down Expand Up @@ -731,7 +732,7 @@ def upsert(
when_matched_update_all: bool = True,
when_not_matched_insert_all: bool = True,
case_sensitive: bool = True,
branch: Optional[str] = None,
branch: Optional[str] = MAIN_BRANCH,
) -> UpsertResult:
"""Shorthand API for performing an upsert to an iceberg table.

Expand Down Expand Up @@ -816,7 +817,7 @@ def upsert(
case_sensitive=case_sensitive,
)

if branch is not None:
if branch in self.table_metadata.refs:
matched_iceberg_record_batches_scan = matched_iceberg_record_batches_scan.use_ref(branch)

matched_iceberg_record_batches = matched_iceberg_record_batches_scan.to_arrow_batch_reader()
Expand Down Expand Up @@ -1307,7 +1308,7 @@ def upsert(
when_matched_update_all: bool = True,
when_not_matched_insert_all: bool = True,
case_sensitive: bool = True,
branch: Optional[str] = None,
branch: Optional[str] = MAIN_BRANCH,
) -> UpsertResult:
"""Shorthand API for performing an upsert to an iceberg table.

Expand Down Expand Up @@ -1354,7 +1355,7 @@ def upsert(
branch=branch,
)

def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: Optional[str] = None) -> None:
def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: Optional[str] = MAIN_BRANCH) -> None:
"""
Shorthand API for appending a PyArrow table to the table.

Expand All @@ -1367,7 +1368,7 @@ def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT,
tx.append(df=df, snapshot_properties=snapshot_properties, branch=branch)

def dynamic_partition_overwrite(
self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: Optional[str] = None
self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: Optional[str] = MAIN_BRANCH
) -> None:
"""Shorthand for dynamic overwriting the table with a PyArrow table.

Expand All @@ -1386,7 +1387,7 @@ def overwrite(
overwrite_filter: Union[BooleanExpression, str] = ALWAYS_TRUE,
snapshot_properties: Dict[str, str] = EMPTY_DICT,
case_sensitive: bool = True,
branch: Optional[str] = None,
branch: Optional[str] = MAIN_BRANCH,
) -> None:
"""
Shorthand for overwriting the table with a PyArrow table.
Expand Down Expand Up @@ -1419,7 +1420,7 @@ def delete(
delete_filter: Union[BooleanExpression, str] = ALWAYS_TRUE,
snapshot_properties: Dict[str, str] = EMPTY_DICT,
case_sensitive: bool = True,
branch: Optional[str] = None,
branch: Optional[str] = MAIN_BRANCH,
) -> None:
"""
Shorthand for deleting rows from the table.
Expand Down
4 changes: 3 additions & 1 deletion pyiceberg/table/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,8 +295,10 @@ def new_snapshot_id(self) -> int:

return snapshot_id

def snapshot_by_name(self, name: str) -> Optional[Snapshot]:
def snapshot_by_name(self, name: Optional[str]) -> Optional[Snapshot]:
"""Return the snapshot referenced by the given name or null if no such reference exists."""
if name is None:
name = MAIN_BRANCH
if ref := self.refs.get(name):
return self.snapshot_by_id(ref.snapshot_id)
return None
Expand Down
69 changes: 38 additions & 31 deletions pyiceberg/table/update/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ class _SnapshotProducer(UpdateTableMetadata[U], Generic[U]):
_manifest_num_counter: itertools.count[int]
_deleted_data_files: Set[DataFile]
_compression: AvroCompressionCodec
_target_branch = MAIN_BRANCH
_target_branch: Optional[str]

def __init__(
self,
Expand All @@ -117,7 +117,7 @@ def __init__(
io: FileIO,
commit_uuid: Optional[uuid.UUID] = None,
snapshot_properties: Dict[str, str] = EMPTY_DICT,
branch: str = MAIN_BRANCH,
branch: Optional[str] = MAIN_BRANCH,
) -> None:
super().__init__(transaction)
self.commit_uuid = commit_uuid or uuid.uuid4()
Expand All @@ -138,14 +138,13 @@ def __init__(
snapshot.snapshot_id if (snapshot := self._transaction.table_metadata.snapshot_by_name(self._target_branch)) else None
)

def _validate_target_branch(self, branch: str) -> str:
# Default is already set to MAIN_BRANCH. So branch name can't be None.
if branch is None:
raise ValueError("Invalid branch name: null")
if branch in self._transaction.table_metadata.refs:
ref = self._transaction.table_metadata.refs[branch]
if ref.snapshot_ref_type != SnapshotRefType.BRANCH:
raise ValueError(f"{branch} is a tag, not a branch. Tags cannot be targets for producing snapshots")
def _validate_target_branch(self, branch: Optional[str]) -> Optional[str]:
# if branch is none, write will be written into a staging snapshot
if branch is not None:
if branch in self._transaction.table_metadata.refs:
ref = self._transaction.table_metadata.refs[branch]
if ref.snapshot_ref_type != SnapshotRefType.BRANCH:
raise ValueError(f"{branch} is a tag, not a branch. Tags cannot be targets for producing snapshots")
return branch

def append_data_file(self, data_file: DataFile) -> _SnapshotProducer[U]:
Expand Down Expand Up @@ -292,25 +291,33 @@ def _commit(self) -> UpdatesAndRequirements:
schema_id=self._transaction.table_metadata.current_schema_id,
)

return (
(
AddSnapshotUpdate(snapshot=snapshot),
SetSnapshotRefUpdate(
snapshot_id=self._snapshot_id,
parent_snapshot_id=self._parent_snapshot_id,
ref_name=self._target_branch,
type=SnapshotRefType.BRANCH,
add_snapshot_update = AddSnapshotUpdate(snapshot=snapshot)

if self._target_branch is None:
return (
(add_snapshot_update,),
(),
)
else:
return (
(
add_snapshot_update,
SetSnapshotRefUpdate(
snapshot_id=self._snapshot_id,
parent_snapshot_id=self._parent_snapshot_id,
ref_name=self._target_branch,
type=SnapshotRefType.BRANCH,
),
),
),
(
AssertRefSnapshotId(
snapshot_id=self._transaction.table_metadata.refs[self._target_branch].snapshot_id
if self._target_branch in self._transaction.table_metadata.refs
else None,
ref=self._target_branch,
(
AssertRefSnapshotId(
snapshot_id=self._transaction.table_metadata.refs[self._target_branch].snapshot_id
if self._target_branch in self._transaction.table_metadata.refs
else None,
ref=self._target_branch,
),
),
),
)
)

@property
def snapshot_id(self) -> int:
Expand Down Expand Up @@ -357,7 +364,7 @@ def __init__(
operation: Operation,
transaction: Transaction,
io: FileIO,
branch: str,
branch: Optional[str] = MAIN_BRANCH,
commit_uuid: Optional[uuid.UUID] = None,
snapshot_properties: Dict[str, str] = EMPTY_DICT,
):
Expand Down Expand Up @@ -527,7 +534,7 @@ def __init__(
operation: Operation,
transaction: Transaction,
io: FileIO,
branch: str,
branch: Optional[str] = MAIN_BRANCH,
commit_uuid: Optional[uuid.UUID] = None,
snapshot_properties: Dict[str, str] = EMPTY_DICT,
) -> None:
Expand Down Expand Up @@ -648,14 +655,14 @@ def _get_entries(manifest: ManifestFile) -> List[ManifestEntry]:
class UpdateSnapshot:
_transaction: Transaction
_io: FileIO
_branch: str
_branch: Optional[str]
_snapshot_properties: Dict[str, str]

def __init__(
self,
transaction: Transaction,
io: FileIO,
branch: str,
branch: Optional[str] = MAIN_BRANCH,
snapshot_properties: Dict[str, str] = EMPTY_DICT,
) -> None:
self._transaction = transaction
Expand Down
58 changes: 58 additions & 0 deletions tests/integration/test_writes/test_partitioned_writes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1133,3 +1133,61 @@ def test_append_multiple_partitions(
"""
)
assert files_df.count() == 6


@pytest.mark.integration
@pytest.mark.parametrize("format_version", [1, 2])
def test_stage_only_dynamic_partition_overwrite_files(
spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int
) -> None:
identifier = f"default.test_stage_only_dynamic_partition_overwrite_files_v{format_version}"
try:
session_catalog.drop_table(identifier=identifier)
except NoSuchTableError:
pass
tbl = session_catalog.create_table(
identifier=identifier,
schema=TABLE_SCHEMA,
partition_spec=PartitionSpec(
PartitionField(source_id=1, field_id=1001, transform=IdentityTransform(), name="bool"),
PartitionField(source_id=4, field_id=1002, transform=IdentityTransform(), name="int"),
),
properties={"format-version": str(format_version)},
)

tbl.append(arrow_table_with_null)
current_snapshot = tbl.metadata.current_snapshot_id
assert current_snapshot is not None

original_count = len(tbl.scan().to_arrow())
assert original_count == 3

# write to staging snapshot
tbl.dynamic_partition_overwrite(arrow_table_with_null.slice(0, 1), branch=None)

assert current_snapshot == tbl.metadata.current_snapshot_id
assert len(tbl.scan().to_arrow()) == original_count
snapshots = tbl.snapshots()
# dynamic partition overwrite will create 2 snapshots, one delete and another append
assert len(snapshots) == 3

# Write to main branch
tbl.append(arrow_table_with_null)

# Main ref has changed
assert current_snapshot != tbl.metadata.current_snapshot_id
assert len(tbl.scan().to_arrow()) == 6
snapshots = tbl.snapshots()
assert len(snapshots) == 4

rows = spark.sql(
f"""
SELECT operation, parent_id, snapshot_id
FROM {identifier}.snapshots
ORDER BY committed_at ASC
"""
).collect()
operations = [row.operation for row in rows]
parent_snapshot_id = [row.parent_id for row in rows]
assert operations == ["append", "delete", "append", "append"]
assert parent_snapshot_id == [None, current_snapshot, current_snapshot, current_snapshot]
Loading