Skip to content

Commit bd883ac

Browse files
author
C.A.P. Linssen
committed
conditional propagators
1 parent 16db15e commit bd883ac

File tree

5 files changed

+155
-58
lines changed

5 files changed

+155
-58
lines changed

odetoolbox/__init__.py

Lines changed: 47 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,32 @@ def _get_all_first_order_variables(indict) -> Iterable[str]:
182182
return variable_names
183183

184184

185+
def symbol_appears_in_any_expr(param_name, solver_json) -> bool:
186+
if "update_expressions" in solver_json.keys():
187+
for sym, expr in solver_json["update_expressions"].items():
188+
if param_name in [str(sym) for sym in list(expr.atoms())]:
189+
return True
190+
191+
if "propagators" in solver_json.keys():
192+
for sym, expr in solver_json["propagators"].items():
193+
if param_name in [str(sym) for sym in list(expr.atoms())]:
194+
return True
195+
196+
if "conditions" in solver_json.keys():
197+
for conditional_solver_json in solver_json["conditions"].values():
198+
if "update_expressions" in conditional_solver_json.keys():
199+
for sym, expr in conditional_solver_json["update_expressions"].items():
200+
if param_name in [str(sym) for sym in list(expr.atoms())]:
201+
return True
202+
203+
if "propagators" in conditional_solver_json.keys():
204+
for sym, expr in solver_json["propagators"].items():
205+
if param_name in [str(sym) for sym in list(expr.atoms())]:
206+
return True
207+
208+
return False
209+
210+
185211
def _analysis(indict, disable_stiffness_check: bool = False, disable_analytic_solver: bool = False, disable_singularity_detection: bool = False, preserve_expressions: Union[bool, Iterable[str]] = False, log_level: Union[str, int] = logging.WARNING) -> Tuple[List[Dict], SystemOfShapes, List[Shape]]:
186212
r"""
187213
Like analysis(), but additionally returns ``shape_sys`` and ``shapes``.
@@ -320,20 +346,7 @@ def _analysis(indict, disable_stiffness_check: bool = False, disable_analytic_so
320346
solver_json["parameters"] = {}
321347
for param_name, param_expr in indict["parameters"].items():
322348
# only make parameters appear in a solver if they are actually used there
323-
symbol_appears_in_any_expr = False
324-
if "update_expressions" in solver_json.keys():
325-
for sym, expr in solver_json["update_expressions"].items():
326-
if param_name in [str(sym) for sym in list(expr.atoms())]:
327-
symbol_appears_in_any_expr = True
328-
break
329-
330-
if "propagators" in solver_json.keys():
331-
for sym, expr in solver_json["propagators"].items():
332-
if param_name in [str(sym) for sym in list(expr.atoms())]:
333-
symbol_appears_in_any_expr = True
334-
break
335-
336-
if symbol_appears_in_any_expr:
349+
if symbol_appears_in_any_expr(sym, solver_json):
337350
sympy_expr = sympy.parsing.sympy_parser.parse_expr(param_expr, global_dict=Shape._sympy_globals)
338351

339352
# validate output for numerical problems
@@ -388,6 +401,26 @@ def _analysis(indict, disable_stiffness_check: bool = False, disable_analytic_so
388401
for sym, expr in solver_json["propagators"].items():
389402
solver_json["propagators"][sym] = str(expr)
390403

404+
if "conditions" in solver_json.keys():
405+
for cond, cond_solver in solver_json["conditions"].items():
406+
if "update_expressions" in cond_solver:
407+
for sym, expr in cond_solver["update_expressions"].items():
408+
cond_solver["update_expressions"][sym] = str(expr)
409+
410+
if preserve_expressions and sym in preserve_expressions:
411+
if "analytic" in solver_json["solver"]:
412+
logging.warning("Not preserving expression for variable \"" + sym + "\" as it is solved by propagator solver")
413+
continue
414+
415+
logging.info("Preserving expression for variable \"" + sym + "\"")
416+
var_def_str = _find_variable_definition(indict, sym, order=1)
417+
assert var_def_str is not None
418+
cond_solver["update_expressions"][sym] = var_def_str.replace("'", Config().differential_order_symbol)
419+
420+
if "propagators" in cond_solver:
421+
for sym, expr in cond_solver["propagators"].items():
422+
cond_solver["propagators"][sym] = str(expr)
423+
391424
logging.info("In ode-toolbox: returning outdict = ")
392425
logging.info(json.dumps(solvers_json, indent=4, sort_keys=True))
393426

odetoolbox/analytic_integrator.py

Lines changed: 57 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
# along with NEST. If not, see <http://www.gnu.org/licenses/>.
2020
#
2121

22+
import logging
2223
from typing import Dict, List, Optional
2324

2425
import sympy
@@ -59,7 +60,6 @@ def __init__(self, solver_dict, spike_times: Optional[Dict[str, List[float]]] =
5960
#
6061
# define the necessary numerical state variables
6162
#
62-
6363
self.dim = len(self.all_variable_symbols)
6464
self.initial_values = self.solver_dict["initial_values"].copy()
6565
self.set_initial_values(self.initial_values)
@@ -72,11 +72,16 @@ def __init__(self, solver_dict, spike_times: Optional[Dict[str, List[float]]] =
7272
subs_dict[k_] = v_
7373
self.shape_starting_values[k] = float(expr.evalf(subs=subs_dict))
7474

75-
self.update_expressions = self.solver_dict["update_expressions"].copy()
76-
for k, v in self.update_expressions.items():
77-
if type(self.update_expressions[k]) is str:
78-
self.update_expressions[k] = sympy.parsing.sympy_parser.parse_expr(self.update_expressions[k], global_dict=Shape._sympy_globals)
7975

76+
#
77+
# initialise update expressions depending on whether conditional solver or not
78+
#
79+
80+
if "update_expressions" in self.solver_dict.keys():
81+
self._pick_unconditional_solver()
82+
else:
83+
assert "conditions" in self.solver_dict.keys()
84+
self._pick_solver_based_on_condition()
8085

8186
#
8287
# reset the system to t = 0
@@ -85,24 +90,65 @@ def __init__(self, solver_dict, spike_times: Optional[Dict[str, List[float]]] =
8590
self.reset()
8691

8792

93+
def _condition_holds(self, condition_string) -> bool:
94+
parts = condition_string.strip('()').split('==')
95+
lhs_str = parts[0].strip()
96+
rhs_str = parts[1].strip()
97+
98+
# Sympify each side individually and create the Eq
99+
equation = sympy.Eq(sympy.sympify(lhs_str), sympy.sympify(rhs_str))
100+
101+
return equation.subs(self.solver_dict["parameters"])
102+
103+
104+
def _pick_unconditional_solver(self):
105+
self.update_expressions = self.solver_dict["update_expressions"].copy()
106+
self.propagators = self.solver_dict["propagators"].copy()
107+
self._process_update_expressions_from_solver_dict()
108+
109+
def _pick_solver_based_on_condition(self):
110+
r"""In case of a conditional propagator solver: pick a solver depending on the conditions that hold (depending on parameter values)"""
111+
self.update_expressions = self.solver_dict["conditions"]["default"]["update_expressions"]
112+
self.propagators = self.solver_dict["conditions"]["default"]["propagators"]
113+
114+
for condition, conditional_solver in self.solver_dict["conditions"].items():
115+
print("Checking condition " + str(condition) + ", params = " + str(self.solver_dict["parameters"]))
116+
if condition != "default" and self._condition_holds(condition):
117+
self.update_expressions = conditional_solver["update_expressions"]
118+
self.propagators = conditional_solver["propagators"]
119+
print("Picking solver based on condition: " + str(condition))
120+
121+
break
122+
123+
self._process_update_expressions_from_solver_dict()
124+
125+
126+
def _process_update_expressions_from_solver_dict(self):
88127
#
89-
# in the update expression, replace symbolic variables with their numerical values
128+
# create substitution dictionary to replace symbolic variables with their numerical values
90129
#
91130

92-
self.subs_dict = {}
93-
for prop_symbol, prop_expr in self.solver_dict["propagators"].items():
94-
self.subs_dict[prop_symbol] = prop_expr
131+
subs_dict = {}
132+
for prop_symbol, prop_expr in self.propagators.items():
133+
subs_dict[prop_symbol] = prop_expr
95134
if "parameters" in self.solver_dict.keys():
96135
for param_symbol, param_expr in self.solver_dict["parameters"].items():
97-
self.subs_dict[param_symbol] = param_expr
136+
subs_dict[param_symbol] = param_expr
98137

138+
#
139+
# parse the expressions from JSON if necessary
140+
#
141+
142+
for k, v in self.update_expressions.items():
143+
if type(self.update_expressions[k]) is str:
144+
self.update_expressions[k] = sympy.parsing.sympy_parser.parse_expr(self.update_expressions[k], global_dict=Shape._sympy_globals)
99145

100146
#
101147
# perform substitution in update expressions ahead of time to save time later
102148
#
103149

104150
for k, v in self.update_expressions.items():
105-
self.update_expressions[k] = self.update_expressions[k].subs(self.subs_dict).subs(self.subs_dict)
151+
self.update_expressions[k] = self.update_expressions[k].subs(subs_dict).subs(subs_dict)
106152

107153
#
108154
# autowrap

odetoolbox/system_of_shapes.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -253,13 +253,17 @@ def generate_propagator_solver(self, disable_singularity_detection: bool = False
253253
solver_dict = {"solver": "analytical",
254254
"state_variables": default_solver["state_variables"],
255255
"initial_values": default_solver["initial_values"],
256-
"conditions": {"default" : {"propagators": default_solver["propagators"],
257-
"update_expressions": default_solver["update_expressions"]}}}
256+
"conditions": {"default": {"propagators": default_solver["propagators"],
257+
"update_expressions": default_solver["update_expressions"]}}}
258258

259259
# XXX: TODO: generate/loop over all combinations of conditions!!!
260260

261261
for cond_set in conditions:
262-
condition_str: str = " && ".join(["(" + str(eq.lhs) + " == " + str(eq.rhs) + ")" for eq in cond_set])
262+
if len(cond_set) == 1:
263+
eq = list(cond_set)[0]
264+
condition_str: str = str(eq.lhs) + " == " + str(eq.rhs)
265+
else:
266+
condition_str: str = " && ".join(["(" + str(eq.lhs) + " == " + str(eq.rhs) + ")" for eq in cond_set])
263267
conditional_A = self.A_.copy()
264268
conditional_b = self.b_.copy()
265269
conditional_c = self.c_.copy()
@@ -297,7 +301,7 @@ def generate_propagator_solver(self, disable_singularity_detection: bool = False
297301
for col in range(P.shape[1]):
298302
if not _is_zero(P[row, col]):
299303
sym_str = Config().propagators_prefix + "__{}__{}".format(str(self.x_[row]), str(self.x_[col]))
300-
P_expr[sym_str] = str(P[row, col])
304+
P_expr[sym_str] = P[row, col]
301305
if row != col and not _is_zero(self.b_[col]):
302306
# the ODE for x_[row] depends on the inhomogeneous ODE of x_[col]. We can't solve this analytically in the general case (even though some specific cases might admit a solution)
303307
raise PropagatorGenerationException("the ODE for " + str(self.x_[row]) + " depends on the inhomogeneous ODE of " + str(self.x_[col]) + ". We can't solve this analytically in the general case (even though some specific cases might admit a solution)")

tests/test_double_exponential.py

Lines changed: 38 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#
2121

2222
import numpy as np
23+
import pytest
2324
from scipy.integrate import odeint
2425

2526
import odetoolbox
@@ -29,7 +30,7 @@
2930

3031
try:
3132
import matplotlib as mpl
32-
mpl.use('Agg')
33+
mpl.use("Agg")
3334
import matplotlib.pyplot as plt
3435
INTEGRATION_TEST_DEBUG_PLOTS = True
3536
except ImportError:
@@ -39,23 +40,32 @@
3940
class TestDoubleExponential:
4041
r"""Test propagators generation for double exponential"""
4142

42-
def test_double_exponential(self):
43-
r"""Test propagators generation for double exponential"""
43+
@pytest.mark.parametrize("tau_1, tau_2", [(10., 2.), (10., 10.)])
44+
def test_double_exponential(self, tau_1, tau_2):
45+
r"""Test propagators generation for double exponential
46+
47+
Test for a case where tau_1 != tau_2 and where tau_1 == tau_2; this tests handling of numerical singularities.
48+
49+
tau_1: decay time constant (ms)
50+
tau_2: rise time constant (ms)
51+
"""
4452

4553
def time_to_max(tau_1, tau_2):
4654
r"""
4755
Time of maximum.
4856
"""
49-
tmax = (np.log(tau_1) - np.log(tau_2)) / (1. / tau_2 - 1. / tau_1)
50-
return tmax
57+
if tau_1 == tau_2:
58+
return tau_1
59+
60+
return (np.log(tau_1) - np.log(tau_2)) / (1. / tau_2 - 1. / tau_1)
5161

5262
def unit_amplitude(tau_1, tau_2):
5363
r"""
5464
Scaling factor ensuring that amplitude of solution is one.
5565
"""
5666
tmax = time_to_max(tau_1, tau_2)
57-
alpha = 1. / (np.exp(-tmax / tau_1) - np.exp(-tmax / tau_2))
58-
return alpha
67+
68+
return 1. / (np.exp(-tmax / tau_1) - np.exp(-tmax / tau_2))
5969

6070
def flow(y, t, tau_1, tau_2, alpha, dt):
6171
r"""
@@ -66,26 +76,27 @@ def flow(y, t, tau_1, tau_2, alpha, dt):
6676

