Skip to content

Commit be965dc

Browse files
committed
Remove deprecated stan_variable behavior for optimization
1 parent 3add172 commit be965dc

File tree

3 files changed

+13
-26
lines changed

3 files changed

+13
-26
lines changed

cmdstanpy/stanfit/mle.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Container for the result of running optimization"""
22

33
from collections import OrderedDict
4-
from typing import Optional, Union
4+
from typing import Optional
55

66
import numpy as np
77
import pandas as pd
@@ -95,7 +95,7 @@ def __repr__(self) -> str:
9595
repr = '{} optimization failed to converge.'.format(repr)
9696
return repr
9797

98-
def __getattr__(self, attr: str) -> Union[np.ndarray, float]:
98+
def __getattr__(self, attr: str) -> np.ndarray:
9999
"""Synonymous with ``fit.stan_variable(attr)"""
100100
if attr.startswith("_"):
101101
raise AttributeError(f"Unknown variable name {attr}")
@@ -206,7 +206,7 @@ def stan_variable(
206206
*,
207207
inc_iterations: bool = False,
208208
warn: bool = True,
209-
) -> Union[np.ndarray, float]:
209+
) -> np.ndarray:
210210
"""
211211
Return a numpy.ndarray which contains the estimates for the
212212
for the named Stan program variable where the dimensions of the
@@ -254,14 +254,6 @@ def stan_variable(
254254
out: np.ndarray = self._metadata.stan_vars[var].extract_reshape(
255255
data
256256
)
257-
# TODO(2.0) remove
258-
if out.shape == () or out.shape == (1,):
259-
get_logger().warning(
260-
"The default behavior of CmdStanMLE.stan_variable() "
261-
"will change in a future release to always return a "
262-
"numpy.ndarray, even for scalar variables."
263-
)
264-
return out.item() # type: ignore
265257
return out
266258
except KeyError:
267259
# pylint: disable=raise-missing-from
@@ -273,7 +265,7 @@ def stan_variable(
273265

274266
def stan_variables(
275267
self, inc_iterations: bool = False
276-
) -> dict[str, Union[np.ndarray, float]]:
268+
) -> dict[str, np.ndarray]:
277269
"""
278270
Return a dictionary mapping Stan program variables names
279271
to the corresponding numpy.ndarray containing the inferred values.

test/test_optimize.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -201,8 +201,7 @@ def test_variable_bern() -> None:
201201
assert 'theta' in bern_mle.metadata.stan_vars
202202
assert bern_mle.metadata.stan_vars['theta'].dimensions == ()
203203
theta = bern_mle.stan_variable(var='theta')
204-
# TODO(2.0): remove before `or` clause
205-
assert isinstance(theta, float) or theta.shape == ()
204+
assert theta.shape == ()
206205
with pytest.raises(ValueError):
207206
bern_mle.stan_variable(var='eta')
208207
with pytest.raises(ValueError):
@@ -234,17 +233,15 @@ def test_variables_3d() -> None:
234233
var_beta = multidim_mle.stan_variable(var='beta')
235234
assert var_beta.shape == (2,)
236235
var_frac_60 = multidim_mle.stan_variable(var='frac_60')
237-
# TODO(2.0): remove before `or` clause
238-
assert isinstance(var_frac_60, float) or var_frac_60.shape == ()
236+
assert var_frac_60.shape == ()
239237
vars = multidim_mle.stan_variables()
240238
assert len(vars) == len(multidim_mle.metadata.stan_vars)
241239
assert 'y_rep' in vars
242240
assert vars['y_rep'].shape == (5, 4, 3)
243241
assert 'beta' in vars
244242
assert vars['beta'].shape == (2,)
245243
assert 'frac_60' in vars
246-
# TODO(2.0): remove before `or` clause
247-
assert isinstance(vars['frac_60'], float) or vars['frac_60'].shape == ()
244+
assert vars['frac_60'].shape == ()
248245

249246
multidim_mle_iters = multidim_model.optimize(
250247
data=jdata,
@@ -565,8 +562,7 @@ def test_single_row_csv() -> None:
565562
model = CmdStanModel(stan_file=stan)
566563
mle = model.optimize()
567564
theta = mle.stan_variable('theta')
568-
# TODO(2.0): remove before `or` clause
569-
assert isinstance(theta, float) or theta.shape == ()
565+
assert theta.shape == ()
570566
z_as_ndarray = mle.stan_variable(var="z")
571567
assert z_as_ndarray.shape == (4, 3)
572568
for i in range(4):
@@ -627,8 +623,7 @@ def test_attrs() -> None:
627623

628624
assert fit.a == 4.5
629625
assert fit.b.shape == (3,)
630-
# TODO(2.0) remove before `or` clause
631-
assert isinstance(fit.theta, float) or fit.theta.shape == ()
626+
assert fit.theta.shape == ()
632627

633628
assert fit.stan_variable('thin') == 3.5
634629

@@ -673,7 +668,7 @@ def test_serialization() -> None:
673668
)
674669

675670

676-
def test_optimize_create_inits():
671+
def test_optimize_create_inits() -> None:
677672
stan = os.path.join(DATAFILES_PATH, 'bernoulli.stan')
678673
bern_model = CmdStanModel(stan_file=stan)
679674
jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
@@ -686,7 +681,7 @@ def test_optimize_create_inits():
686681
assert len(inits) == 1
687682

688683

689-
def test_optimize_init_sampling():
684+
def test_optimize_init_sampling() -> None:
690685
stan = os.path.join(DATAFILES_PATH, 'logistic.stan')
691686
logistic_model = CmdStanModel(stan_file=stan)
692687
logistic_data = os.path.join(DATAFILES_PATH, 'logistic.data.R')

test/test_variational.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,7 @@ def test_serialization() -> None:
344344
)
345345

346346

347-
def test_variational_create_inits():
347+
def test_variational_create_inits() -> None:
348348
stan = os.path.join(DATAFILES_PATH, 'bernoulli.stan')
349349
bern_model = CmdStanModel(stan_file=stan)
350350
jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
@@ -374,7 +374,7 @@ def test_variational_create_inits():
374374
)
375375

376376

377-
def test_variational_init_sampling():
377+
def test_variational_init_sampling() -> None:
378378
stan = os.path.join(DATAFILES_PATH, 'logistic.stan')
379379
logistic_model = CmdStanModel(stan_file=stan)
380380
logistic_data = os.path.join(DATAFILES_PATH, 'logistic.data.R')

0 commit comments

Comments
 (0)