Skip to content

Commit 827de32

Browse files
author
C.A.P. Linssen
committed
conditional propagators
1 parent 16db15e commit 827de32

File tree

3 files changed

+99
-25
lines changed

3 files changed

+99
-25
lines changed

odetoolbox/__init__.py

Lines changed: 45 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,30 @@ def _get_all_first_order_variables(indict) -> Iterable[str]:
181181

182182
return variable_names
183183

184+
def symbol_appears_in_any_expr(param_name, solver_json) -> bool:
185+
if "update_expressions" in solver_json.keys():
186+
for sym, expr in solver_json["update_expressions"].items():
187+
if param_name in [str(sym) for sym in list(expr.atoms())]:
188+
return True
189+
190+
if "propagators" in solver_json.keys():
191+
for sym, expr in solver_json["propagators"].items():
192+
if param_name in [str(sym) for sym in list(expr.atoms())]:
193+
return True
194+
195+
if "conditions" in solver_json.keys():
196+
for conditional_solver_json in solver_json["conditions"].values():
197+
if "update_expressions" in conditional_solver_json.keys():
198+
for sym, expr in conditional_solver_json["update_expressions"].items():
199+
if param_name in [str(sym) for sym in list(expr.atoms())]:
200+
return True
201+
202+
if "propagators" in conditional_solver_json.keys():
203+
for sym, expr in solver_json["propagators"].items():
204+
if param_name in [str(sym) for sym in list(expr.atoms())]:
205+
return True
206+
207+
return False
184208

