From d7451be96dbce6ab66f060163f48585cdf5c7c6f Mon Sep 17 00:00:00 2001 From: Pip Liggins Date: Fri, 12 Sep 2025 17:42:51 -0700 Subject: [PATCH 1/2] Add test and fix bug --- src/pybamm/solvers/base_solver.py | 8 +++--- tests/unit/test_solvers/test_base_solver.py | 30 +++++++++++++++++++++ 2 files changed, 34 insertions(+), 4 deletions(-) diff --git a/src/pybamm/solvers/base_solver.py b/src/pybamm/solvers/base_solver.py index 51e6dcede7..f5fbae205a 100644 --- a/src/pybamm/solvers/base_solver.py +++ b/src/pybamm/solvers/base_solver.py @@ -797,7 +797,7 @@ def solve( for it in model.concatenated_initial_conditions.pre_order() ] ) - if all_inputs_names.issubset(initial_conditions_node_names): + if not initial_conditions_node_names.isdisjoint(all_inputs_names): raise pybamm.SolverError( "Input parameters cannot appear in expression " "for initial conditions." @@ -847,9 +847,9 @@ def solve( # If the new initial conditions are different # and cannot be evaluated directly, set up again self.set_up(model, model_inputs_list[0], t_eval, ics_only=True) - self._model_set_up[model]["initial conditions"] = ( - model.concatenated_initial_conditions - ) + self._model_set_up[model][ + "initial conditions" + ] = model.concatenated_initial_conditions else: # Set the standard initial conditions self._set_initial_conditions(model, t_eval[0], model_inputs_list[0]) diff --git a/tests/unit/test_solvers/test_base_solver.py b/tests/unit/test_solvers/test_base_solver.py index c0bcbf3ce2..18ec1aece1 100644 --- a/tests/unit/test_solvers/test_base_solver.py +++ b/tests/unit/test_solvers/test_base_solver.py @@ -446,3 +446,33 @@ def test_on_extrapolation_and_on_failure_settings(self): ValueError, match="on_failure must be 'warn', 'raise', or 'ignore'" ): base_solver.on_failure = "invalid" + + def test_solver_multiple_inputs_initial_conditions_error(self): + + y = pybamm.Variable("y") + y0 = pybamm.InputParameter("y0") + k = pybamm.InputParameter("k") + + model = pybamm.BaseModel() + model.rhs = {y: -k * y} + model.initial_conditions = {y: y0} + model.variables = {"y": y} + + disc = pybamm.Discretisation() + disc.process_model(model) + + t_eval = np.linspace(0.0, 1.0, 6) + + # Three different ICs so each run is clearly distinct + inputs_list = [ + {"y0": 1.0, "k": 0.5}, + {"y0": 2.0, "k": 1.0}, + {"y0": 3.0, "k": 1.5}, + ] + + solver = pybamm.BaseSolver() + with pytest.raises( + pybamm.SolverError, + match="Input parameters cannot appear in expression for initial conditions", + ): + solver.solve(model, t_eval=t_eval, inputs=inputs_list) From 2d5ea7e99ee963261b3da0ef952379f7c20c70cc Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 13 Sep 2025 00:48:26 +0000 Subject: [PATCH 2/2] style: pre-commit fixes --- src/pybamm/solvers/base_solver.py | 6 +++--- tests/unit/test_solvers/test_base_solver.py | 1 - 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/pybamm/solvers/base_solver.py b/src/pybamm/solvers/base_solver.py index f5fbae205a..35851adc4f 100644 --- a/src/pybamm/solvers/base_solver.py +++ b/src/pybamm/solvers/base_solver.py @@ -847,9 +847,9 @@ def solve( # If the new initial conditions are different # and cannot be evaluated directly, set up again self.set_up(model, model_inputs_list[0], t_eval, ics_only=True) - self._model_set_up[model][ - "initial conditions" - ] = model.concatenated_initial_conditions + self._model_set_up[model]["initial conditions"] = ( + model.concatenated_initial_conditions + ) else: # Set the standard initial conditions self._set_initial_conditions(model, t_eval[0], model_inputs_list[0]) diff --git a/tests/unit/test_solvers/test_base_solver.py b/tests/unit/test_solvers/test_base_solver.py index 18ec1aece1..90c84eb986 100644 --- a/tests/unit/test_solvers/test_base_solver.py +++ b/tests/unit/test_solvers/test_base_solver.py @@ -448,7 +448,6 @@ def test_on_extrapolation_and_on_failure_settings(self): base_solver.on_failure = "invalid" def test_solver_multiple_inputs_initial_conditions_error(self): - y = pybamm.Variable("y") y0 = pybamm.InputParameter("y0") k = pybamm.InputParameter("k")