Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions kauldron/data/transforms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class Elements(tr_abc.MapTransform):
drop: Iterable[str] = ()
rename: Mapping[str, str] = _FrozenDict()
copy: Mapping[str, str] = _FrozenDict()
skip_missing: bool = False

def __post_init__(self):
if self.keep and self.drop:
Expand Down Expand Up @@ -83,7 +84,7 @@ def map(self, features):
if bool(self.copy):
copy_keys = set(self.copy.keys())
missing_copy_keys = copy_keys - feature_keys
if missing_copy_keys:
if missing_copy_keys and not self.skip_missing:
raise KeyError(
f"copy-key(s) {missing_copy_keys} not found in batch. "
f"Available keys are {sorted(feature_keys)!r}."
Expand All @@ -95,13 +96,16 @@ def map(self, features):
f"copy-value(s) {overlap_keys} will overwrite existing values in "
f"batch. Existing keys are {sorted(feature_keys)!r}."
)
copy_output = {v: features[k] for k, v in self.copy.items()}
copy_output = {}
for k, v in self.copy.items():
if k in features:
copy_output[v] = features[k]

# resolve keep or drop
if self.keep:
keep_keys = set(self.keep)
missing_keep_keys = keep_keys - feature_keys
if missing_keep_keys:
if missing_keep_keys and not self.skip_missing:
raise KeyError(
f"keep-key(s) {missing_keep_keys} not found in batch. "
f"Available keys are {sorted(feature_keys)!r}."
Expand All @@ -110,7 +114,7 @@ def map(self, features):
elif self.drop:
drop_keys = set(self.drop)
missing_drop_keys = drop_keys - feature_keys
if missing_drop_keys:
if missing_drop_keys and not self.skip_missing:
raise KeyError(
f"drop-key(s) {missing_drop_keys} not found in batch. "
f"Available keys are {sorted(feature_keys)!r}."
Expand All @@ -127,7 +131,7 @@ def map(self, features):
# resolve renaming
rename_keys = set(self.rename.keys())
missing_rename_keys = rename_keys - feature_keys
if missing_rename_keys:
if missing_rename_keys and not self.skip_missing:
raise KeyError(
f"rename-key(s) {missing_rename_keys} not found in batch. "
f"Available keys are {sorted(feature_keys)!r}."
Expand Down
61 changes: 61 additions & 0 deletions kauldron/data/transforms/base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,22 @@ def test_elements_keep():
assert after["no_copy"] == before["no"]


def test_elements_keep_skip_missing():
el = kd.data.py.Elements(
keep={"yes", "definitely", "missing"},
rename={"old": "new", "old_missing": "new_missing"},
copy={"no": "no_copy", "missing": "missing_copy"},
skip_missing=True,
)
before = {"yes": 1, "definitely": 2, "old": 3, "no": 4, "drop": 5}
after = el.map(before)
assert set(after.keys()) == {"yes", "definitely", "new", "no_copy"}
assert after["yes"] == before["yes"]
assert after["definitely"] == before["definitely"]
assert after["new"] == before["old"]
assert after["no_copy"] == before["no"]


def test_elements_drop():
el = kd.data.py.Elements(
drop={"no", "drop"}, rename={"old": "new"}, copy={"yes": "yes_copy"}
Expand All @@ -42,6 +58,22 @@ def test_elements_drop():
assert after["yes_copy"] == before["yes"]


def test_elements_drop_skip_missing():
el = kd.data.py.Elements(
drop={"no", "drop", "missing"},
rename={"old": "new", "old_missing": "new_missing"},
copy={"yes": "yes_copy", "missing": "missing_copy"},
skip_missing=True,
)
before = {"yes": 1, "definitely": 2, "old": 3, "no": 4, "drop": 5}
after = el.map(before)
assert set(after.keys()) == {"yes", "definitely", "new", "yes_copy"}
assert after["yes"] == before["yes"]
assert after["definitely"] == before["definitely"]
assert after["new"] == before["old"]
assert after["yes_copy"] == before["yes"]


def test_elements_rename_only():
el = kd.data.py.Elements(rename={"old": "new"})
before = {"yes": 1, "definitely": 2, "old": 3, "no": 4, "drop": 5}
Expand All @@ -60,6 +92,12 @@ def test_elements_rename_overwrite_raises():
with pytest.raises(KeyError):
el.map(before)

# Same as above but with skip_missing=True.
el = kd.data.py.Elements(rename={"old": "oops"}, skip_missing=True)
before = {"old": 1, "oops": 2}
with pytest.raises(KeyError):
el.map(before)


def test_elements_copy_only():
el = kd.data.py.Elements(copy={"yes": "no", "old": "new"})
Expand All @@ -84,3 +122,26 @@ def test_elements_copy_overwrite_raises():
# copy two fields to the same target name
with pytest.raises(ValueError):
_ = kd.data.py.Elements(copy={"old": "oops", "yes": "oops"})


def test_elements_copy_overwrite_raises_skip_missing():
# copy to an existing key
el = kd.data.py.Elements(
copy={"old": "oops", "missing": "missing_copy"}, skip_missing=True
)
before = {"old": 1, "oops": 2}
with pytest.raises(KeyError):
el.map(before)
# copy to a key that is also a rename target
with pytest.raises(KeyError):
_ = kd.data.py.Elements(
copy={"old": "oops"},
rename={"yes": "oops", "missing": "missing_remamed"},
skip_missing=True,
)
# copy two fields to the same target name
with pytest.raises(ValueError):
_ = kd.data.py.Elements(
copy={"old": "oops", "yes": "oops", "missing": "missing_copy"},
skip_missing=True,
)