185209
def _analysis(indict, disable_stiffness_check: bool = False, disable_analytic_solver: bool = False, disable_singularity_detection: bool = False, preserve_expressions: Union[bool, Iterable[str]] = False, log_level: Union[str, int] = logging.WARNING) -> Tuple[List[Dict], SystemOfShapes, List[Shape]]:
186210
r"""
@@ -320,20 +344,7 @@ def _analysis(indict, disable_stiffness_check: bool = False, disable_analytic_so
320344
solver_json["parameters"] = {}
321345
for param_name, param_expr in indict["parameters"].items():
322346
# only make parameters appear in a solver if they are actually used there
323-
symbol_appears_in_any_expr = False
324-
if "update_expressions" in solver_json.keys():
325-
for sym, expr in solver_json["update_expressions"].items():
326-
if param_name in [str(sym) for sym in list(expr.atoms())]:
327-
symbol_appears_in_any_expr = True
328-
break
329-
330-
if "propagators" in solver_json.keys():
331-
for sym, expr in solver_json["propagators"].items():
332-
if param_name in [str(sym) for sym in list(expr.atoms())]:
333-
symbol_appears_in_any_expr = True
334-
break
335-
336-
if symbol_appears_in_any_expr:
347+
if symbol_appears_in_any_expr(sym, solver_json):
337348
sympy_expr = sympy.parsing.sympy_parser.parse_expr(param_expr, global_dict=Shape._sympy_globals)
338349

339350
# validate output for numerical problems
@@ -388,6 +399,26 @@ def _analysis(indict, disable_stiffness_check: bool = False, disable_analytic_so
388399
for sym, expr in solver_json["propagators"].items():
389400
solver_json["propagators"][sym] = str(expr)
390401

402+
if "conditions" in solver_json.keys():
403+
for cond, cond_solver in solver_json["conditions"].items():
404+
if "update_expressions" in cond_solver:
405+
for sym, expr in cond_solver["update_expressions"].items():
406+
cond_solver["update_expressions"][sym] = str(expr)
407+
408+
if preserve_expressions and sym in preserve_expressions:
409+
if "analytic" in solver_json["solver"]:
410+
logging.warning("Not preserving expression for variable \"" + sym + "\" as it is solved by propagator solver")
411+
continue
412+
413+
logging.info("Preserving expression for variable \"" + sym + "\"")
414+
var_def_str = _find_variable_definition(indict, sym, order=1)
415+
assert var_def_str is not None
416+
cond_solver["update_expressions"][sym] = var_def_str.replace("'", Config().differential_order_symbol)
417+
418+
if "propagators" in cond_solver:
419+
for sym, expr in cond_solver["propagators"].items():
420+
cond_solver["propagators"][sym] = str(expr)
421+
391422
logging.info("In ode-toolbox: returning outdict = ")
392423
logging.info(json.dumps(solvers_json, indent=4, sort_keys=True))
393424

odetoolbox/analytic_integrator.py

Lines changed: 49 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -72,11 +72,17 @@ def __init__(self, solver_dict, spike_times: Optional[Dict[str, List[float]]] =
7272
subs_dict[k_] = v_
7373
self.shape_starting_values[k] = float(expr.evalf(subs=subs_dict))
7474

75-
self.update_expressions = self.solver_dict["update_expressions"].copy()
76-
for k, v in self.update_expressions.items():
77-
if type(self.update_expressions[k]) is str:
78-
self.update_expressions[k] = sympy.parsing.sympy_parser.parse_expr(self.update_expressions[k], global_dict=Shape._sympy_globals)
7975

76+
#
77+
# initialise update expressions depending on whether conditional solver or not
78+
#
79+
80+
if "update_expressions" in self.solver_dict.keys():
81+
self.update_expressions = self.solver_dict["update_expressions"].copy()
82+
self._process_update_expressions_from_solver_dict()
83+
else:
84+
assert "conditions" in self.solver_dict.keys()
85+
self._pick_solver_based_on_condition()
8086

8187
#
8288
# reset the system to t = 0
@@ -85,24 +91,57 @@ def __init__(self, solver_dict, spike_times: Optional[Dict[str, List[float]]] =
8591
self.reset()
8692

8793

94+
def _condition_holds(self, condition_string) -> bool:
95+
parts = condition_string.strip('()').split('==')
96+
lhs_str = parts[0].strip()
97+
rhs_str = parts[1].strip()
98+
99+
# Sympify each side individually and create the Eq
100+
equation = sympy.Eq(sympy.sympify(lhs_str), sympy.sympify(rhs_str))
101+
102+
return equation.subs(self.solver_dict["parameters"])
103+
104+
105+
def _pick_solver_based_on_condition(self):
106+
r"""In case of a conditional propagator solver: pick a solver depending on the conditions that hold (depending on parameter values)"""
107+
self.update_expressions = self.solver_dict["conditions"]["default"]["update_expressions"]
108+
self.propagators = self.solver_dict["conditions"]["default"]["propagators"]
109+
110+
for condition, conditional_solver in self.solver_dict["conditions"].items():
111+
if condition != "default" and self._condition_holds(condition):
112+
self.update_expressions = conditional_solver["update_expressions"]
113+
self.propagators = conditional_solver["update_expressions"]
114+
break
115+
116+
self._process_update_expressions_from_solver_dict()
117+
118+
119+
def _process_update_expressions_from_solver_dict(self):
88120
#
89-
# in the update expression, replace symbolic variables with their numerical values
121+
# create substitution dictionary to replace symbolic variables with their numerical values
90122
#
91123

92-
self.subs_dict = {}
93-
for prop_symbol, prop_expr in self.solver_dict["propagators"].items():
94-
self.subs_dict[prop_symbol] = prop_expr
124+
subs_dict = {}
125+
for prop_symbol, prop_expr in self.propagators.items():
126+
subs_dict[prop_symbol] = prop_expr
95127
if "parameters" in self.solver_dict.keys():
96128
for param_symbol, param_expr in self.solver_dict["parameters"].items():
97-
self.subs_dict[param_symbol] = param_expr
129+
subs_dict[param_symbol] = param_expr
130+
131+
#
132+
# parse the expressions from JSON if necessary
133+
#
98134

135+
for k, v in self.update_expressions.items():
136+
if type(self.update_expressions[k]) is str:
137+
self.update_expressions[k] = sympy.parsing.sympy_parser.parse_expr(self.update_expressions[k], global_dict=Shape._sympy_globals)
99138

100139
#
101140
# perform substitution in update expressions ahead of time to save time later
102141
#
103142

104143
for k, v in self.update_expressions.items():
105-
self.update_expressions[k] = self.update_expressions[k].subs(self.subs_dict).subs(self.subs_dict)
144+
self.update_expressions[k] = self.update_expressions[k].subs(subs_dict).subs(subs_dict)
106145

107146
#
108147
# autowrap

odetoolbox/system_of_shapes.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,11 @@ def generate_propagator_solver(self, disable_singularity_detection: bool = False
259259
# XXX: TODO: generate/loop over all combinations of conditions!!!
260260

261261
for cond_set in conditions:
262-
condition_str: str = " && ".join(["(" + str(eq.lhs) + " == " + str(eq.rhs) + ")" for eq in cond_set])
262+
if len(cond_set) == 1:
263+
eq = list(cond_set)[0]
264+
condition_str: str = str(eq.lhs) + " == " + str(eq.rhs)
265+
else:
266+
condition_str: str = " && ".join(["(" + str(eq.lhs) + " == " + str(eq.rhs) + ")" for eq in cond_set])
263267
conditional_A = self.A_.copy()
264268
conditional_b = self.b_.copy()
265269
conditional_c = self.c_.copy()

0 commit comments

Comments
 (0)