6777
return np.array([dy1dt, dy2dt])
6878

79+
if tau_1 == tau_2:
80+
alpha = 1.
81+
else:
82+
alpha = unit_amplitude(tau_1=tau_1, tau_2=tau_2)
83+
6984
indict = {"dynamics": [{"expression": "I_aux' = -I_aux / tau_1",
7085
"initial_values": {"I_aux": "0."}},
7186
{"expression": "I' = I_aux - I / tau_2",
7287
"initial_values": {"I": "0"}}],
7388
"options": {"output_timestep_symbol": "__h"},
74-
"parameters": {"tau_1": "10",
75-
"tau_2": "2",
89+
"parameters": {"tau_1": str(tau_1),
90+
"tau_2": str(tau_2),
7691
"w": "3.14",
77-
"alpha": str(unit_amplitude(tau_1=10., tau_2=2.)),
92+
"alpha": str(alpha),
7893
"weighted_input_spikes": "0."}}
7994

8095
w = 3.14 # weight (amplitude; pA)
81-
tau_1 = 10. # decay time constant (ms)
82-
tau_2 = 2. # rise time constant (ms)
8396
dt = .125 # time resolution (ms)
8497
T = 500. # simulation time (ms)
8598
input_spike_times = np.array([100., 300.]) # array of input spike times (ms)
8699

87-
alpha = unit_amplitude(tau_1, tau_2)
88-
89100
stimuli = [{"type": "list",
90101
"list": " ".join([str(el) for el in input_spike_times]),
91102
"variables": ["I_aux"]}]
@@ -103,7 +114,7 @@ def flow(y, t, tau_1, tau_2, alpha, dt):
103114
N = int(np.ceil(T / dt) + 1)
104115
timevec = np.linspace(0., T, N)
105116
analytic_integrator = AnalyticIntegrator(solver_dict, spike_times)
106-
analytic_integrator.shape_starting_values["I_aux"] = w * alpha * (1. / tau_2 - 1. / tau_1)
117+
analytic_integrator.shape_starting_values["I_aux"] = w * alpha
107118
analytic_integrator.set_initial_values(ODE_INITIAL_VALUES)
108119
analytic_integrator.reset()
109120
state = {"timevec": [], "I": [], "I_aux": []}
@@ -119,24 +130,24 @@ def flow(y, t, tau_1, tau_2, alpha, dt):
119130
ts2 = np.arange(input_spike_times[1], T + dt, dt)
120131

121132
y_ = odeint(flow, [0., 0.], ts0, args=(tau_1, tau_2, alpha, dt))
122-
y_ = np.vstack([y_, odeint(flow, [y_[-1, 0] + w * alpha * (1. / tau_2 - 1. / tau_1), y_[-1, 1]], ts1, args=(tau_1, tau_2, alpha, dt))])
123-
y_ = np.vstack([y_, odeint(flow, [y_[-1, 0] + w * alpha * (1. / tau_2 - 1. / tau_1), y_[-1, 1]], ts2, args=(tau_1, tau_2, alpha, dt))])
133+
y_ = np.vstack([y_, odeint(flow, [y_[-1, 0] + w * alpha, y_[-1, 1]], ts1, args=(tau_1, tau_2, alpha, dt))])
134+
y_ = np.vstack([y_, odeint(flow, [y_[-1, 0] + w * alpha, y_[-1, 1]], ts2, args=(tau_1, tau_2, alpha, dt))])
124135

125-
rec_I_interp = np.interp(np.hstack([ts0, ts1, ts2]), timevec, state['I'])
126-
rec_I_aux_interp = np.interp(np.hstack([ts0, ts1, ts2]), timevec, state['I_aux'])
136+
rec_I_interp = np.interp(np.hstack([ts0, ts1, ts2]), timevec, state["I"])
137+
rec_I_aux_interp = np.interp(np.hstack([ts0, ts1, ts2]), timevec, state["I_aux"])
127138

128139
if INTEGRATION_TEST_DEBUG_PLOTS:
129140
tmax = time_to_max(tau_1, tau_2)
130-
mpl.rcParams['text.usetex'] = True
141+
mpl.rcParams["text.usetex"] = True
131142

132143
fig, ax = plt.subplots(nrows=2, figsize=(5, 4), dpi=300)
133-
ax[0].plot(timevec, state['I_aux'], '--', lw=3, color='k', label=r'$I_\mathsf{aux}(t)$ (NEST)')
134-
ax[0].plot(timevec, state['I'], '-', lw=3, color='k', label=r'$I(t)$ (NEST)')
135-
ax[0].plot(np.hstack([ts0, ts1, ts2]), y_[:, 0], '--', lw=2, color='r', label=r'$I_\mathsf{aux}(t)$ (odeint)')
136-
ax[0].plot(np.hstack([ts0, ts1, ts2]), y_[:, 1], '-', lw=2, color='r', label=r'$I(t)$ (odeint)')
144+
ax[0].plot(timevec, state["I_aux"], "--", lw=3, color="k", label=r"$I_\mathsf{aux}(t)$ (ODEtb)")
145+
ax[0].plot(timevec, state["I"], "-", lw=3, color="k", label=r"$I(t)$ (ODEtb)")
146+
ax[0].plot(np.hstack([ts0, ts1, ts2]), y_[:, 0], "--", lw=2, color="r", label=r"$I_\mathsf{aux}(t)$ (odeint)")
147+
ax[0].plot(np.hstack([ts0, ts1, ts2]), y_[:, 1], "-", lw=2, color="r", label=r"$I(t)$ (odeint)")
137148

138149
for tin in input_spike_times:
139-
ax[0].vlines(tin + tmax, ax[0].get_ylim()[0], ax[0].get_ylim()[1], colors='k', linestyles=':')
150+
ax[0].vlines(tin + tmax, ax[0].get_ylim()[0], ax[0].get_ylim()[1], colors="k", linestyles=":")
140151

141152
ax[1].semilogy(np.hstack([ts0, ts1, ts2]), np.abs(y_[:, 1] - rec_I_interp), label="I")
142153
ax[1].semilogy(np.hstack([ts0, ts1, ts2]), np.abs(y_[:, 0] - rec_I_aux_interp), linestyle="--", label="I_aux")
@@ -146,9 +157,9 @@ def flow(y, t, tau_1, tau_2, alpha, dt):
146157
_ax.set_xlim(0., T + dt)
147158
_ax.legend()
148159

149-
ax[-1].set_xlabel(r'time (ms)')
160+
ax[-1].set_xlabel(r"time (ms)")
150161

151-
fig.savefig('double_exp_test.png')
162+
fig.savefig("double_exp_test_[tau_1=" + str(tau_1) + "]_[tau_2=" + str(tau_2) + "].png")
152163

153164
np.testing.assert_allclose(y_[:, 1], rec_I_interp, atol=1E-7)
154165

tests/test_propagator_solver_homogeneous.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,5 +32,8 @@ def test_propagator_solver_homogeneous(self):
3232
assert len(solver_dict) == 1
3333
solver_dict = solver_dict[0]
3434
assert solver_dict["solver"] == "analytical"
35-
assert float(solver_dict["propagators"]["__P__refr_t__refr_t"]) == 1.
36-
assert solver_dict["propagators"]["__P__V_m__V_m"] == "1.0*exp(-__h/tau_m)"
35+
36+
for cond_solver_dict in solver_dict["conditions"].values():
37+
assert float(cond_solver_dict["propagators"]["__P__refr_t__refr_t"]) == 1.
38+
39+
assert solver_dict["conditions"]["default"]["propagators"]["__P__V_m__V_m"] == "1.0*exp(-__h/tau_m)"

0 commit comments

Comments
 (0)