@@ -181,6 +181,30 @@ def _get_all_first_order_variables(indict) -> Iterable[str]:
181181
182182 return variable_names
183183
184+ def symbol_appears_in_any_expr (param_name , solver_json ) -> bool :
185+ if "update_expressions" in solver_json .keys ():
186+ for sym , expr in solver_json ["update_expressions" ].items ():
187+ if param_name in [str (sym ) for sym in list (expr .atoms ())]:
188+ return True
189+
190+ if "propagators" in solver_json .keys ():
191+ for sym , expr in solver_json ["propagators" ].items ():
192+ if param_name in [str (sym ) for sym in list (expr .atoms ())]:
193+ return True
194+
195+ if "conditions" in solver_json .keys ():
196+ for conditional_solver_json in solver_json ["conditions" ].values ():
197+ if "update_expressions" in conditional_solver_json .keys ():
198+ for sym , expr in conditional_solver_json ["update_expressions" ].items ():
199+ if param_name in [str (sym ) for sym in list (expr .atoms ())]:
200+ return True
201+
202+ if "propagators" in conditional_solver_json .keys ():
203+ for sym , expr in solver_json ["propagators" ].items ():
204+ if param_name in [str (sym ) for sym in list (expr .atoms ())]:
205+ return True
206+
207+ return False
184208
185209def _analysis (indict , disable_stiffness_check : bool = False , disable_analytic_solver : bool = False , disable_singularity_detection : bool = False , preserve_expressions : Union [bool , Iterable [str ]] = False , log_level : Union [str , int ] = logging .WARNING ) -> Tuple [List [Dict ], SystemOfShapes , List [Shape ]]:
186210 r"""
@@ -320,20 +344,7 @@ def _analysis(indict, disable_stiffness_check: bool = False, disable_analytic_so
320344 solver_json ["parameters" ] = {}
321345 for param_name , param_expr in indict ["parameters" ].items ():
322346 # only make parameters appear in a solver if they are actually used there
323- symbol_appears_in_any_expr = False
324- if "update_expressions" in solver_json .keys ():
325- for sym , expr in solver_json ["update_expressions" ].items ():
326- if param_name in [str (sym ) for sym in list (expr .atoms ())]:
327- symbol_appears_in_any_expr = True
328- break
329-
330- if "propagators" in solver_json .keys ():
331- for sym , expr in solver_json ["propagators" ].items ():
332- if param_name in [str (sym ) for sym in list (expr .atoms ())]:
333- symbol_appears_in_any_expr = True
334- break
335-
336- if symbol_appears_in_any_expr :
347+ if symbol_appears_in_any_expr (sym , solver_json ):
337348 sympy_expr = sympy .parsing .sympy_parser .parse_expr (param_expr , global_dict = Shape ._sympy_globals )
338349
339350 # validate output for numerical problems
@@ -388,6 +399,26 @@ def _analysis(indict, disable_stiffness_check: bool = False, disable_analytic_so
388399 for sym , expr in solver_json ["propagators" ].items ():
389400 solver_json ["propagators" ][sym ] = str (expr )
390401
402+ if "conditions" in solver_json .keys ():
403+ for cond , cond_solver in solver_json ["conditions" ].items ():
404+ if "update_expressions" in cond_solver :
405+ for sym , expr in cond_solver ["update_expressions" ].items ():
406+ cond_solver ["update_expressions" ][sym ] = str (expr )
407+
408+ if preserve_expressions and sym in preserve_expressions :
409+ if "analytic" in solver_json ["solver" ]:
410+ logging .warning ("Not preserving expression for variable \" " + sym + "\" as it is solved by propagator solver" )
411+ continue
412+
413+ logging .info ("Preserving expression for variable \" " + sym + "\" " )
414+ var_def_str = _find_variable_definition (indict , sym , order = 1 )
415+ assert var_def_str is not None
416+ cond_solver ["update_expressions" ][sym ] = var_def_str .replace ("'" , Config ().differential_order_symbol )
417+
418+ if "propagators" in cond_solver :
419+ for sym , expr in cond_solver ["propagators" ].items ():
420+ cond_solver ["propagators" ][sym ] = str (expr )
421+
391422 logging .info ("In ode-toolbox: returning outdict = " )
392423 logging .info (json .dumps (solvers_json , indent = 4 , sort_keys = True ))
393424
0 commit comments