From 7181ae166ee88f6abcc8bfe4acbfdc19c0f58edc Mon Sep 17 00:00:00 2001 From: Yingjian Wu Date: Sun, 3 Aug 2025 16:22:22 -0700 Subject: [PATCH 1/4] implement stageOnly Commit --- pyiceberg/table/__init__.py | 8 +- pyiceberg/table/update/snapshot.py | 58 ++++--- tests/integration/test_writes/test_writes.py | 170 +++++++++++++++++++ 3 files changed, 215 insertions(+), 21 deletions(-) diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 7d5cc10de5..f392ad30d3 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -430,7 +430,9 @@ def update_schema(self, allow_incompatible_changes: bool = False, case_sensitive name_mapping=self.table_metadata.name_mapping(), ) - def update_snapshot(self, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: Optional[str] = None) -> UpdateSnapshot: + def update_snapshot( + self, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: Optional[str] = None, stage_only: bool = False + ) -> UpdateSnapshot: """Create a new UpdateSnapshot to produce a new snapshot for the table. Returns: @@ -439,7 +441,9 @@ def update_snapshot(self, snapshot_properties: Dict[str, str] = EMPTY_DICT, bran if branch is None: branch = MAIN_BRANCH - return UpdateSnapshot(self, io=self._table.io, branch=branch, snapshot_properties=snapshot_properties) + return UpdateSnapshot( + self, io=self._table.io, branch=branch, snapshot_properties=snapshot_properties, stage_only=stage_only + ) def update_statistics(self) -> UpdateStatistics: """ diff --git a/pyiceberg/table/update/snapshot.py b/pyiceberg/table/update/snapshot.py index 3ffb275ded..a9322578be 100644 --- a/pyiceberg/table/update/snapshot.py +++ b/pyiceberg/table/update/snapshot.py @@ -109,6 +109,7 @@ class _SnapshotProducer(UpdateTableMetadata[U], Generic[U]): _deleted_data_files: Set[DataFile] _compression: AvroCompressionCodec _target_branch = MAIN_BRANCH + _stage_only = False def __init__( self, @@ -118,6 +119,7 @@ def __init__( commit_uuid: Optional[uuid.UUID] = None, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: str = MAIN_BRANCH, + stage_only: bool = False, ) -> None: super().__init__(transaction) self.commit_uuid = commit_uuid or uuid.uuid4() @@ -137,6 +139,7 @@ def __init__( self._parent_snapshot_id = ( snapshot.snapshot_id if (snapshot := self._transaction.table_metadata.snapshot_by_name(self._target_branch)) else None ) + self._stage_only = stage_only def _validate_target_branch(self, branch: str) -> str: # Default is already set to MAIN_BRANCH. So branch name can't be None. @@ -292,25 +295,33 @@ def _commit(self) -> UpdatesAndRequirements: schema_id=self._transaction.table_metadata.current_schema_id, ) - return ( - ( - AddSnapshotUpdate(snapshot=snapshot), - SetSnapshotRefUpdate( - snapshot_id=self._snapshot_id, - parent_snapshot_id=self._parent_snapshot_id, - ref_name=self._target_branch, - type=SnapshotRefType.BRANCH, + add_snapshot_update = AddSnapshotUpdate(snapshot=snapshot) + + if self._stage_only: + return ( + (add_snapshot_update,), + (), + ) + else: + return ( + ( + add_snapshot_update, + SetSnapshotRefUpdate( + snapshot_id=self._snapshot_id, + parent_snapshot_id=self._parent_snapshot_id, + ref_name=self._target_branch, + type=SnapshotRefType.BRANCH, + ), ), - ), - ( - AssertRefSnapshotId( - snapshot_id=self._transaction.table_metadata.refs[self._target_branch].snapshot_id - if self._target_branch in self._transaction.table_metadata.refs - else None, - ref=self._target_branch, + ( + AssertRefSnapshotId( + snapshot_id=self._transaction.table_metadata.refs[self._target_branch].snapshot_id + if self._target_branch in self._transaction.table_metadata.refs + else None, + ref=self._target_branch, + ), ), - ), - ) + ) @property def snapshot_id(self) -> int: @@ -360,8 +371,9 @@ def __init__( branch: str, commit_uuid: Optional[uuid.UUID] = None, snapshot_properties: Dict[str, str] = EMPTY_DICT, + stage_only: bool = False, ): - super().__init__(operation, transaction, io, commit_uuid, snapshot_properties, branch) + super().__init__(operation, transaction, io, commit_uuid, snapshot_properties, branch, stage_only) self._predicate = AlwaysFalse() self._case_sensitive = True @@ -530,10 +542,11 @@ def __init__( branch: str, commit_uuid: Optional[uuid.UUID] = None, snapshot_properties: Dict[str, str] = EMPTY_DICT, + stage_only: bool = False, ) -> None: from pyiceberg.table import TableProperties - super().__init__(operation, transaction, io, commit_uuid, snapshot_properties, branch) + super().__init__(operation, transaction, io, commit_uuid, snapshot_properties, branch, stage_only) self._target_size_bytes = property_as_int( self._transaction.table_metadata.properties, TableProperties.MANIFEST_TARGET_SIZE_BYTES, @@ -649,6 +662,7 @@ class UpdateSnapshot: _transaction: Transaction _io: FileIO _branch: str + _stage_only: bool _snapshot_properties: Dict[str, str] def __init__( @@ -656,12 +670,14 @@ def __init__( transaction: Transaction, io: FileIO, branch: str, + stage_only: bool = False, snapshot_properties: Dict[str, str] = EMPTY_DICT, ) -> None: self._transaction = transaction self._io = io self._snapshot_properties = snapshot_properties self._branch = branch + self._stage_only = stage_only def fast_append(self) -> _FastAppendFiles: return _FastAppendFiles( @@ -670,6 +686,7 @@ def fast_append(self) -> _FastAppendFiles: io=self._io, branch=self._branch, snapshot_properties=self._snapshot_properties, + stage_only=self._stage_only, ) def merge_append(self) -> _MergeAppendFiles: @@ -679,6 +696,7 @@ def merge_append(self) -> _MergeAppendFiles: io=self._io, branch=self._branch, snapshot_properties=self._snapshot_properties, + stage_only=self._stage_only, ) def overwrite(self, commit_uuid: Optional[uuid.UUID] = None) -> _OverwriteFiles: @@ -691,6 +709,7 @@ def overwrite(self, commit_uuid: Optional[uuid.UUID] = None) -> _OverwriteFiles: io=self._io, branch=self._branch, snapshot_properties=self._snapshot_properties, + stage_only=self._stage_only, ) def delete(self) -> _DeleteFiles: @@ -700,6 +719,7 @@ def delete(self) -> _DeleteFiles: io=self._io, branch=self._branch, snapshot_properties=self._snapshot_properties, + stage_only=self._stage_only, ) diff --git a/tests/integration/test_writes/test_writes.py b/tests/integration/test_writes/test_writes.py index 38aea1e255..51f01b77dd 100644 --- a/tests/integration/test_writes/test_writes.py +++ b/tests/integration/test_writes/test_writes.py @@ -2261,3 +2261,173 @@ def test_nanosecond_support_on_catalog(session_catalog: Catalog) -> None: ) _create_table(session_catalog, identifier, {"format-version": "3"}, schema=table.schema) + + +@pytest.mark.parametrize("format_version", [1, 2]) +def test_stage_only_delete( + spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int +) -> None: + identifier = f"default.test_stage_only_delete_files_v{format_version}" + iceberg_spec = PartitionSpec( + *[PartitionField(source_id=4, field_id=1001, transform=IdentityTransform(), name="integer_partition")] + ) + tbl = _create_table( + session_catalog, identifier, {"format-version": str(format_version)}, [arrow_table_with_null], iceberg_spec + ) + + current_snapshot = tbl.metadata.current_snapshot_id + assert current_snapshot is not None + + original_count = len(tbl.scan().to_arrow()) + assert original_count == 3 + + files_to_delete = [] + for file_task in tbl.scan().plan_files(): + files_to_delete.append(file_task.file) + assert len(files_to_delete) > 0 + + with tbl.transaction() as txn: + with txn.update_snapshot(stage_only=True).delete() as delete: + delete.delete_by_predicate(EqualTo("int", 9)) + + # a new delete snapshot is added + snapshots = tbl.snapshots() + assert len(snapshots) == 2 + + rows = spark.sql( + f""" + SELECT operation, summary + FROM {identifier}.snapshots + ORDER BY committed_at ASC + """ + ).collect() + operations = [row.operation for row in rows] + assert operations == ["append", "delete"] + + # snapshot main ref has not changed + assert current_snapshot == tbl.metadata.current_snapshot_id + assert len(tbl.scan().to_arrow()) == original_count + + +@pytest.mark.integration +@pytest.mark.parametrize("format_version", [1, 2]) +def test_stage_only_fast_append( + spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int +) -> None: + identifier = f"default.test_stage_only_fast_append_files_v{format_version}" + tbl = _create_table(session_catalog, identifier, {"format-version": str(format_version)}, [arrow_table_with_null]) + + current_snapshot = tbl.metadata.current_snapshot_id + assert current_snapshot is not None + + original_count = len(tbl.scan().to_arrow()) + assert original_count == 3 + + with tbl.transaction() as txn: + with txn.update_snapshot(stage_only=True).fast_append() as fast_append: + for data_file in _dataframe_to_data_files( + table_metadata=txn.table_metadata, df=arrow_table_with_null, io=txn._table.io + ): + fast_append.append_data_file(data_file=data_file) + + # Main ref has not changed and data is not yet appended + assert current_snapshot == tbl.metadata.current_snapshot_id + assert len(tbl.scan().to_arrow()) == original_count + + # There should be a new staged snapshot + snapshots = tbl.snapshots() + assert len(snapshots) == 2 + + rows = spark.sql( + f""" + SELECT operation, summary + FROM {identifier}.snapshots + ORDER BY committed_at ASC + """ + ).collect() + operations = [row.operation for row in rows] + assert operations == ["append", "append"] + + +@pytest.mark.integration +@pytest.mark.parametrize("format_version", [1, 2]) +def test_stage_only_merge_append( + spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int +) -> None: + identifier = f"default.test_stage_only_merge_append_files_v{format_version}" + tbl = _create_table(session_catalog, identifier, {"format-version": str(format_version)}, [arrow_table_with_null]) + + current_snapshot = tbl.metadata.current_snapshot_id + assert current_snapshot is not None + + original_count = len(tbl.scan().to_arrow()) + assert original_count == 3 + + with tbl.transaction() as txn: + with txn.update_snapshot(stage_only=True).merge_append() as merge_append: + for data_file in _dataframe_to_data_files( + table_metadata=txn.table_metadata, df=arrow_table_with_null, io=txn._table.io + ): + merge_append.append_data_file(data_file=data_file) + + # Main ref has not changed and data is not yet appended + assert current_snapshot == tbl.metadata.current_snapshot_id + assert len(tbl.scan().to_arrow()) == original_count + + # There should be a new staged snapshot + snapshots = tbl.snapshots() + assert len(snapshots) == 2 + + rows = spark.sql( + f""" + SELECT operation, summary + FROM {identifier}.snapshots + ORDER BY committed_at ASC + """ + ).collect() + operations = [row.operation for row in rows] + assert operations == ["append", "append"] + + +@pytest.mark.integration +@pytest.mark.parametrize("format_version", [1, 2]) +def test_stage_only_overwrite_files( + spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int +) -> None: + identifier = f"default.test_stage_only_overwrite_files_v{format_version}" + tbl = _create_table(session_catalog, identifier, {"format-version": str(format_version)}, [arrow_table_with_null]) + + current_snapshot = tbl.metadata.current_snapshot_id + assert current_snapshot is not None + + original_count = len(tbl.scan().to_arrow()) + assert original_count == 3 + + files_to_delete = [] + for file_task in tbl.scan().plan_files(): + files_to_delete.append(file_task.file) + assert len(files_to_delete) > 0 + + with tbl.transaction() as txn: + with txn.update_snapshot(stage_only=True).overwrite() as overwrite: + for data_file in _dataframe_to_data_files( + table_metadata=txn.table_metadata, df=arrow_table_with_null, io=txn._table.io + ): + overwrite.append_data_file(data_file=data_file) + overwrite.delete_data_file(files_to_delete[0]) + + assert current_snapshot == tbl.metadata.current_snapshot_id + assert len(tbl.scan().to_arrow()) == original_count + + snapshots = tbl.snapshots() + assert len(snapshots) == 2 + + rows = spark.sql( + f""" + SELECT operation, summary + FROM {identifier}.snapshots + ORDER BY committed_at ASC + """ + ).collect() + operations = [row.operation for row in rows] + assert operations == ["append", "overwrite"] From 884eca90e1f0830919fa5853b9b10bfe261b51b6 Mon Sep 17 00:00:00 2001 From: Yingjian Wu Date: Sun, 24 Aug 2025 10:33:12 -0700 Subject: [PATCH 2/4] wip wip --- pyiceberg/table/__init__.py | 35 ++++++++-------- pyiceberg/table/metadata.py | 4 +- pyiceberg/table/update/snapshot.py | 43 +++++++------------- tests/integration/test_writes/test_writes.py | 8 ++-- 4 files changed, 38 insertions(+), 52 deletions(-) diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index f392ad30d3..0ba5f452a8 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -397,7 +397,9 @@ def _build_partition_predicate(self, partition_records: Set[Record]) -> BooleanE expr = Or(expr, match_partition_expression) return expr - def _append_snapshot_producer(self, snapshot_properties: Dict[str, str], branch: Optional[str]) -> _FastAppendFiles: + def _append_snapshot_producer( + self, snapshot_properties: Dict[str, str], branch: Optional[str] = MAIN_BRANCH + ) -> _FastAppendFiles: """Determine the append type based on table properties. Args: @@ -431,19 +433,14 @@ def update_schema(self, allow_incompatible_changes: bool = False, case_sensitive ) def update_snapshot( - self, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: Optional[str] = None, stage_only: bool = False + self, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: Optional[str] = MAIN_BRANCH ) -> UpdateSnapshot: """Create a new UpdateSnapshot to produce a new snapshot for the table. Returns: A new UpdateSnapshot """ - if branch is None: - branch = MAIN_BRANCH - - return UpdateSnapshot( - self, io=self._table.io, branch=branch, snapshot_properties=snapshot_properties, stage_only=stage_only - ) + return UpdateSnapshot(self, io=self._table.io, branch=branch, snapshot_properties=snapshot_properties) def update_statistics(self) -> UpdateStatistics: """ @@ -454,7 +451,7 @@ def update_statistics(self) -> UpdateStatistics: """ return UpdateStatistics(transaction=self) - def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: Optional[str] = None) -> None: + def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: Optional[str] = MAIN_BRANCH) -> None: """ Shorthand API for appending a PyArrow table to a table transaction. @@ -499,7 +496,7 @@ def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT, append_files.append_data_file(data_file) def dynamic_partition_overwrite( - self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: Optional[str] = None + self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: Optional[str] = MAIN_BRANCH ) -> None: """ Shorthand for overwriting existing partitions with a PyArrow table. @@ -566,7 +563,7 @@ def overwrite( overwrite_filter: Union[BooleanExpression, str] = ALWAYS_TRUE, snapshot_properties: Dict[str, str] = EMPTY_DICT, case_sensitive: bool = True, - branch: Optional[str] = None, + branch: Optional[str] = MAIN_BRANCH, ) -> None: """ Shorthand for adding a table overwrite with a PyArrow table to the transaction. @@ -632,7 +629,7 @@ def delete( delete_filter: Union[str, BooleanExpression], snapshot_properties: Dict[str, str] = EMPTY_DICT, case_sensitive: bool = True, - branch: Optional[str] = None, + branch: Optional[str] = MAIN_BRANCH, ) -> None: """ Shorthand for deleting record from a table. @@ -735,7 +732,7 @@ def upsert( when_matched_update_all: bool = True, when_not_matched_insert_all: bool = True, case_sensitive: bool = True, - branch: Optional[str] = None, + branch: Optional[str] = MAIN_BRANCH, ) -> UpsertResult: """Shorthand API for performing an upsert to an iceberg table. @@ -820,7 +817,7 @@ def upsert( case_sensitive=case_sensitive, ) - if branch is not None: + if branch in self.table_metadata.refs: matched_iceberg_record_batches_scan = matched_iceberg_record_batches_scan.use_ref(branch) matched_iceberg_record_batches = matched_iceberg_record_batches_scan.to_arrow_batch_reader() @@ -1311,7 +1308,7 @@ def upsert( when_matched_update_all: bool = True, when_not_matched_insert_all: bool = True, case_sensitive: bool = True, - branch: Optional[str] = None, + branch: Optional[str] = MAIN_BRANCH, ) -> UpsertResult: """Shorthand API for performing an upsert to an iceberg table. @@ -1358,7 +1355,7 @@ def upsert( branch=branch, ) - def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: Optional[str] = None) -> None: + def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: Optional[str] = MAIN_BRANCH) -> None: """ Shorthand API for appending a PyArrow table to the table. @@ -1371,7 +1368,7 @@ def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT, tx.append(df=df, snapshot_properties=snapshot_properties, branch=branch) def dynamic_partition_overwrite( - self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: Optional[str] = None + self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: Optional[str] = MAIN_BRANCH ) -> None: """Shorthand for dynamic overwriting the table with a PyArrow table. @@ -1390,7 +1387,7 @@ def overwrite( overwrite_filter: Union[BooleanExpression, str] = ALWAYS_TRUE, snapshot_properties: Dict[str, str] = EMPTY_DICT, case_sensitive: bool = True, - branch: Optional[str] = None, + branch: Optional[str] = MAIN_BRANCH, ) -> None: """ Shorthand for overwriting the table with a PyArrow table. @@ -1423,7 +1420,7 @@ def delete( delete_filter: Union[BooleanExpression, str] = ALWAYS_TRUE, snapshot_properties: Dict[str, str] = EMPTY_DICT, case_sensitive: bool = True, - branch: Optional[str] = None, + branch: Optional[str] = MAIN_BRANCH, ) -> None: """ Shorthand for deleting rows from the table. diff --git a/pyiceberg/table/metadata.py b/pyiceberg/table/metadata.py index 9c2ae29cdd..9ab29815e9 100644 --- a/pyiceberg/table/metadata.py +++ b/pyiceberg/table/metadata.py @@ -295,8 +295,10 @@ def new_snapshot_id(self) -> int: return snapshot_id - def snapshot_by_name(self, name: str) -> Optional[Snapshot]: + def snapshot_by_name(self, name: Optional[str]) -> Optional[Snapshot]: """Return the snapshot referenced by the given name or null if no such reference exists.""" + if name is None: + name = MAIN_BRANCH if ref := self.refs.get(name): return self.snapshot_by_id(ref.snapshot_id) return None diff --git a/pyiceberg/table/update/snapshot.py b/pyiceberg/table/update/snapshot.py index a9322578be..6a2d5b1785 100644 --- a/pyiceberg/table/update/snapshot.py +++ b/pyiceberg/table/update/snapshot.py @@ -108,8 +108,7 @@ class _SnapshotProducer(UpdateTableMetadata[U], Generic[U]): _manifest_num_counter: itertools.count[int] _deleted_data_files: Set[DataFile] _compression: AvroCompressionCodec - _target_branch = MAIN_BRANCH - _stage_only = False + _target_branch: Optional[str] def __init__( self, @@ -118,8 +117,7 @@ def __init__( io: FileIO, commit_uuid: Optional[uuid.UUID] = None, snapshot_properties: Dict[str, str] = EMPTY_DICT, - branch: str = MAIN_BRANCH, - stage_only: bool = False, + branch: Optional[str] = MAIN_BRANCH, ) -> None: super().__init__(transaction) self.commit_uuid = commit_uuid or uuid.uuid4() @@ -139,16 +137,14 @@ def __init__( self._parent_snapshot_id = ( snapshot.snapshot_id if (snapshot := self._transaction.table_metadata.snapshot_by_name(self._target_branch)) else None ) - self._stage_only = stage_only - def _validate_target_branch(self, branch: str) -> str: + def _validate_target_branch(self, branch: Optional[str]) -> Optional[str]: # Default is already set to MAIN_BRANCH. So branch name can't be None. - if branch is None: - raise ValueError("Invalid branch name: null") - if branch in self._transaction.table_metadata.refs: - ref = self._transaction.table_metadata.refs[branch] - if ref.snapshot_ref_type != SnapshotRefType.BRANCH: - raise ValueError(f"{branch} is a tag, not a branch. Tags cannot be targets for producing snapshots") + if branch is not None: + if branch in self._transaction.table_metadata.refs: + ref = self._transaction.table_metadata.refs[branch] + if ref.snapshot_ref_type != SnapshotRefType.BRANCH: + raise ValueError(f"{branch} is a tag, not a branch. Tags cannot be targets for producing snapshots") return branch def append_data_file(self, data_file: DataFile) -> _SnapshotProducer[U]: @@ -297,7 +293,7 @@ def _commit(self) -> UpdatesAndRequirements: add_snapshot_update = AddSnapshotUpdate(snapshot=snapshot) - if self._stage_only: + if self._target_branch is None: return ( (add_snapshot_update,), (), @@ -368,12 +364,11 @@ def __init__( operation: Operation, transaction: Transaction, io: FileIO, - branch: str, + branch: Optional[str] = MAIN_BRANCH, commit_uuid: Optional[uuid.UUID] = None, snapshot_properties: Dict[str, str] = EMPTY_DICT, - stage_only: bool = False, ): - super().__init__(operation, transaction, io, commit_uuid, snapshot_properties, branch, stage_only) + super().__init__(operation, transaction, io, commit_uuid, snapshot_properties, branch) self._predicate = AlwaysFalse() self._case_sensitive = True @@ -539,14 +534,13 @@ def __init__( operation: Operation, transaction: Transaction, io: FileIO, - branch: str, + branch: Optional[str] = MAIN_BRANCH, commit_uuid: Optional[uuid.UUID] = None, snapshot_properties: Dict[str, str] = EMPTY_DICT, - stage_only: bool = False, ) -> None: from pyiceberg.table import TableProperties - super().__init__(operation, transaction, io, commit_uuid, snapshot_properties, branch, stage_only) + super().__init__(operation, transaction, io, commit_uuid, snapshot_properties, branch) self._target_size_bytes = property_as_int( self._transaction.table_metadata.properties, TableProperties.MANIFEST_TARGET_SIZE_BYTES, @@ -661,23 +655,20 @@ def _get_entries(manifest: ManifestFile) -> List[ManifestEntry]: class UpdateSnapshot: _transaction: Transaction _io: FileIO - _branch: str - _stage_only: bool + _branch: Optional[str] _snapshot_properties: Dict[str, str] def __init__( self, transaction: Transaction, io: FileIO, - branch: str, - stage_only: bool = False, + branch: Optional[str] = MAIN_BRANCH, snapshot_properties: Dict[str, str] = EMPTY_DICT, ) -> None: self._transaction = transaction self._io = io self._snapshot_properties = snapshot_properties self._branch = branch - self._stage_only = stage_only def fast_append(self) -> _FastAppendFiles: return _FastAppendFiles( @@ -686,7 +677,6 @@ def fast_append(self) -> _FastAppendFiles: io=self._io, branch=self._branch, snapshot_properties=self._snapshot_properties, - stage_only=self._stage_only, ) def merge_append(self) -> _MergeAppendFiles: @@ -696,7 +686,6 @@ def merge_append(self) -> _MergeAppendFiles: io=self._io, branch=self._branch, snapshot_properties=self._snapshot_properties, - stage_only=self._stage_only, ) def overwrite(self, commit_uuid: Optional[uuid.UUID] = None) -> _OverwriteFiles: @@ -709,7 +698,6 @@ def overwrite(self, commit_uuid: Optional[uuid.UUID] = None) -> _OverwriteFiles: io=self._io, branch=self._branch, snapshot_properties=self._snapshot_properties, - stage_only=self._stage_only, ) def delete(self) -> _DeleteFiles: @@ -719,7 +707,6 @@ def delete(self) -> _DeleteFiles: io=self._io, branch=self._branch, snapshot_properties=self._snapshot_properties, - stage_only=self._stage_only, ) diff --git a/tests/integration/test_writes/test_writes.py b/tests/integration/test_writes/test_writes.py index 51f01b77dd..58aae8c531 100644 --- a/tests/integration/test_writes/test_writes.py +++ b/tests/integration/test_writes/test_writes.py @@ -2287,7 +2287,7 @@ def test_stage_only_delete( assert len(files_to_delete) > 0 with tbl.transaction() as txn: - with txn.update_snapshot(stage_only=True).delete() as delete: + with txn.update_snapshot(branch=None).delete() as delete: delete.delete_by_predicate(EqualTo("int", 9)) # a new delete snapshot is added @@ -2324,7 +2324,7 @@ def test_stage_only_fast_append( assert original_count == 3 with tbl.transaction() as txn: - with txn.update_snapshot(stage_only=True).fast_append() as fast_append: + with txn.update_snapshot(branch=None).fast_append() as fast_append: for data_file in _dataframe_to_data_files( table_metadata=txn.table_metadata, df=arrow_table_with_null, io=txn._table.io ): @@ -2364,7 +2364,7 @@ def test_stage_only_merge_append( assert original_count == 3 with tbl.transaction() as txn: - with txn.update_snapshot(stage_only=True).merge_append() as merge_append: + with txn.update_snapshot(branch=None).merge_append() as merge_append: for data_file in _dataframe_to_data_files( table_metadata=txn.table_metadata, df=arrow_table_with_null, io=txn._table.io ): @@ -2409,7 +2409,7 @@ def test_stage_only_overwrite_files( assert len(files_to_delete) > 0 with tbl.transaction() as txn: - with txn.update_snapshot(stage_only=True).overwrite() as overwrite: + with txn.update_snapshot(branch=None).overwrite() as overwrite: for data_file in _dataframe_to_data_files( table_metadata=txn.table_metadata, df=arrow_table_with_null, io=txn._table.io ): From 3a942cf2cb04e7a6926a4c413e409c91e204061b Mon Sep 17 00:00:00 2001 From: Yingjian Wu Date: Sun, 24 Aug 2025 13:32:17 -0700 Subject: [PATCH 3/4] add test --- pyiceberg/table/update/snapshot.py | 2 +- tests/integration/test_writes/test_writes.py | 113 +++++++++++++++---- 2 files changed, 91 insertions(+), 24 deletions(-) diff --git a/pyiceberg/table/update/snapshot.py b/pyiceberg/table/update/snapshot.py index 6a2d5b1785..c5c0bfa738 100644 --- a/pyiceberg/table/update/snapshot.py +++ b/pyiceberg/table/update/snapshot.py @@ -139,7 +139,7 @@ def __init__( ) def _validate_target_branch(self, branch: Optional[str]) -> Optional[str]: - # Default is already set to MAIN_BRANCH. So branch name can't be None. + # if branch is none, write will be written into a staging snapshot if branch is not None: if branch in self._transaction.table_metadata.refs: ref = self._transaction.table_metadata.refs[branch] diff --git a/tests/integration/test_writes/test_writes.py b/tests/integration/test_writes/test_writes.py index 58aae8c531..0cc0b79032 100644 --- a/tests/integration/test_writes/test_writes.py +++ b/tests/integration/test_writes/test_writes.py @@ -2293,20 +2293,36 @@ def test_stage_only_delete( # a new delete snapshot is added snapshots = tbl.snapshots() assert len(snapshots) == 2 + # snapshot main ref has not changed + assert current_snapshot == tbl.metadata.current_snapshot_id + assert len(tbl.scan().to_arrow()) == original_count + + # Write to main branch + with tbl.transaction() as txn: + with txn.update_snapshot().fast_append() as fast_append: + for data_file in _dataframe_to_data_files( + table_metadata=txn.table_metadata, df=arrow_table_with_null, io=txn._table.io + ): + fast_append.append_data_file(data_file=data_file) + + # Main ref has changed + assert current_snapshot != tbl.metadata.current_snapshot_id + assert len(tbl.scan().to_arrow()) == 3 + snapshots = tbl.snapshots() + assert len(snapshots) == 3 rows = spark.sql( f""" - SELECT operation, summary - FROM {identifier}.snapshots - ORDER BY committed_at ASC - """ + SELECT operation, parent_id + FROM {identifier}.snapshots + ORDER BY committed_at ASC + """ ).collect() operations = [row.operation for row in rows] - assert operations == ["append", "delete"] - - # snapshot main ref has not changed - assert current_snapshot == tbl.metadata.current_snapshot_id - assert len(tbl.scan().to_arrow()) == original_count + parent_snapshot_id = [row.parent_id for row in rows] + assert operations == ["append", "delete", "append"] + # both subsequent parent id should be the first snapshot id + assert parent_snapshot_id == [None, current_snapshot, current_snapshot] @pytest.mark.integration @@ -2323,6 +2339,7 @@ def test_stage_only_fast_append( original_count = len(tbl.scan().to_arrow()) assert original_count == 3 + # Write to staging branch with tbl.transaction() as txn: with txn.update_snapshot(branch=None).fast_append() as fast_append: for data_file in _dataframe_to_data_files( @@ -2333,20 +2350,37 @@ def test_stage_only_fast_append( # Main ref has not changed and data is not yet appended assert current_snapshot == tbl.metadata.current_snapshot_id assert len(tbl.scan().to_arrow()) == original_count - # There should be a new staged snapshot snapshots = tbl.snapshots() assert len(snapshots) == 2 + # Write to main branch + with tbl.transaction() as txn: + with txn.update_snapshot().fast_append() as fast_append: + for data_file in _dataframe_to_data_files( + table_metadata=txn.table_metadata, df=arrow_table_with_null, io=txn._table.io + ): + fast_append.append_data_file(data_file=data_file) + + # Main ref has changed + assert current_snapshot != tbl.metadata.current_snapshot_id + assert len(tbl.scan().to_arrow()) == 6 + snapshots = tbl.snapshots() + assert len(snapshots) == 3 + rows = spark.sql( f""" - SELECT operation, summary + SELECT operation, parent_id FROM {identifier}.snapshots ORDER BY committed_at ASC """ ).collect() operations = [row.operation for row in rows] - assert operations == ["append", "append"] + parent_snapshot_id = [row.parent_id for row in rows] + assert operations == ["append", "append", "append"] + # both subsequent parent id should be the first snapshot id + assert parent_snapshot_id == [None, current_snapshot, current_snapshot] + @pytest.mark.integration @@ -2378,15 +2412,32 @@ def test_stage_only_merge_append( snapshots = tbl.snapshots() assert len(snapshots) == 2 + # Write to main branch + with tbl.transaction() as txn: + with txn.update_snapshot().fast_append() as fast_append: + for data_file in _dataframe_to_data_files( + table_metadata=txn.table_metadata, df=arrow_table_with_null, io=txn._table.io + ): + fast_append.append_data_file(data_file=data_file) + + # Main ref has changed + assert current_snapshot != tbl.metadata.current_snapshot_id + assert len(tbl.scan().to_arrow()) == 6 + snapshots = tbl.snapshots() + assert len(snapshots) == 3 + rows = spark.sql( f""" - SELECT operation, summary - FROM {identifier}.snapshots - ORDER BY committed_at ASC - """ + SELECT operation, parent_id + FROM {identifier}.snapshots + ORDER BY committed_at ASC + """ ).collect() operations = [row.operation for row in rows] - assert operations == ["append", "append"] + parent_snapshot_id = [row.parent_id for row in rows] + assert operations == ["append", "append", "append"] + # both subsequent parent id should be the first snapshot id + assert parent_snapshot_id == [None, current_snapshot, current_snapshot] @pytest.mark.integration @@ -2418,16 +2469,32 @@ def test_stage_only_overwrite_files( assert current_snapshot == tbl.metadata.current_snapshot_id assert len(tbl.scan().to_arrow()) == original_count - snapshots = tbl.snapshots() assert len(snapshots) == 2 + # Write to main branch + with tbl.transaction() as txn: + with txn.update_snapshot().fast_append() as fast_append: + for data_file in _dataframe_to_data_files( + table_metadata=txn.table_metadata, df=arrow_table_with_null, io=txn._table.io + ): + fast_append.append_data_file(data_file=data_file) + + # Main ref has changed + assert current_snapshot != tbl.metadata.current_snapshot_id + assert len(tbl.scan().to_arrow()) == 6 + snapshots = tbl.snapshots() + assert len(snapshots) == 3 + rows = spark.sql( f""" - SELECT operation, summary - FROM {identifier}.snapshots - ORDER BY committed_at ASC - """ + SELECT operation, parent_id + FROM {identifier}.snapshots + ORDER BY committed_at ASC + """ ).collect() operations = [row.operation for row in rows] - assert operations == ["append", "overwrite"] + parent_snapshot_id = [row.parent_id for row in rows] + assert operations == ["append", "overwrite", "append"] + # both subsequent parent id should be the first snapshot id + assert parent_snapshot_id == [None, current_snapshot, current_snapshot] From 0bcdcbe6f3c9034f86363ca2f3c721082e70e47e Mon Sep 17 00:00:00 2001 From: Yingjian Wu Date: Sun, 24 Aug 2025 17:06:47 -0700 Subject: [PATCH 4/4] improve test --- .../test_writes/test_partitioned_writes.py | 58 +++++++ tests/integration/test_writes/test_writes.py | 144 ++++-------------- tests/table/test_upsert.py | 64 ++++++++ 3 files changed, 148 insertions(+), 118 deletions(-) diff --git a/tests/integration/test_writes/test_partitioned_writes.py b/tests/integration/test_writes/test_partitioned_writes.py index e9698067c1..6d1c15212a 100644 --- a/tests/integration/test_writes/test_partitioned_writes.py +++ b/tests/integration/test_writes/test_partitioned_writes.py @@ -1133,3 +1133,61 @@ def test_append_multiple_partitions( """ ) assert files_df.count() == 6 + + +@pytest.mark.integration +@pytest.mark.parametrize("format_version", [1, 2]) +def test_stage_only_dynamic_partition_overwrite_files( + spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int +) -> None: + identifier = f"default.test_stage_only_dynamic_partition_overwrite_files_v{format_version}" + try: + session_catalog.drop_table(identifier=identifier) + except NoSuchTableError: + pass + tbl = session_catalog.create_table( + identifier=identifier, + schema=TABLE_SCHEMA, + partition_spec=PartitionSpec( + PartitionField(source_id=1, field_id=1001, transform=IdentityTransform(), name="bool"), + PartitionField(source_id=4, field_id=1002, transform=IdentityTransform(), name="int"), + ), + properties={"format-version": str(format_version)}, + ) + + tbl.append(arrow_table_with_null) + current_snapshot = tbl.metadata.current_snapshot_id + assert current_snapshot is not None + + original_count = len(tbl.scan().to_arrow()) + assert original_count == 3 + + # write to staging snapshot + tbl.dynamic_partition_overwrite(arrow_table_with_null.slice(0, 1), branch=None) + + assert current_snapshot == tbl.metadata.current_snapshot_id + assert len(tbl.scan().to_arrow()) == original_count + snapshots = tbl.snapshots() + # dynamic partition overwrite will create 2 snapshots, one delete and another append + assert len(snapshots) == 3 + + # Write to main branch + tbl.append(arrow_table_with_null) + + # Main ref has changed + assert current_snapshot != tbl.metadata.current_snapshot_id + assert len(tbl.scan().to_arrow()) == 6 + snapshots = tbl.snapshots() + assert len(snapshots) == 4 + + rows = spark.sql( + f""" + SELECT operation, parent_id, snapshot_id + FROM {identifier}.snapshots + ORDER BY committed_at ASC + """ + ).collect() + operations = [row.operation for row in rows] + parent_snapshot_id = [row.parent_id for row in rows] + assert operations == ["append", "delete", "append", "append"] + assert parent_snapshot_id == [None, current_snapshot, current_snapshot, current_snapshot] diff --git a/tests/integration/test_writes/test_writes.py b/tests/integration/test_writes/test_writes.py index 0cc0b79032..25a3c70bc5 100644 --- a/tests/integration/test_writes/test_writes.py +++ b/tests/integration/test_writes/test_writes.py @@ -2281,14 +2281,7 @@ def test_stage_only_delete( original_count = len(tbl.scan().to_arrow()) assert original_count == 3 - files_to_delete = [] - for file_task in tbl.scan().plan_files(): - files_to_delete.append(file_task.file) - assert len(files_to_delete) > 0 - - with tbl.transaction() as txn: - with txn.update_snapshot(branch=None).delete() as delete: - delete.delete_by_predicate(EqualTo("int", 9)) + tbl.delete("int = 9", branch=None) # a new delete snapshot is added snapshots = tbl.snapshots() @@ -2298,16 +2291,11 @@ def test_stage_only_delete( assert len(tbl.scan().to_arrow()) == original_count # Write to main branch - with tbl.transaction() as txn: - with txn.update_snapshot().fast_append() as fast_append: - for data_file in _dataframe_to_data_files( - table_metadata=txn.table_metadata, df=arrow_table_with_null, io=txn._table.io - ): - fast_append.append_data_file(data_file=data_file) + tbl.append(arrow_table_with_null) # Main ref has changed assert current_snapshot != tbl.metadata.current_snapshot_id - assert len(tbl.scan().to_arrow()) == 3 + assert len(tbl.scan().to_arrow()) == 6 snapshots = tbl.snapshots() assert len(snapshots) == 3 @@ -2327,7 +2315,7 @@ def test_stage_only_delete( @pytest.mark.integration @pytest.mark.parametrize("format_version", [1, 2]) -def test_stage_only_fast_append( +def test_stage_only_append( spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int ) -> None: identifier = f"default.test_stage_only_fast_append_files_v{format_version}" @@ -2340,12 +2328,7 @@ def test_stage_only_fast_append( assert original_count == 3 # Write to staging branch - with tbl.transaction() as txn: - with txn.update_snapshot(branch=None).fast_append() as fast_append: - for data_file in _dataframe_to_data_files( - table_metadata=txn.table_metadata, df=arrow_table_with_null, io=txn._table.io - ): - fast_append.append_data_file(data_file=data_file) + tbl.append(arrow_table_with_null, branch=None) # Main ref has not changed and data is not yet appended assert current_snapshot == tbl.metadata.current_snapshot_id @@ -2355,12 +2338,7 @@ def test_stage_only_fast_append( assert len(snapshots) == 2 # Write to main branch - with tbl.transaction() as txn: - with txn.update_snapshot().fast_append() as fast_append: - for data_file in _dataframe_to_data_files( - table_metadata=txn.table_metadata, df=arrow_table_with_null, io=txn._table.io - ): - fast_append.append_data_file(data_file=data_file) + tbl.append(arrow_table_with_null) # Main ref has changed assert current_snapshot != tbl.metadata.current_snapshot_id @@ -2382,64 +2360,6 @@ def test_stage_only_fast_append( assert parent_snapshot_id == [None, current_snapshot, current_snapshot] - -@pytest.mark.integration -@pytest.mark.parametrize("format_version", [1, 2]) -def test_stage_only_merge_append( - spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int -) -> None: - identifier = f"default.test_stage_only_merge_append_files_v{format_version}" - tbl = _create_table(session_catalog, identifier, {"format-version": str(format_version)}, [arrow_table_with_null]) - - current_snapshot = tbl.metadata.current_snapshot_id - assert current_snapshot is not None - - original_count = len(tbl.scan().to_arrow()) - assert original_count == 3 - - with tbl.transaction() as txn: - with txn.update_snapshot(branch=None).merge_append() as merge_append: - for data_file in _dataframe_to_data_files( - table_metadata=txn.table_metadata, df=arrow_table_with_null, io=txn._table.io - ): - merge_append.append_data_file(data_file=data_file) - - # Main ref has not changed and data is not yet appended - assert current_snapshot == tbl.metadata.current_snapshot_id - assert len(tbl.scan().to_arrow()) == original_count - - # There should be a new staged snapshot - snapshots = tbl.snapshots() - assert len(snapshots) == 2 - - # Write to main branch - with tbl.transaction() as txn: - with txn.update_snapshot().fast_append() as fast_append: - for data_file in _dataframe_to_data_files( - table_metadata=txn.table_metadata, df=arrow_table_with_null, io=txn._table.io - ): - fast_append.append_data_file(data_file=data_file) - - # Main ref has changed - assert current_snapshot != tbl.metadata.current_snapshot_id - assert len(tbl.scan().to_arrow()) == 6 - snapshots = tbl.snapshots() - assert len(snapshots) == 3 - - rows = spark.sql( - f""" - SELECT operation, parent_id - FROM {identifier}.snapshots - ORDER BY committed_at ASC - """ - ).collect() - operations = [row.operation for row in rows] - parent_snapshot_id = [row.parent_id for row in rows] - assert operations == ["append", "append", "append"] - # both subsequent parent id should be the first snapshot id - assert parent_snapshot_id == [None, current_snapshot, current_snapshot] - - @pytest.mark.integration @pytest.mark.parametrize("format_version", [1, 2]) def test_stage_only_overwrite_files( @@ -2447,54 +2367,42 @@ def test_stage_only_overwrite_files( ) -> None: identifier = f"default.test_stage_only_overwrite_files_v{format_version}" tbl = _create_table(session_catalog, identifier, {"format-version": str(format_version)}, [arrow_table_with_null]) + first_snapshot = tbl.metadata.current_snapshot_id - current_snapshot = tbl.metadata.current_snapshot_id - assert current_snapshot is not None + # duplicate data with a new insert + tbl.append(arrow_table_with_null) + second_snapshot = tbl.metadata.current_snapshot_id + assert second_snapshot is not None original_count = len(tbl.scan().to_arrow()) - assert original_count == 3 + assert original_count == 6 - files_to_delete = [] - for file_task in tbl.scan().plan_files(): - files_to_delete.append(file_task.file) - assert len(files_to_delete) > 0 - - with tbl.transaction() as txn: - with txn.update_snapshot(branch=None).overwrite() as overwrite: - for data_file in _dataframe_to_data_files( - table_metadata=txn.table_metadata, df=arrow_table_with_null, io=txn._table.io - ): - overwrite.append_data_file(data_file=data_file) - overwrite.delete_data_file(files_to_delete[0]) - - assert current_snapshot == tbl.metadata.current_snapshot_id + # write to non-main branch + tbl.overwrite(arrow_table_with_null, branch=None) + assert second_snapshot == tbl.metadata.current_snapshot_id assert len(tbl.scan().to_arrow()) == original_count snapshots = tbl.snapshots() - assert len(snapshots) == 2 + # overwrite will create 2 snapshots + assert len(snapshots) == 4 - # Write to main branch - with tbl.transaction() as txn: - with txn.update_snapshot().fast_append() as fast_append: - for data_file in _dataframe_to_data_files( - table_metadata=txn.table_metadata, df=arrow_table_with_null, io=txn._table.io - ): - fast_append.append_data_file(data_file=data_file) + # Write to main branch again + tbl.append(arrow_table_with_null) # Main ref has changed - assert current_snapshot != tbl.metadata.current_snapshot_id - assert len(tbl.scan().to_arrow()) == 6 + assert second_snapshot != tbl.metadata.current_snapshot_id + assert len(tbl.scan().to_arrow()) == 9 snapshots = tbl.snapshots() - assert len(snapshots) == 3 + assert len(snapshots) == 5 rows = spark.sql( f""" - SELECT operation, parent_id + SELECT operation, parent_id, snapshot_id FROM {identifier}.snapshots ORDER BY committed_at ASC """ ).collect() operations = [row.operation for row in rows] parent_snapshot_id = [row.parent_id for row in rows] - assert operations == ["append", "overwrite", "append"] - # both subsequent parent id should be the first snapshot id - assert parent_snapshot_id == [None, current_snapshot, current_snapshot] + assert operations == ["append", "append", "delete", "append", "append"] + + assert parent_snapshot_id == [None, first_snapshot, second_snapshot, second_snapshot, second_snapshot] diff --git a/tests/table/test_upsert.py b/tests/table/test_upsert.py index cc6e008b1e..1a9c35fc07 100644 --- a/tests/table/test_upsert.py +++ b/tests/table/test_upsert.py @@ -770,3 +770,67 @@ def test_transaction_multiple_upserts(catalog: Catalog) -> None: {"id": 1, "name": "Alicia"}, {"id": 2, "name": "Bob"}, ] + + +def test_stage_only_upsert(catalog: Catalog) -> None: + identifier = "default.test_stage_only_dynamic_partition_overwrite_files" + _drop_table(catalog, identifier) + + schema = Schema( + NestedField(1, "city", StringType(), required=True), + NestedField(2, "inhabitants", IntegerType(), required=True), + # Mark City as the identifier field, also known as the primary-key + identifier_field_ids=[1], + ) + + tbl = catalog.create_table(identifier, schema=schema) + + arrow_schema = pa.schema( + [ + pa.field("city", pa.string(), nullable=False), + pa.field("inhabitants", pa.int32(), nullable=False), + ] + ) + + # Write some data + df = pa.Table.from_pylist( + [ + {"city": "Amsterdam", "inhabitants": 921402}, + {"city": "San Francisco", "inhabitants": 808988}, + {"city": "Drachten", "inhabitants": 45019}, + {"city": "Paris", "inhabitants": 2103000}, + ], + schema=arrow_schema, + ) + + tbl.append(df.slice(0, 1)) + current_snapshot = tbl.metadata.current_snapshot_id + assert current_snapshot is not None + + original_count = len(tbl.scan().to_arrow()) + assert original_count == 1 + + # write to staging snapshot + upd = tbl.upsert(df, branch = None) + assert upd.rows_updated == 0 + assert upd.rows_inserted == 3 + + assert current_snapshot == tbl.metadata.current_snapshot_id + assert len(tbl.scan().to_arrow()) == original_count + snapshots = tbl.snapshots() + assert len(snapshots) == 2 + + # Write to main ref + tbl.append(df.slice(1, 1)) + # Main ref has changed + assert current_snapshot != tbl.metadata.current_snapshot_id + assert len(tbl.scan().to_arrow()) == 2 + snapshots = tbl.snapshots() + assert len(snapshots) == 3 + + sorted_snapshots = sorted(tbl.snapshots(), key=lambda s: s.timestamp_ms) + operations = [snapshot.summary.operation.value if snapshot.summary else None for snapshot in sorted_snapshots] + parent_snapshot_id = [snapshot.parent_snapshot_id for snapshot in sorted_snapshots] + assert operations == ["append", "append", "append"] + # both subsequent parent id should be the first snapshot id + assert parent_snapshot_id == [None, current_snapshot, current_snapshot]