Skip to content

Commit 89dcd4c

Browse files
authored
Merge pull request #825 from stan-dev/2.0-more-deprecation-removals
2.0 more deprecation removals
2 parents 9662f33 + 78a1304 commit 89dcd4c

File tree

7 files changed

+83
-260
lines changed

7 files changed

+83
-260
lines changed

cmdstanpy/model.py

Lines changed: 3 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from concurrent.futures import ThreadPoolExecutor
1414
from io import StringIO
1515
from multiprocessing import cpu_count
16-
from typing import Any, Callable, Iterable, Mapping, Optional, TypeVar, Union
16+
from typing import Any, Callable, Mapping, Optional, TypeVar, Union
1717

1818
import numpy as np
1919
import pandas as pd
@@ -242,53 +242,6 @@ def src_info(self) -> dict[str, Any]:
242242
return {}
243243
return compilation.src_info(str(self.stan_file), self._stanc_options)
244244

245-
# TODO(2.0) remove
246-
def format(
247-
self,
248-
overwrite_file: bool = False,
249-
canonicalize: Union[bool, str, Iterable[str]] = False,
250-
max_line_length: int = 78,
251-
*,
252-
backup: bool = True,
253-
) -> None:
254-
"""
255-
Deprecated: Use :func:`cmdstanpy.format_stan_file()` instead.
256-
257-
Run stanc's auto-formatter on the model code. Either saves directly
258-
back to the file or prints for inspection
259-
260-
261-
:param overwrite_file: If True, save the updated code to disk, rather
262-
than printing it. By default False
263-
:param canonicalize: Whether or not the compiler should 'canonicalize'
264-
the Stan model, removing things like deprecated syntax. Default is
265-
False. If True, all canonicalizations are run. If it is a list of
266-
strings, those options are passed to stanc (new in Stan 2.29)
267-
:param max_line_length: Set the wrapping point for the formatter. The
268-
default value is 78, which wraps most lines by the 80th character.
269-
:param backup: If True, create a stanfile.bak backup before
270-
writing to the file. Only disable this if you're sure you have other
271-
copies of the file or are using a version control system like Git.
272-
"""
273-
274-
get_logger().warning(
275-
"CmdStanModel.format() is deprecated and will be "
276-
"removed in the next major version.\n"
277-
"Use cmdstanpy.format_stan_file() instead."
278-
)
279-
280-
if self.stan_file is None:
281-
raise ValueError("No Stan file found for this module")
282-
283-
compilation.format_stan_file(
284-
self.stan_file,
285-
overwrite_file=overwrite_file,
286-
max_line_length=max_line_length,
287-
canonicalize=canonicalize,
288-
backup=backup,
289-
stanc_options=self._stanc_options,
290-
)
291-
292245
def code(self) -> Optional[str]:
293246
"""Return Stan program as a string."""
294247
if not self._stan_file:
@@ -517,9 +470,7 @@ def sample(
517470
save_warmup: bool = False,
518471
thin: Optional[int] = None,
519472
max_treedepth: Optional[int] = None,
520-
metric: Union[
521-
str, dict[str, Any], list[str], list[dict[str, Any]], None
522-
] = None,
473+
metric: Optional[str] = None,
523474
step_size: Union[float, list[float], None] = None,
524475
adapt_engaged: bool = True,
525476
adapt_delta: Optional[float] = None,
@@ -845,27 +796,6 @@ def sample(
845796
'Chain_id must be a non-negative integer value,'
846797
' found {}.'.format(chain_id)
847798
)
848-
if metric is not None and metric not in (
849-
'diag',
850-
'dense',
851-
'unit_e',
852-
'diag_e',
853-
'dense_e',
854-
):
855-
get_logger().warning(
856-
"Providing anything other than metric type for"
857-
" 'metric' is deprecated and will be removed"
858-
" in the next major release."
859-
" Please provide such information via"
860-
" 'inv_metric' argument."
861-
)
862-
if inv_metric is not None:
863-
raise ValueError(
864-
"Cannot provide both (deprecated) non-metric-type 'metric'"
865-
" argument and 'inv_metric' argument."
866-
)
867-
inv_metric = metric # type: ignore # for backwards compatibility
868-
metric = None
869799

870800
if metric is None and inv_metric is not None:
871801
metric = try_deduce_metric_type(inv_metric)
@@ -908,7 +838,7 @@ def sample(
908838
save_warmup=save_warmup,
909839
thin=thin,
910840
max_treedepth=max_treedepth,
911-
metric_type=metric, # type: ignore
841+
metric_type=metric,
912842
metric_file=cmdstan_metrics,
913843
step_size=step_size,
914844
adapt_engaged=adapt_engaged,
@@ -1037,8 +967,6 @@ def generate_quantities(
1037967
refresh: Optional[int] = None,
1038968
time_fmt: str = "%Y%m%d%H%M%S",
1039969
timeout: Optional[float] = None,
1040-
*,
1041-
mcmc_sample: Union[CmdStanMCMC, list[str], None] = None,
1042970
) -> CmdStanGQ[Fit]:
1043971
"""
1044972
Run CmdStan's generate_quantities method which runs the generated
@@ -1104,19 +1032,6 @@ def generate_quantities(
11041032
11051033
:return: CmdStanGQ object
11061034
"""
1107-
# TODO(2.0): remove
1108-
if mcmc_sample is not None:
1109-
if previous_fit:
1110-
raise ValueError(
1111-
"Cannot supply both 'previous_fit' and "
1112-
"deprecated argument 'mcmc_sample'"
1113-
)
1114-
get_logger().warning(
1115-
"Argument name `mcmc_sample` is deprecated, please "
1116-
"rename to `previous_fit`."
1117-
)
1118-
1119-
previous_fit = mcmc_sample # type: ignore
11201035

11211036
if isinstance(previous_fit, (CmdStanMCMC, CmdStanMLE, CmdStanVB)):
11221037
fit_object = previous_fit
@@ -1243,8 +1158,6 @@ def variational(
12431158
refresh: Optional[int] = None,
12441159
time_fmt: str = "%Y%m%d%H%M%S",
12451160
timeout: Optional[float] = None,
1246-
*,
1247-
output_samples: Optional[int] = None,
12481161
) -> CmdStanVB:
12491162
"""
12501163
Run CmdStan's variational inference algorithm to approximate
@@ -1342,19 +1255,6 @@ def variational(
13421255
13431256
:return: CmdStanVB object
13441257
"""
1345-
# TODO(2.0): remove
1346-
if output_samples is not None:
1347-
if draws is not None:
1348-
raise ValueError(
1349-
"Cannot supply both 'draws' and deprecated argument "
1350-
"'output_samples'"
1351-
)
1352-
get_logger().warning(
1353-
"Argument name `output_samples` is deprecated, please "
1354-
"rename to `draws`."
1355-
)
1356-
1357-
draws = output_samples
13581258

13591259
variational_args = VariationalArgs(
13601260
algorithm=algorithm,

cmdstanpy/stanfit/mle.py

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Container for the result of running optimization"""
22

33
from collections import OrderedDict
4-
from typing import Optional, Union
4+
from typing import Optional
55

66
import numpy as np
77
import pandas as pd
@@ -77,9 +77,7 @@ def create_inits(
7777
"""
7878
# pylint: disable=unused-argument
7979

80-
return {
81-
name: np.array(val) for name, val in self.stan_variables().items()
82-
}
80+
return self.stan_variables()
8381

8482
def __repr__(self) -> str:
8583
repr = 'CmdStanMLE: model={}{}'.format(
@@ -95,7 +93,7 @@ def __repr__(self) -> str:
9593
repr = '{} optimization failed to converge.'.format(repr)
9694
return repr
9795

98-
def __getattr__(self, attr: str) -> Union[np.ndarray, float]:
96+
def __getattr__(self, attr: str) -> np.ndarray:
9997
"""Synonymous with ``fit.stan_variable(attr)"""
10098
if attr.startswith("_"):
10199
raise AttributeError(f"Unknown variable name {attr}")
@@ -206,7 +204,7 @@ def stan_variable(
206204
*,
207205
inc_iterations: bool = False,
208206
warn: bool = True,
209-
) -> Union[np.ndarray, float]:
207+
) -> np.ndarray:
210208
"""
211209
Return a numpy.ndarray which contains the estimates for the
212210
for the named Stan program variable where the dimensions of the
@@ -254,14 +252,6 @@ def stan_variable(
254252
out: np.ndarray = self._metadata.stan_vars[var].extract_reshape(
255253
data
256254
)
257-
# TODO(2.0) remove
258-
if out.shape == () or out.shape == (1,):
259-
get_logger().warning(
260-
"The default behavior of CmdStanMLE.stan_variable() "
261-
"will change in a future release to always return a "
262-
"numpy.ndarray, even for scalar variables."
263-
)
264-
return out.item() # type: ignore
265255
return out
266256
except KeyError:
267257
# pylint: disable=raise-missing-from
@@ -273,7 +263,7 @@ def stan_variable(
273263

274264
def stan_variables(
275265
self, inc_iterations: bool = False
276-
) -> dict[str, Union[np.ndarray, float]]:
266+
) -> dict[str, np.ndarray]:
277267
"""
278268
Return a dictionary mapping Stan program variables names
279269
to the corresponding numpy.ndarray containing the inferred values.

cmdstanpy/stanfit/vb.py

Lines changed: 6 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
from cmdstanpy.cmdstan_args import Method
1010
from cmdstanpy.utils import stancsv
11-
from cmdstanpy.utils.logging import get_logger
1211

1312
from .metadata import InferenceMetadata
1413
from .runset import RunSet
@@ -100,7 +99,7 @@ def __repr__(self) -> str:
10099
# TODO - diagnostic, profiling files
101100
return repr
102101

103-
def __getattr__(self, attr: str) -> Union[np.ndarray, float]:
102+
def __getattr__(self, attr: str) -> np.ndarray:
104103
"""Synonymous with ``fit.stan_variable(attr)"""
105104
if attr.startswith("_"):
106105
raise AttributeError(f"Unknown variable name {attr}")
@@ -163,9 +162,7 @@ def metadata(self) -> InferenceMetadata:
163162
"""
164163
return self._metadata
165164

166-
def stan_variable(
167-
self, var: str, *, mean: Optional[bool] = None
168-
) -> Union[np.ndarray, float]:
165+
def stan_variable(self, var: str, *, mean: bool = False) -> np.ndarray:
169166
"""
170167
Return a numpy.ndarray which contains the estimates for the
171168
for the named Stan program variable where the dimensions of the
@@ -188,8 +185,7 @@ def stan_variable(
188185
:param var: variable name
189186
190187
:param mean: if True, return the variational mean. Otherwise,
191-
return the variational sample. The default behavior will
192-
change in a future release to return the variational sample.
188+
return the variational sample. Defaults to False.
193189
194190
See Also
195191
--------
@@ -200,16 +196,7 @@ def stan_variable(
200196
CmdStanGQ.stan_variable
201197
CmdStanLaplace.stan_variable
202198
"""
203-
# TODO(2.0): remove None case, make default `False`
204-
if mean is None:
205-
get_logger().warning(
206-
"The default behavior of CmdStanVB.stan_variable() "
207-
"will change in a future release to return the "
208-
"variational sample, rather than the mean.\n"
209-
"To maintain the current behavior, pass the argument "
210-
"mean=True"
211-
)
212-
mean = True
199+
213200
if mean:
214201
draws = self._variational_mean
215202
else:
@@ -219,16 +206,7 @@ def stan_variable(
219206
out: np.ndarray = self._metadata.stan_vars[var].extract_reshape(
220207
draws
221208
)
222-
# TODO(2.0): remove
223-
if out.shape == () or out.shape == (1,):
224-
if mean:
225-
get_logger().warning(
226-
"The default behavior of "
227-
"CmdStanVB.stan_variable(mean=True) will change in a "
228-
"future release to always return a numpy.ndarray, even "
229-
"for scalar variables."
230-
)
231-
return out.item() # type: ignore
209+
232210
return out
233211
except KeyError:
234212
# pylint: disable=raise-missing-from
@@ -238,9 +216,7 @@ def stan_variable(
238216
+ ", ".join(self._metadata.stan_vars.keys())
239217
)
240218

241-
def stan_variables(
242-
self, *, mean: Optional[bool] = None
243-
) -> dict[str, Union[np.ndarray, float]]:
219+
def stan_variables(self, *, mean: bool = False) -> dict[str, np.ndarray]:
244220
"""
245221
Return a dictionary mapping Stan program variables names
246222
to the corresponding numpy.ndarray containing the inferred values.

test/test_compiler_opts.py renamed to test/test_compilation.py

Lines changed: 53 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
"""Compiler options tests"""
22

3+
import contextlib
4+
import io
35
import logging
46
import os
5-
from test import check_present
7+
from test import check_present, raises_nested
8+
from unittest.mock import MagicMock, patch
69

710
import pytest
811

9-
from cmdstanpy.compilation import CompilerOptions
12+
from cmdstanpy.compilation import CompilerOptions, format_stan_file
13+
from cmdstanpy.utils import cmdstan_version_before
1014

1115
HERE = os.path.dirname(os.path.abspath(__file__))
1216
DATAFILES_PATH = os.path.join(HERE, 'data')
@@ -191,3 +195,50 @@ def test_user_header() -> None:
191195
)
192196
with pytest.raises(ValueError, match="Disagreement"):
193197
opts.validate()
198+
199+
200+
def test_model_format_options() -> None:
201+
stan = os.path.join(DATAFILES_PATH, 'format_me.stan')
202+
203+
sys_stdout = io.StringIO()
204+
with contextlib.redirect_stdout(sys_stdout):
205+
format_stan_file(stan, max_line_length=10)
206+
formatted = sys_stdout.getvalue()
207+
assert len(formatted.splitlines()) > 11
208+
209+
sys_stdout = io.StringIO()
210+
with contextlib.redirect_stdout(sys_stdout):
211+
format_stan_file(stan, canonicalize='braces')
212+
formatted = sys_stdout.getvalue()
213+
assert formatted.count('{') == 3
214+
assert formatted.count('(') == 4
215+
216+
sys_stdout = io.StringIO()
217+
with contextlib.redirect_stdout(sys_stdout):
218+
format_stan_file(stan, canonicalize=['parentheses'])
219+
formatted = sys_stdout.getvalue()
220+
assert formatted.count('{') == 1
221+
assert formatted.count('(') == 1
222+
223+
sys_stdout = io.StringIO()
224+
with contextlib.redirect_stdout(sys_stdout):
225+
format_stan_file(stan, canonicalize=True)
226+
formatted = sys_stdout.getvalue()
227+
assert formatted.count('{') == 3
228+
assert formatted.count('(') == 1
229+
230+
231+
@patch(
232+
'cmdstanpy.utils.cmdstan.cmdstan_version',
233+
MagicMock(return_value=(2, 27)),
234+
)
235+
def test_format_old_version() -> None:
236+
assert cmdstan_version_before(2, 28)
237+
238+
stan = os.path.join(DATAFILES_PATH, 'format_me.stan')
239+
with raises_nested(RuntimeError, r"--canonicalize"):
240+
format_stan_file(stan, canonicalize='braces')
241+
with raises_nested(RuntimeError, r"--max-line"):
242+
format_stan_file(stan, max_line_length=88)
243+
244+
format_stan_file(stan, canonicalize=True)

0 commit comments

Comments
 (0)