Skip to content

Commit 3a7b4e4

Browse files
committed
Remove LocalEnsemble.load_all_gen_kw_data()
This commit removes the function, and replaces it with `LocalEnsemble.load_scalars()` as it gradually moves from pandas towards polars.
1 parent d1a849d commit 3a7b4e4

File tree

8 files changed

+71
-77
lines changed

8 files changed

+71
-77
lines changed

src/ert/plugins/hook_implementations/workflows/csv_export.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,10 @@ def run(
9090
f"The ensemble '{ensemble.name}' does not have any data!"
9191
)
9292

93-
ensemble_data = ensemble.load_all_gen_kw_data()
93+
ensemble_data = ensemble.load_scalars().to_pandas().set_index("realization")
94+
ensemble_data.columns.name = None
95+
ensemble_data.index.name = "Realization"
96+
ensemble_data = ensemble_data.sort_index(axis=1)
9497

9598
if design_matrix_path is not None:
9699
design_matrix_data = loadDesignMatrix(design_matrix_path)

src/ert/storage/local_ensemble.py

Lines changed: 0 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -788,50 +788,6 @@ def _load_responses_lazy(
788788

789789
return pl.concat(loaded) if loaded else pl.DataFrame().lazy()
790790

791-
def load_all_gen_kw_data(
792-
self,
793-
group: str | None = None,
794-
realization_index: int | None = None,
795-
) -> pd.DataFrame:
796-
"""Loads scalar parameters (GEN_KWs) into a pandas DataFrame
797-
with columns <PARAMETER_GROUP>:<PARAMETER_NAME> and
798-
"Realization" as index.
799-
800-
Parameters
801-
----------
802-
group : str, optional
803-
Name of parameter group to load.
804-
relization_index : int, optional
805-
The realization to load.
806-
807-
Returns
808-
-------
809-
data : DataFrame
810-
A pandas DataFrame containing the GEN_KW data.
811-
812-
Notes
813-
-----
814-
Any provided keys that are not gen_kw will be ignored.
815-
"""
816-
if realization_index is not None:
817-
realizations = np.array([realization_index])
818-
else:
819-
ens_mask = (
820-
self.get_realization_mask_with_responses()
821-
+ self.get_realization_mask_with_parameters()
822-
)
823-
realizations = np.flatnonzero(ens_mask)
824-
825-
df = self.load_scalars(group, realizations)
826-
827-
if df.is_empty():
828-
return pd.DataFrame()
829-
830-
dataframe = df.to_pandas().set_index("realization")
831-
dataframe.columns.name = None
832-
dataframe.index.name = "Realization"
833-
return dataframe.sort_index(axis=1)
834-
835791
@require_write
836792
def save_parameters(
837793
self,

tests/ert/ui_tests/cli/analysis/test_es_update.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,9 @@ def test_that_posterior_has_lower_variance_than_prior():
5353
with open_storage("storage") as storage:
5454
experiment = storage.get_experiment_by_name("es-test")
5555
prior_ensemble = experiment.get_ensemble_by_name("iter-0")
56-
df_default = prior_ensemble.load_all_gen_kw_data()
56+
df_default = prior_ensemble.load_scalars()
5757
posterior_ensemble = experiment.get_ensemble_by_name("iter-1")
58-
df_target = posterior_ensemble.load_all_gen_kw_data()
58+
df_target = posterior_ensemble.load_scalars()
5959

6060
# The std for the ensemble should decrease
6161
assert float(
@@ -68,8 +68,8 @@ def test_that_posterior_has_lower_variance_than_prior():
6868
# generalized variance for the parameters.
6969
assert (
7070
0
71-
< np.linalg.det(df_target.cov().to_numpy())
72-
< np.linalg.det(df_default.cov().to_numpy())
71+
< np.linalg.det(np.cov(df_target.to_numpy(), rowvar=False))
72+
< np.linalg.det(np.cov(df_default.to_numpy(), rowvar=False))
7373
)
7474

7575

tests/ert/ui_tests/cli/test_cli.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -537,7 +537,11 @@ def test_that_es_mda_on_poly_case_matches_snapshot(snapshot):
537537
experiment = storage.get_experiment_by_name("es-mda")
538538
for iter_nr in range(4):
539539
ensemble = experiment.get_ensemble_by_name(f"iter-{iter_nr}")
540-
data.append(ensemble.load_all_gen_kw_data())
540+
ensemble_data = ensemble.load_scalars().to_pandas().set_index("realization")
541+
ensemble_data.columns.name = None
542+
ensemble_data.index.name = "Realization"
543+
ensemble_data = ensemble_data.sort_index(axis=1)
544+
data.append(ensemble_data)
541545
result = pd.concat(
542546
data,
543547
keys=[f"iter-{iter_}" for iter_ in range(len(data))],
@@ -573,7 +577,11 @@ def test_that_enif_on_poly_case_matches_snapshot(snapshot):
573577
experiment = storage.get_experiment_by_name("enif")
574578
for iter_nr in range(2):
575579
ensemble = experiment.get_ensemble_by_name(f"iter-{iter_nr}")
576-
data.append(ensemble.load_all_gen_kw_data())
580+
ensemble_data = ensemble.load_scalars().to_pandas().set_index("realization")
581+
ensemble_data.columns.name = None
582+
ensemble_data.index.name = "Realization"
583+
ensemble_data = ensemble_data.sort_index(axis=1)
584+
data.append(ensemble_data)
577585
result = pd.concat(
578586
data,
579587
keys=[f"iter-{i}" for i in range(len(data))],

tests/ert/ui_tests/cli/test_update.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -232,12 +232,10 @@ def test_update_lowers_generalized_variance_or_deactivates_observations(
232232
if success:
233233
with open_storage("storage") as storage:
234234
experiment = storage.get_experiment_by_name("experiment")
235-
prior = experiment.get_ensemble_by_name("iter-0").load_all_gen_kw_data()
236-
posterior = experiment.get_ensemble_by_name(
237-
"iter-1"
238-
).load_all_gen_kw_data()
235+
prior = experiment.get_ensemble_by_name("iter-0").load_scalars()
236+
posterior = experiment.get_ensemble_by_name("iter-1").load_scalars()
239237

240238
assert (
241-
np.linalg.det(posterior.cov().to_numpy())
242-
<= np.linalg.det(prior.cov().to_numpy()) + 0.001
239+
np.linalg.det(np.cov(posterior.to_numpy(), rowvar=False))
240+
<= np.linalg.det(np.cov(prior.to_numpy(), rowvar=False)) + 0.001
243241
)

tests/ert/ui_tests/gui/test_csv_export.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def verify_exported_content(file_name, gui, ensemble_select):
6868
for name in ensemble_names:
6969
experiment = gui.notifier.storage.get_experiment_by_name("es_mda")
7070
ensemble = experiment.get_ensemble_by_name(name)
71-
gen_kw_data = ensemble.load_all_gen_kw_data()
71+
gen_kw_data = ensemble.load_scalars().to_pandas()
7272

7373
facade = LibresFacade.from_config_file("poly.ert")
7474
misfit_data = facade.load_all_misfit_data(ensemble)

tests/ert/ui_tests/gui/test_full_manual_update_workflow.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,11 +90,11 @@ def test_manual_analysis_workflow(ensemble_experiment_has_run, qtbot):
9090
10,
9191
)
9292

93-
df_prior = ensemble_prior.load_all_gen_kw_data()
93+
df_prior = ensemble_prior.load_scalars().to_pandas()
9494

9595
exp_posterior = storage.get_experiment_by_name("Manual update of iter-0")
9696
ensemble_posterior = exp_posterior.get_ensemble_by_name("iter-0_1")
97-
df_posterior = ensemble_posterior.load_all_gen_kw_data()
97+
df_posterior = ensemble_posterior.load_scalars().to_pandas()
9898

9999
# Making sure measured data works with failed realizations
100100
MeasuredData(experiment.get_ensemble_by_name("iter-0"), ["POLY_OBS"])

tests/ert/unit_tests/storage/test_local_storage.py

Lines changed: 46 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1068,40 +1068,69 @@ def test_load_gen_kw_not_sorted(storage, tmpdir, snapshot):
10681068
)
10691069

10701070
sample_prior(ensemble, range(ensemble_size), random_seed=1234)
1071-
1072-
data = ensemble.load_all_gen_kw_data()
1071+
data = ensemble.load_scalars().to_pandas().set_index("realization")
1072+
data.columns.name = None
1073+
data.index.name = "Realization"
1074+
data = data.sort_index(axis=1)
10731075
snapshot.assert_match(data.round(12).to_csv(), "gen_kw_unsorted")
10741076

10751077

10761078
def test_gen_kw_collector(snake_oil_default_storage, snapshot):
1077-
data = snake_oil_default_storage.load_all_gen_kw_data()
1079+
data = snake_oil_default_storage.load_scalars().to_pandas().set_index("realization")
1080+
data.columns.name = None
1081+
data.index.name = "Realization"
1082+
data = data.sort_index(axis=1)
10781083
snapshot.assert_match(data.round(6).to_csv(), "gen_kw_collector.csv")
10791084

10801085
with pytest.raises(KeyError):
10811086
# realization 60:
10821087
_ = data.loc[60]
10831088

1084-
data = snake_oil_default_storage.load_all_gen_kw_data(
1085-
"SNAKE_OIL_PARAM",
1086-
)[["SNAKE_OIL_PARAM:OP1_PERSISTENCE", "SNAKE_OIL_PARAM:OP1_OFFSET"]]
1089+
data = (
1090+
snake_oil_default_storage.load_scalars(
1091+
"SNAKE_OIL_PARAM",
1092+
)
1093+
.to_pandas()
1094+
.set_index("realization")
1095+
)
1096+
data.columns.name = None
1097+
data.index.name = "Realization"
1098+
data = data.sort_index(axis=1)
1099+
data = data[["SNAKE_OIL_PARAM:OP1_PERSISTENCE", "SNAKE_OIL_PARAM:OP1_OFFSET"]]
10871100
snapshot.assert_match(data.round(6).to_csv(), "gen_kw_collector_2.csv")
10881101

10891102
with pytest.raises(KeyError):
10901103
_ = data["SNAKE_OIL_PARAM:OP1_DIVERGENCE_SCALE"]
10911104

10921105
realization_index = 3
1093-
data = snake_oil_default_storage.load_all_gen_kw_data(
1094-
"SNAKE_OIL_PARAM",
1095-
realization_index=realization_index,
1096-
)["SNAKE_OIL_PARAM:OP1_PERSISTENCE"]
1106+
data = (
1107+
snake_oil_default_storage.load_scalars(
1108+
"SNAKE_OIL_PARAM",
1109+
realizations=[realization_index],
1110+
)
1111+
.to_pandas()
1112+
.set_index("realization")
1113+
)
1114+
data.columns.name = None
1115+
data.index.name = "Realization"
1116+
data = data.sort_index(axis=1)
1117+
data = data["SNAKE_OIL_PARAM:OP1_PERSISTENCE"]
10971118
snapshot.assert_match(data.round(6).to_csv(), "gen_kw_collector_3.csv")
10981119

10991120
non_existing_realization_index = 150
11001121
with pytest.raises((IndexError, KeyError)):
1101-
_ = snake_oil_default_storage.load_all_gen_kw_data(
1102-
"SNAKE_OIL_PARAM",
1103-
realization_index=non_existing_realization_index,
1104-
)["SNAKE_OIL_PARAM:OP1_PERSISTENCE"]
1122+
data = (
1123+
snake_oil_default_storage.load_scalars(
1124+
"SNAKE_OIL_PARAM",
1125+
realizations=[non_existing_realization_index],
1126+
)
1127+
.to_pandas()
1128+
.set_index("realization")
1129+
)
1130+
data.columns.name = None
1131+
data.index.name = "Realization"
1132+
data = data.sort_index(axis=1)
1133+
data = data["SNAKE_OIL_PARAM:OP1_PERSISTENCE"]
11051134

11061135

11071136
def test_keyword_type_checks(snake_oil_default_storage):
@@ -1130,12 +1159,12 @@ def test_data_fetching_missing_key(snake_oil_case):
11301159
empty_case = experiment.create_ensemble(name="new_case", ensemble_size=25)
11311160

11321161
data = [
1133-
empty_case.load_all_gen_kw_data("nokey", None),
1162+
empty_case.load_scalars("nokey", None),
11341163
]
11351164

11361165
for dataframe in data:
1137-
assert isinstance(dataframe, DataFrame)
1138-
assert dataframe.empty
1166+
assert isinstance(dataframe, pl.DataFrame)
1167+
assert dataframe.is_empty()
11391168

11401169

11411170
def test_set_failure_will_create_realization_directory(storage):

0 commit comments

Comments
 (0)