Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 3 additions & 103 deletions cmdstanpy/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from concurrent.futures import ThreadPoolExecutor
from io import StringIO
from multiprocessing import cpu_count
from typing import Any, Callable, Iterable, Mapping, Optional, TypeVar, Union
from typing import Any, Callable, Mapping, Optional, TypeVar, Union

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -242,53 +242,6 @@ def src_info(self) -> dict[str, Any]:
return {}
return compilation.src_info(str(self.stan_file), self._stanc_options)

# TODO(2.0) remove
def format(
self,
overwrite_file: bool = False,
canonicalize: Union[bool, str, Iterable[str]] = False,
max_line_length: int = 78,
*,
backup: bool = True,
) -> None:
"""
Deprecated: Use :func:`cmdstanpy.format_stan_file()` instead.

Run stanc's auto-formatter on the model code. Either saves directly
back to the file or prints for inspection


:param overwrite_file: If True, save the updated code to disk, rather
than printing it. By default False
:param canonicalize: Whether or not the compiler should 'canonicalize'
the Stan model, removing things like deprecated syntax. Default is
False. If True, all canonicalizations are run. If it is a list of
strings, those options are passed to stanc (new in Stan 2.29)
:param max_line_length: Set the wrapping point for the formatter. The
default value is 78, which wraps most lines by the 80th character.
:param backup: If True, create a stanfile.bak backup before
writing to the file. Only disable this if you're sure you have other
copies of the file or are using a version control system like Git.
"""

get_logger().warning(
"CmdStanModel.format() is deprecated and will be "
"removed in the next major version.\n"
"Use cmdstanpy.format_stan_file() instead."
)

if self.stan_file is None:
raise ValueError("No Stan file found for this module")

compilation.format_stan_file(
self.stan_file,
overwrite_file=overwrite_file,
max_line_length=max_line_length,
canonicalize=canonicalize,
backup=backup,
stanc_options=self._stanc_options,
)

