diff --git a/sklift/models/models.py b/sklift/models/models.py index 52d3e05..df5a618 100644 --- a/sklift/models/models.py +++ b/sklift/models/models.py @@ -142,8 +142,8 @@ def predict(self, X): if self.method == 'dummy': if isinstance(X, np.ndarray): - X_mod_trmnt = np.column_stack((X, np.ones(X.shape[0]))) - X_mod_ctrl = np.column_stack((X, np.zeros(X.shape[0]))) + X_mod_trmnt = np.vstack((X, np.ones(X.shape[1]))) + X_mod_ctrl = np.vstack((X, np.zeros(X.shape[1]))) elif isinstance(X, pd.DataFrame): X_mod_trmnt = X.assign(treatment=np.ones(X.shape[0])) X_mod_ctrl = X.assign(treatment=np.zeros(X.shape[0])) diff --git a/sklift/tests/conftest.py b/sklift/tests/conftest.py index ed23faf..01b5e50 100644 --- a/sklift/tests/conftest.py +++ b/sklift/tests/conftest.py @@ -36,7 +36,7 @@ def random_xy_dataset_regr(request): treat = (np.random.normal(0, 2, (n,)) > 0.0).astype(int) if dataset_type == 'numpy': return X, y, treat - return pd.DataFrame(X), pd.Series(y), pd.Series(treat) + return pd.DataFrame(X, columns=[f"feat_{i}" for i in range(X.shape[1])]), pd.Series(y), pd.Series(treat) @pytest.fixture( @@ -65,5 +65,5 @@ def random_xyt_dataset_clf(request): if dataset_type == 'numpy': return X, y, treat - return pd.DataFrame(X), pd.Series(y), pd.Series(treat) + return pd.DataFrame(X, columns=[f"feat_{i}" for i in range(X.shape[1])]), pd.Series(y), pd.Series(treat) diff --git a/sklift/tests/test_models.py b/sklift/tests/test_models.py index 2e58281..e8c3d1b 100644 --- a/sklift/tests/test_models.py +++ b/sklift/tests/test_models.py @@ -27,7 +27,9 @@ ) def test_shape_classification(model, random_xyt_dataset_clf): X, y, treat = random_xyt_dataset_clf - assert model.fit(X, y, treat).predict(X).shape[0] == y.shape[0] + preds = model.fit(X, y, treat).predict(X) + assert preds.shape[0] == y.shape[0], 'different 0 dim' + assert pd.DataFrame(preds).shape[1] == pd.DataFrame(y).shape[1], 'different 1 dim' pipe = Pipeline(steps=[("scaler", StandardScaler()), ("clf", model)]) assert pipe.fit(X, y, clf__treatment=treat).predict(X).shape[0] == y.shape[0]