Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 3 additions & 7 deletions cmdstanpy/cmdstan_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,7 +632,7 @@ class CmdStanArgs:
def __init__(
self,
model_name: str,
model_exe: OptionalPath,
model_exe: str,
chain_ids: Optional[list[int]],
method_args: Union[
SamplerArgs,
Expand Down Expand Up @@ -692,10 +692,6 @@ def validate(self) -> None:
* if no seed specified, set random seed.
* length of per-chain lists equals specified # of chains
"""
if self.model_name is None:
raise ValueError('no stan model specified')
if self.model_exe is None:
raise ValueError('model not compiled')

if self.chain_ids is not None:
for chain_id in self.chain_ids:
Expand Down Expand Up @@ -857,10 +853,10 @@ def compose_command(
idx, len(self.chain_ids)
)
)
cmd.append(self.model_exe) # type: ignore # guaranteed by validate
cmd.append(self.model_exe)
cmd.append(f'id={self.chain_ids[idx]}')
else:
cmd.append(self.model_exe) # type: ignore # guaranteed by validate
cmd.append(self.model_exe)

if self.seed is not None:
if not isinstance(self.seed, list):
Expand Down
101 changes: 19 additions & 82 deletions cmdstanpy/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import platform
import shutil
import subprocess
from copy import copy
from datetime import datetime
from pathlib import Path
from typing import Any, Iterable, Optional, Union
Expand Down Expand Up @@ -37,12 +36,6 @@
'warn-pedantic',
]

# TODO(2.0): remove
STANC_DEPRECATED_OPTS = {
'allow_undefined': 'allow-undefined',
'include_paths': 'include-paths',
}

STANC_IGNORE_OPTS = [
'debug-lex',
'debug-parse',
Expand All @@ -67,7 +60,6 @@
OptionalPath = Union[str, os.PathLike, None]


# TODO(2.0): can remove add function and other logic
class CompilerOptions:
"""
User-specified flags for stanc and C++ compiler.
Expand Down Expand Up @@ -95,26 +87,6 @@ def __repr__(self) -> str:
self._stanc_options, self._cpp_options
)

def __eq__(self, other: Any) -> bool:
"""Overrides the default implementation"""
if self.is_empty() and other is None: # equiv w/r/t compiler
return True
if not isinstance(other, CompilerOptions):
return False
return (
self._stanc_options == other.stanc_options
and self._cpp_options == other.cpp_options
and self._user_header == other.user_header
)

def is_empty(self) -> bool:
"""True if no options specified."""
return (
self._stanc_options == {}
and self._cpp_options == {}
and self._user_header == ''
)

@property
def stanc_options(self) -> dict[str, Union[bool, int, str, Iterable[str]]]:
"""Stanc compiler options."""
Expand Down Expand Up @@ -144,31 +116,12 @@ def validate_stanc_opts(self) -> None:
Check stanc compiler args and consistency between stanc and C++ options.
Raise ValueError if bad config is found.
"""
# pylint: disable=no-member
if self._stanc_options is None:
return
ignore = []
paths = None
has_o_flag = False

for deprecated, replacement in STANC_DEPRECATED_OPTS.items():
if deprecated in self._stanc_options:
if replacement:
get_logger().warning(
'compiler option "%s" is deprecated, use "%s" instead',
deprecated,
replacement,
)
self._stanc_options[replacement] = copy(
self._stanc_options[deprecated]
)
del self._stanc_options[deprecated]
else:
get_logger().warning(
'compiler option "%s" is deprecated and should '
'not be used',
deprecated,
)
for key, val in self._stanc_options.items():
if key in STANC_IGNORE_OPTS:
get_logger().info('ignoring compiler option: %s', key)
Expand Down Expand Up @@ -267,37 +220,6 @@ def validate_user_header(self) -> None:

self._cpp_options['USER_HEADER'] = self._user_header

def add(self, new_opts: "CompilerOptions") -> None: # noqa: disable=Q000
"""Adds options to existing set of compiler options."""
if new_opts.stanc_options is not None:
if self._stanc_options is None:
self._stanc_options = new_opts.stanc_options
else:
for key, val in new_opts.stanc_options.items():
if key == 'include-paths':
if isinstance(val, Iterable) and not isinstance(
val, str
):
for path in val:
self.add_include_path(str(path))
else:
self.add_include_path(str(val))
else:
self._stanc_options[key] = val
if new_opts.cpp_options is not None:
for key, val in new_opts.cpp_options.items():
self._cpp_options[key] = val
if new_opts._user_header != '' and self._user_header == '':
self._user_header = new_opts._user_header

def add_include_path(self, path: str) -> None:
"""Adds include path to existing set of compiler options."""
path = os.path.abspath(os.path.expanduser(path))
if 'include-paths' not in self._stanc_options:
self._stanc_options['include-paths'] = [path]
elif path not in self._stanc_options['include-paths']:
self._stanc_options['include-paths'].append(path)

def compose_stanc(self, filename_in_msg: Optional[str]) -> list[str]:
opts = []

Expand Down Expand Up @@ -343,7 +265,8 @@ def compose(self, filename_in_msg: Optional[str] = None) -> list[str]:


def src_info(
stan_file: str, compiler_options: CompilerOptions
stan_file: str,
stanc_options: Optional[dict[str, Any]] = None,
) -> dict[str, Any]:
"""
Get source info for Stan program file.
Expand All @@ -354,7 +277,7 @@ def src_info(
cmd = (
[stanc_path()]
# handle include-paths, allow-undefined etc
+ compiler_options.compose_stanc(None)
+ CompilerOptions(stanc_options=stanc_options).compose_stanc(None)
+ ['--info', str(stan_file)]
)
proc = subprocess.run(cmd, capture_output=True, text=True, check=False)
Expand Down Expand Up @@ -407,12 +330,26 @@ def compile_stan_file(
)
compiler_options.validate()

# if program has include directives, record path
if '#include' in src.read_text():
path = os.fspath(src.parent.resolve())
if 'include-paths' not in compiler_options.stanc_options:
compiler_options.stanc_options['include-paths'] = [path]
else:
paths: list[str] = compiler_options.stanc_options[
'include-paths'
] # type: ignore
if path not in paths:
paths.append(path)

exe_target = src.with_suffix(EXTENSION)
if exe_target.exists():
exe_time = os.path.getmtime(exe_target)
included_files = [src]
included_files.extend(
src_info(str(src), compiler_options).get('included_files', [])
src_info(str(src), compiler_options.stanc_options).get(
'included_files', []
)
)
out_of_date = any(
os.path.getmtime(included_file) > exe_time
Expand Down Expand Up @@ -482,7 +419,7 @@ def compile_stan_file(
raise ValueError(
f"Failed to compile Stan model '{src}'. Console:\n{console}"
)
return str(exe_target)
return os.fspath(exe_target)


def format_stan_file(
Expand Down
Loading