def code(self) -> Optional[str]:
"""Return Stan program as a string."""
if not self._stan_file:
Expand Down Expand Up @@ -517,9 +470,7 @@ def sample(
save_warmup: bool = False,
thin: Optional[int] = None,
max_treedepth: Optional[int] = None,
metric: Union[
str, dict[str, Any], list[str], list[dict[str, Any]], None
] = None,
metric: Optional[str] = None,
step_size: Union[float, list[float], None] = None,
adapt_engaged: bool = True,
adapt_delta: Optional[float] = None,
Expand Down Expand Up @@ -845,27 +796,6 @@ def sample(
'Chain_id must be a non-negative integer value,'
' found {}.'.format(chain_id)
)
if metric is not None and metric not in (
'diag',
'dense',
'unit_e',
'diag_e',
'dense_e',
):
get_logger().warning(
"Providing anything other than metric type for"
" 'metric' is deprecated and will be removed"
" in the next major release."
" Please provide such information via"
" 'inv_metric' argument."
)
if inv_metric is not None:
raise ValueError(
"Cannot provide both (deprecated) non-metric-type 'metric'"
" argument and 'inv_metric' argument."
)
inv_metric = metric # type: ignore # for backwards compatibility
metric = None

if metric is None and inv_metric is not None:
metric = try_deduce_metric_type(inv_metric)
Expand Down Expand Up @@ -908,7 +838,7 @@ def sample(
save_warmup=save_warmup,
thin=thin,
max_treedepth=max_treedepth,
metric_type=metric, # type: ignore
metric_type=metric,
metric_file=cmdstan_metrics,
step_size=step_size,
adapt_engaged=adapt_engaged,
Expand Down Expand Up @@ -1037,8 +967,6 @@ def generate_quantities(
refresh: Optional[int] = None,
time_fmt: str = "%Y%m%d%H%M%S",
timeout: Optional[float] = None,
*,
mcmc_sample: Union[CmdStanMCMC, list[str], None] = None,
) -> CmdStanGQ[Fit]:
"""
Run CmdStan's generate_quantities method which runs the generated
Expand Down Expand Up @@ -1104,19 +1032,6 @@ def generate_quantities(

:return: CmdStanGQ object
"""
# TODO(2.0): remove
if mcmc_sample is not None:
if previous_fit:
raise ValueError(
"Cannot supply both 'previous_fit' and "
"deprecated argument 'mcmc_sample'"
)
get_logger().warning(
"Argument name `mcmc_sample` is deprecated, please "
"rename to `previous_fit`."
)

previous_fit = mcmc_sample # type: ignore

if isinstance(previous_fit, (CmdStanMCMC, CmdStanMLE, CmdStanVB)):
fit_object = previous_fit
Expand Down Expand Up @@ -1243,8 +1158,6 @@ def variational(
refresh: Optional[int] = None,
time_fmt: str = "%Y%m%d%H%M%S",
timeout: Optional[float] = None,
*,
output_samples: Optional[int] = None,
) -> CmdStanVB:
"""
Run CmdStan's variational inference algorithm to approximate
Expand Down Expand Up @@ -1342,19 +1255,6 @@ def variational(

:return: CmdStanVB object
"""
# TODO(2.0): remove
if output_samples is not None:
if draws is not None:
raise ValueError(
"Cannot supply both 'draws' and deprecated argument "
"'output_samples'"
)
get_logger().warning(
"Argument name `output_samples` is deprecated, please "
"rename to `draws`."
)

draws = output_samples

variational_args = VariationalArgs(
algorithm=algorithm,
Expand Down
20 changes: 5 additions & 15 deletions cmdstanpy/stanfit/mle.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Container for the result of running optimization"""

from collections import OrderedDict
from typing import Optional, Union
from typing import Optional

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -77,9 +77,7 @@ def create_inits(
"""
# pylint: disable=unused-argument

return {
name: np.array(val) for name, val in self.stan_variables().items()
}
return self.stan_variables()

def __repr__(self) -> str:
repr = 'CmdStanMLE: model={}{}'.format(
Expand All @@ -95,7 +93,7 @@ def __repr__(self) -> str:
repr = '{} optimization failed to converge.'.format(repr)
return repr

def __getattr__(self, attr: str) -> Union[np.ndarray, float]:
def __getattr__(self, attr: str) -> np.ndarray:
"""Synonymous with ``fit.stan_variable(attr)"""
if attr.startswith("_"):
raise AttributeError(f"Unknown variable name {attr}")
Expand Down Expand Up @@ -206,7 +204,7 @@ def stan_variable(
*,
inc_iterations: bool = False,
warn: bool = True,
) -> Union[np.ndarray, float]:
) -> np.ndarray:
"""
Return a numpy.ndarray which contains the estimates for the
for the named Stan program variable where the dimensions of the
Expand Down Expand Up @@ -254,14 +252,6 @@ def stan_variable(
out: np.ndarray = self._metadata.stan_vars[var].extract_reshape(
data
)
# TODO(2.0) remove
if out.shape == () or out.shape == (1,):
get_logger().warning(
"The default behavior of CmdStanMLE.stan_variable() "
"will change in a future release to always return a "
"numpy.ndarray, even for scalar variables."
)
return out.item() # type: ignore
return out
except KeyError:
# pylint: disable=raise-missing-from
Expand All @@ -273,7 +263,7 @@ def stan_variable(

def stan_variables(
self, inc_iterations: bool = False
) -> dict[str, Union[np.ndarray, float]]:
) -> dict[str, np.ndarray]:
"""
Return a dictionary mapping Stan program variables names
to the corresponding numpy.ndarray containing the inferred values.
Expand Down
36 changes: 6 additions & 30 deletions cmdstanpy/stanfit/vb.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

from cmdstanpy.cmdstan_args import Method
from cmdstanpy.utils import stancsv
from cmdstanpy.utils.logging import get_logger

from .metadata import InferenceMetadata
from .runset import RunSet
Expand Down Expand Up @@ -100,7 +99,7 @@ def __repr__(self) -> str:
# TODO - diagnostic, profiling files
return repr

def __getattr__(self, attr: str) -> Union[np.ndarray, float]:
def __getattr__(self, attr: str) -> np.ndarray:
"""Synonymous with ``fit.stan_variable(attr)"""
if attr.startswith("_"):
raise AttributeError(f"Unknown variable name {attr}")
Expand Down Expand Up @@ -163,9 +162,7 @@ def metadata(self) -> InferenceMetadata:
"""
return self._metadata

def stan_variable(
self, var: str, *, mean: Optional[bool] = None
) -> Union[np.ndarray, float]:
def stan_variable(self, var: str, *, mean: bool = False) -> np.ndarray:
"""
Return a numpy.ndarray which contains the estimates for the
for the named Stan program variable where the dimensions of the
Expand All @@ -188,8 +185,7 @@ def stan_variable(
:param var: variable name

:param mean: if True, return the variational mean. Otherwise,
return the variational sample. The default behavior will
change in a future release to return the variational sample.
return the variational sample. Defaults to False.

See Also
--------
Expand All @@ -200,16 +196,7 @@ def stan_variable(
CmdStanGQ.stan_variable
CmdStanLaplace.stan_variable
"""
# TODO(2.0): remove None case, make default `False`
if mean is None:
get_logger().warning(
"The default behavior of CmdStanVB.stan_variable() "
"will change in a future release to return the "
"variational sample, rather than the mean.\n"
"To maintain the current behavior, pass the argument "
"mean=True"
)
mean = True

if mean:
draws = self._variational_mean
else:
Expand All @@ -219,16 +206,7 @@ def stan_variable(
out: np.ndarray = self._metadata.stan_vars[var].extract_reshape(
draws
)
# TODO(2.0): remove
if out.shape == () or out.shape == (1,):
if mean:
get_logger().warning(
"The default behavior of "
"CmdStanVB.stan_variable(mean=True) will change in a "
"future release to always return a numpy.ndarray, even "
"for scalar variables."
)
return out.item() # type: ignore

return out
except KeyError:
# pylint: disable=raise-missing-from
Expand All @@ -238,9 +216,7 @@ def stan_variable(
+ ", ".join(self._metadata.stan_vars.keys())
)

def stan_variables(
self, *, mean: Optional[bool] = None
) -> dict[str, Union[np.ndarray, float]]:
def stan_variables(self, *, mean: bool = False) -> dict[str, np.ndarray]:
"""
Return a dictionary mapping Stan program variables names
to the corresponding numpy.ndarray containing the inferred values.
Expand Down
55 changes: 53 additions & 2 deletions test/test_compiler_opts.py → test/test_compilation.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
"""Compiler options tests"""

import contextlib
import io
import logging
import os
from test import check_present
from test import check_present, raises_nested
from unittest.mock import MagicMock, patch

import pytest

from cmdstanpy.compilation import CompilerOptions
from cmdstanpy.compilation import CompilerOptions, format_stan_file
from cmdstanpy.utils import cmdstan_version_before

HERE = os.path.dirname(os.path.abspath(__file__))
DATAFILES_PATH = os.path.join(HERE, 'data')
Expand Down Expand Up @@ -191,3 +195,50 @@ def test_user_header() -> None:
)
with pytest.raises(ValueError, match="Disagreement"):
opts.validate()


def test_model_format_options() -> None:
stan = os.path.join(DATAFILES_PATH, 'format_me.stan')

sys_stdout = io.StringIO()
with contextlib.redirect_stdout(sys_stdout):
format_stan_file(stan, max_line_length=10)
formatted = sys_stdout.getvalue()
assert len(formatted.splitlines()) > 11

sys_stdout = io.StringIO()
with contextlib.redirect_stdout(sys_stdout):
format_stan_file(stan, canonicalize='braces')
formatted = sys_stdout.getvalue()
assert formatted.count('{') == 3
assert formatted.count('(') == 4

sys_stdout = io.StringIO()
with contextlib.redirect_stdout(sys_stdout):
format_stan_file(stan, canonicalize=['parentheses'])
formatted = sys_stdout.getvalue()
assert formatted.count('{') == 1
assert formatted.count('(') == 1

sys_stdout = io.StringIO()
with contextlib.redirect_stdout(sys_stdout):
format_stan_file(stan, canonicalize=True)
formatted = sys_stdout.getvalue()
assert formatted.count('{') == 3
assert formatted.count('(') == 1


@patch(
'cmdstanpy.utils.cmdstan.cmdstan_version',
MagicMock(return_value=(2, 27)),
)
def test_format_old_version() -> None:
assert cmdstan_version_before(2, 28)

stan = os.path.join(DATAFILES_PATH, 'format_me.stan')
with raises_nested(RuntimeError, r"--canonicalize"):
format_stan_file(stan, canonicalize='braces')
with raises_nested(RuntimeError, r"--max-line"):
format_stan_file(stan, max_line_length=88)

format_stan_file(stan, canonicalize=True)
Loading