diff --git a/src/pybamm/solvers/base_solver.py b/src/pybamm/solvers/base_solver.py index 51e6dcede7..35851adc4f 100644 --- a/src/pybamm/solvers/base_solver.py +++ b/src/pybamm/solvers/base_solver.py @@ -797,7 +797,7 @@ def solve( for it in model.concatenated_initial_conditions.pre_order() ] ) - if all_inputs_names.issubset(initial_conditions_node_names): + if not initial_conditions_node_names.isdisjoint(all_inputs_names): raise pybamm.SolverError( "Input parameters cannot appear in expression " "for initial conditions." diff --git a/tests/unit/test_solvers/test_base_solver.py b/tests/unit/test_solvers/test_base_solver.py index c0bcbf3ce2..90c84eb986 100644 --- a/tests/unit/test_solvers/test_base_solver.py +++ b/tests/unit/test_solvers/test_base_solver.py @@ -446,3 +446,32 @@ def test_on_extrapolation_and_on_failure_settings(self): ValueError, match="on_failure must be 'warn', 'raise', or 'ignore'" ): base_solver.on_failure = "invalid" + + def test_solver_multiple_inputs_initial_conditions_error(self): + y = pybamm.Variable("y") + y0 = pybamm.InputParameter("y0") + k = pybamm.InputParameter("k") + + model = pybamm.BaseModel() + model.rhs = {y: -k * y} + model.initial_conditions = {y: y0} + model.variables = {"y": y} + + disc = pybamm.Discretisation() + disc.process_model(model) + + t_eval = np.linspace(0.0, 1.0, 6) + + # Three different ICs so each run is clearly distinct + inputs_list = [ + {"y0": 1.0, "k": 0.5}, + {"y0": 2.0, "k": 1.0}, + {"y0": 3.0, "k": 1.5}, + ] + + solver = pybamm.BaseSolver() + with pytest.raises( + pybamm.SolverError, + match="Input parameters cannot appear in expression for initial conditions", + ): + solver.solve(model, t_eval=t_eval, inputs=inputs_list)