diff --git a/pyproject.toml b/pyproject.toml index 4b866d5ebb..f4e458fa4b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -264,6 +264,14 @@ concurrency = ["multiprocessing"] ignore_missing_imports = true allow_redefinition = true disable_error_code = ["call-overload", "operator"] +strict = false +enable_error_code = ["ignore-without-code", "redundant-expr", "truthy-bool"] +exclude = [ + "^build/", + "^build/lib/", + "^docs/conf\\.py$", + "^examples/scripts/" +] [[tool.mypy.overrides]] module = [ diff --git a/src/pybamm/experiment/experiment.py b/src/pybamm/experiment/experiment.py index ce44457cb2..b9cd4b4d9d 100644 --- a/src/pybamm/experiment/experiment.py +++ b/src/pybamm/experiment/experiment.py @@ -40,7 +40,9 @@ class Experiment: def __init__( self, - operating_conditions: list[str | tuple[str] | BaseStep], + operating_conditions: list[ + str | tuple[str, ...] | tuple[str | BaseStep] | BaseStep + ], period: str | None = None, temperature: float | None = None, termination: list[str] | None = None, diff --git a/src/pybamm/experiment/step/base_step.py b/src/pybamm/experiment/step/base_step.py index ebbc4057c4..98c612781a 100644 --- a/src/pybamm/experiment/step/base_step.py +++ b/src/pybamm/experiment/step/base_step.py @@ -140,7 +140,7 @@ def __init__( self.value = pybamm.Interpolant( t, y, - pybamm.t - pybamm.InputParameter("start time"), + [pybamm.t - pybamm.InputParameter("start time")], name="Drive Cycle", ) self.period = np.diff(t).min() diff --git a/src/pybamm/expression_tree/binary_operators.py b/src/pybamm/expression_tree/binary_operators.py index efd9874664..c0dd6017bd 100644 --- a/src/pybamm/expression_tree/binary_operators.py +++ b/src/pybamm/expression_tree/binary_operators.py @@ -2,7 +2,6 @@ # Binary operator classes # from __future__ import annotations -import numbers import numpy as np import numpy.typing as npt @@ -34,8 +33,8 @@ def _preprocess_binary( raise ValueError("right must be a 1D array") right = pybamm.Vector(right) - # Check both left and right are pybamm Symbols - if not (isinstance(left, pybamm.Symbol) and isinstance(right, pybamm.Symbol)): + # Check right is pybamm Symbol + if not isinstance(right, pybamm.Symbol): raise NotImplementedError( f"BinaryOperator not implemented for symbols of type {type(left)} and {type(right)}" ) @@ -114,6 +113,9 @@ def __str__(self): right_str = f"{self.right!s}" return f"{left_str} {self.name} {right_str}" + def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol: + return self.__class__(self.name, left, right) # pragma: no cover + def create_copy( self, new_children: list[pybamm.Symbol] | None = None, @@ -128,7 +130,7 @@ def create_copy( children = self._children_for_copying(new_children) if not perform_simplifications: - out = self.__class__(children[0], children[1]) + out = self._new_instance(children[0], children[1]) else: # creates a new instance using the overloaded binary operator to perform # additional simplifications, rather than just calling the constructor @@ -225,6 +227,9 @@ def __init__( """See :meth:`pybamm.BinaryOperator.__init__()`.""" super().__init__("**", left, right) + def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol: + return Power(left, right) + def _diff(self, variable: pybamm.Symbol): """See :meth:`pybamm.Symbol._diff()`.""" # apply chain rule and power rule @@ -274,6 +279,9 @@ def __init__( """See :meth:`pybamm.BinaryOperator.__init__()`.""" super().__init__("+", left, right) + def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol: + return Addition(left, right) + def _diff(self, variable: pybamm.Symbol): """See :meth:`pybamm.Symbol._diff()`.""" return self.left.diff(variable) + self.right.diff(variable) @@ -301,6 +309,9 @@ def __init__( super().__init__("-", left, right) + def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol: + return Subtraction(left, right) + def _diff(self, variable: pybamm.Symbol): """See :meth:`pybamm.Symbol._diff()`.""" return self.left.diff(variable) - self.right.diff(variable) @@ -330,6 +341,9 @@ def __init__( super().__init__("*", left, right) + def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol: + return Multiplication(left, right) + def _diff(self, variable: pybamm.Symbol): """See :meth:`pybamm.Symbol._diff()`.""" # apply product rule @@ -370,6 +384,9 @@ def __init__( """See :meth:`pybamm.BinaryOperator.__init__()`.""" super().__init__("@", left, right) + def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol: + return MatrixMultiplication(left, right) # pragma: no cover + def diff(self, variable): """See :meth:`pybamm.Symbol.diff()`.""" # We shouldn't need this @@ -419,6 +436,9 @@ def __init__( """See :meth:`pybamm.BinaryOperator.__init__()`.""" super().__init__("/", left, right) + def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol: + return Division(left, right) + def _diff(self, variable: pybamm.Symbol): """See :meth:`pybamm.Symbol._diff()`.""" # apply quotient rule @@ -467,6 +487,9 @@ def __init__( """See :meth:`pybamm.BinaryOperator.__init__()`.""" super().__init__("inner product", left, right) + def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol: + return Inner(left, right) # pragma: no cover + def _diff(self, variable: pybamm.Symbol): """See :meth:`pybamm.Symbol._diff()`.""" # apply product rule @@ -544,6 +567,9 @@ def __init__( """See :meth:`pybamm.BinaryOperator.__init__()`.""" super().__init__("==", left, right) + def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol: + return Equality(left, right) + def diff(self, variable): """See :meth:`pybamm.Symbol.diff()`.""" # Equality should always be multiplied by something else so hopefully don't @@ -601,6 +627,10 @@ def __init__( ): """See :meth:`pybamm.BinaryOperator.__init__()`.""" super().__init__(name, left, right) + self.name = name + + def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol: + return _Heaviside(self.name, left, right) # pragma: no cover def diff(self, variable): """See :meth:`pybamm.Symbol.diff()`.""" @@ -679,6 +709,9 @@ def __init__( ): super().__init__("%", left, right) + def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol: + return Modulo(left, right) + def _diff(self, variable: pybamm.Symbol): """See :meth:`pybamm.Symbol._diff()`.""" # apply chain rule and power rule @@ -721,6 +754,9 @@ def __init__( ): super().__init__("minimum", left, right) + def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol: + return Minimum(left, right) + def __str__(self): """See :meth:`pybamm.Symbol.__str__()`.""" return f"minimum({self.left!s}, {self.right!s})" @@ -765,6 +801,9 @@ def __init__( ): super().__init__("maximum", left, right) + def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol: + return Maximum(left, right) + def __str__(self): """See :meth:`pybamm.Symbol.__str__()`.""" return f"maximum({self.left!s}, {self.right!s})" @@ -1539,7 +1578,7 @@ def source( corresponding to a source term in the bulk. """ # Broadcast if left is number - if isinstance(left, numbers.Number): + if isinstance(left, (int, float)): left = pybamm.PrimaryBroadcast(left, "current collector") # force type cast for mypy diff --git a/src/pybamm/expression_tree/broadcasts.py b/src/pybamm/expression_tree/broadcasts.py index 6045c3f3e8..1fabef127c 100644 --- a/src/pybamm/expression_tree/broadcasts.py +++ b/src/pybamm/expression_tree/broadcasts.py @@ -78,8 +78,7 @@ def _from_json(cls, snippet): ) def _unary_new_copy(self, child: pybamm.Symbol, perform_simplifications=True): - """See :meth:`pybamm.UnaryOperator._unary_new_copy()`.""" - return self.__class__(child, self.broadcast_domain) + pass # pragma: no cover class PrimaryBroadcast(Broadcast): @@ -191,6 +190,10 @@ def reduce_one_dimension(self): """Reduce the broadcast by one dimension.""" return self.orphans[0] + def _unary_new_copy(self, child: pybamm.Symbol, perform_simplifications=True): + """See :meth:`pybamm.UnaryOperator._unary_new_copy()`.""" + return self.__class__(child, self.broadcast_domain) + class PrimaryBroadcastToEdges(PrimaryBroadcast): """A primary broadcast onto the edges of the domain.""" @@ -321,6 +324,10 @@ def reduce_one_dimension(self): """Reduce the broadcast by one dimension.""" return self.orphans[0] + def _unary_new_copy(self, child: pybamm.Symbol, perform_simplifications=True): + """See :meth:`pybamm.UnaryOperator._unary_new_copy()`.""" + return self.__class__(child, self.broadcast_domain) + class SecondaryBroadcastToEdges(SecondaryBroadcast): """A secondary broadcast onto the edges of a domain.""" @@ -438,6 +445,10 @@ def reduce_one_dimension(self): """Reduce the broadcast by one dimension.""" raise NotImplementedError + def _unary_new_copy(self, child: pybamm.Symbol, perform_simplifications=True): + """See :meth:`pybamm.UnaryOperator._unary_new_copy()`.""" + return self.__class__(child, self.broadcast_domain) + class TertiaryBroadcastToEdges(TertiaryBroadcast): """A tertiary broadcast onto the edges of a domain.""" @@ -463,7 +474,7 @@ def __init__( self, child_input: Numeric | pybamm.Symbol, broadcast_domain: DomainType = None, - auxiliary_domains: AuxiliaryDomainType = None, + auxiliary_domains: AuxiliaryDomainType | str = None, broadcast_domains: DomainsType = None, name: str | None = None, ): diff --git a/src/pybamm/expression_tree/concatenations.py b/src/pybamm/expression_tree/concatenations.py index 6f190aaf59..50abf5c29f 100644 --- a/src/pybamm/expression_tree/concatenations.py +++ b/src/pybamm/expression_tree/concatenations.py @@ -516,7 +516,7 @@ def substrings(s: str): yield s[i : j + 1] -def intersect(s1: str, s2: str): +def intersect(s1: str, s2: str) -> str: # find all the common strings between two strings all_intersects = set(substrings(s1)) & set(substrings(s2)) # intersect is the longest such intercept @@ -527,7 +527,7 @@ def intersect(s1: str, s2: str): return intersect.lstrip().rstrip() -def simplified_concatenation(*children, name: str | None = None): +def simplified_concatenation(*children: pybamm.Symbol, name: str | None = None): """Perform simplifications on a concatenation.""" # remove children that are None children = list(filter(lambda x: x is not None, children)) @@ -543,7 +543,8 @@ def simplified_concatenation(*children, name: str | None = None): # Create Concatenation to easily read domains concat = Concatenation(*children, name=name) if all( - isinstance(child, pybamm.Broadcast) and child.child == children[0].child + isinstance(child, pybamm.Broadcast) + and getattr(child, "child", None) == getattr(children[0], "child", None) for child in children ): unique_child = children[0].orphans[0] diff --git a/src/pybamm/expression_tree/symbol.py b/src/pybamm/expression_tree/symbol.py index 34ca9d627b..96dc3a844f 100644 --- a/src/pybamm/expression_tree/symbol.py +++ b/src/pybamm/expression_tree/symbol.py @@ -965,7 +965,9 @@ def to_casadi( """ return pybamm.CasadiConverter(casadi_symbols).convert(self, t, y, y_dot, inputs) - def _children_for_copying(self, children: list[Symbol] | None = None) -> Symbol: + def _children_for_copying( + self, children: list[Symbol] | None = None + ) -> list[Symbol]: """ Gets existing children for a symbol being copied if they aren't provided. """ diff --git a/src/pybamm/models/base_model.py b/src/pybamm/models/base_model.py index a8dc6aa5ea..59b97a8e87 100644 --- a/src/pybamm/models/base_model.py +++ b/src/pybamm/models/base_model.py @@ -77,6 +77,8 @@ def __init__(self, name="Unnamed model"): self.use_jacobian = True self.convert_to_format = "casadi" + self.calculate_sensitivities: list[str] = [] + # Model is not initially discretised self.is_discretised = False self.y_slices = None diff --git a/src/pybamm/plotting/quick_plot.py b/src/pybamm/plotting/quick_plot.py index ee146a2002..6a4fce9a04 100644 --- a/src/pybamm/plotting/quick_plot.py +++ b/src/pybamm/plotting/quick_plot.py @@ -126,7 +126,7 @@ def __init__( # Set colors, linestyles, figsize, axis limits # call LoopList to make sure list index never runs out if colors is None: - self.colors = LoopList(colors or ["r", "b", "k", "g", "m", "c"]) + self.colors = LoopList(["r", "b", "k", "g", "m", "c"]) else: self.colors = LoopList(colors) self.linestyles = LoopList(linestyles or ["-", ":", "--", "-."]) diff --git a/src/pybamm/solvers/base_solver.py b/src/pybamm/solvers/base_solver.py index 49e9b928ae..4cd219fd3e 100644 --- a/src/pybamm/solvers/base_solver.py +++ b/src/pybamm/solvers/base_solver.py @@ -94,8 +94,8 @@ def supports_parallel_solve(self): def requires_explicit_sensitivities(self): return True - @root_method.setter - def root_method(self, method): + @root_method.setter # type: ignore[attr-defined, no-redef] + def root_method(self, method) -> None: if method == "casadi": method = pybamm.CasadiAlgebraicSolver(self.root_tol) elif isinstance(method, str): @@ -1122,7 +1122,7 @@ def _set_sens_initial_conditions_from( """ ninputs = len(model.calculate_sensitivities) - initial_conditions = tuple([] for _ in range(ninputs)) + initial_conditions: tuple[list[float], ...] = tuple([] for _ in range(ninputs)) solution = solution.last_state for var in model.initial_conditions: final_state = solution[var.name] @@ -1143,10 +1143,10 @@ def _set_sens_initial_conditions_from( slices = [y_slices[symbol][0] for symbol in model.initial_conditions.keys()] # sort equations according to slices - concatenated_initial_conditions = [ + concatenated_initial_conditions = tuple( casadi.vertcat(*[eq for _, eq in sorted(zip(slices, init))]) for init in initial_conditions - ] + ) return concatenated_initial_conditions def process_t_interp(self, t_interp): diff --git a/src/pybamm/solvers/idaklu_jax.py b/src/pybamm/solvers/idaklu_jax.py index ef505570fa..6cfa624490 100644 --- a/src/pybamm/solvers/idaklu_jax.py +++ b/src/pybamm/solvers/idaklu_jax.py @@ -259,9 +259,9 @@ def f_isolated(*args, **kwargs): def jax_value( self, - t: npt.NDArray = None, - inputs: Union[dict, None] = None, - output_variables: Union[list[str], None] = None, + t: npt.NDArray | None = None, + inputs: dict | None = None, + output_variables: list[str] | None = None, ): """Helper function to compute the gradient of a jaxified expression @@ -292,9 +292,9 @@ def jax_value( def jax_grad( self, - t: npt.NDArray = None, - inputs: Union[dict, None] = None, - output_variables: Union[list[str], None] = None, + t: npt.NDArray | None = None, + inputs: dict | None = None, + output_variables: list[str] | None = None, ): """Helper function to compute the gradient of a jaxified expression diff --git a/src/pybamm/solvers/processed_variable_time_integral.py b/src/pybamm/solvers/processed_variable_time_integral.py index ce41c1796e..8b2f5712c8 100644 --- a/src/pybamm/solvers/processed_variable_time_integral.py +++ b/src/pybamm/solvers/processed_variable_time_integral.py @@ -7,7 +7,7 @@ @dataclass class ProcessedVariableTimeIntegral: method: Literal["discrete", "continuous"] - initial_condition: npt.NDArray + initial_condition: npt.NDArray | float discrete_times: Optional[npt.NDArray] @staticmethod diff --git a/src/pybamm/solvers/solution.py b/src/pybamm/solvers/solution.py index 4f17c60d94..d96a667344 100644 --- a/src/pybamm/solvers/solution.py +++ b/src/pybamm/solvers/solution.py @@ -160,7 +160,7 @@ def __init__( def has_sensitivities(self) -> bool: if isinstance(self._all_sensitivities, bool): return self._all_sensitivities - elif isinstance(self._all_sensitivities, dict): + else: return len(self._all_sensitivities) > 0 def extract_explicit_sensitivities(self): diff --git a/src/pybamm/solvers/summary_variable.py b/src/pybamm/solvers/summary_variable.py index 4c3da92a42..26dd40e174 100644 --- a/src/pybamm/solvers/summary_variable.py +++ b/src/pybamm/solvers/summary_variable.py @@ -4,7 +4,8 @@ from __future__ import annotations import pybamm import numpy as np -from typing import Any +from typing import Any, cast +from numpy.typing import NDArray class SummaryVariables: @@ -40,12 +41,15 @@ def __init__( ): self.user_inputs = user_inputs or {} self.esoh_solver = esoh_solver - self._variables = {} # Store computed variables + + self._variables: dict[str, float | list[float]] = {} # Store computed variables self.cycle_number = np.array([]) + self.cycles: list[SummaryVariables] | None = None + self._all_variables: list[str] | None = None model = solution.all_models[0] self._possible_variables = model.summary_variables # minus esoh variables - self._esoh_variables = None # Store eSOH variable names + self._esoh_variables: list[str] | None = None # Store eSOH variable names # Flag if eSOH calculations are needed self.calc_esoh = ( @@ -81,31 +85,34 @@ def all_variables(self) -> list[str]: Return names of all possible summary variables, including eSOH variables if appropriate. """ - try: - return self._all_variables - except AttributeError: - base_vars = self._possible_variables.copy() - base_vars.extend( - f"Change in {var[0].lower() + var[1:]}" - for var in self._possible_variables - ) + all_vars = getattr(self, "_all_variables", None) + if all_vars is not None: + return all_vars + base_vars = self._possible_variables.copy() + base_vars.extend( + f"Change in {var[0].lower() + var[1:]}" for var in self._possible_variables + ) - if self.calc_esoh: - base_vars.extend(self.esoh_variables) + if self.calc_esoh: + base_vars.extend(self.esoh_variables) - self._all_variables = base_vars - return self._all_variables + self._all_variables = base_vars + return base_vars @property def esoh_variables(self) -> list[str] | None: """Return names of all eSOH variables.""" - if self.calc_esoh and self._esoh_variables is None: + if ( + self.esoh_solver is not None + and self.calc_esoh + and self._esoh_variables is None + ): esoh_model = self.esoh_solver._get_electrode_soh_sims_full().model esoh_vars = list(esoh_model.variables.keys()) self._esoh_variables = esoh_vars return self._esoh_variables - def __getitem__(self, key: str) -> float | list[float]: + def __getitem__(self, key: str) -> float | list[float] | NDArray[Any]: """ Access or compute a summary variable by its name. @@ -148,10 +155,11 @@ def update(self, var: str): def _update_multiple_cycles(self, var: str, var_lowercase: str): """Creates aggregated summary variables for where more than one cycle exists.""" - var_cycle = [cycle[var] for cycle in self.cycles] - change_var_cycle = [ - cycle[f"Change in {var_lowercase}"] for cycle in self.cycles - ] + cycles = cast(list[SummaryVariables], self.cycles) + var_cycle = cast(list[float], [cycle[var] for cycle in cycles]) + change_var_cycle = cast( + list[float], [cycle[f"Change in {var_lowercase}"] for cycle in cycles] + ) self._variables[var] = var_cycle self._variables[f"Change in {var_lowercase}"] = change_var_cycle @@ -180,8 +188,9 @@ def _get_esoh_variables(self) -> dict[str, float]: Q_p = self.last_state["Positive electrode capacity [A.h]"].data[0] Q_Li = self.last_state["Total lithium capacity in particles [A.h]"].data[0] all_inputs = {**self.user_inputs, "Q_n": Q_n, "Q_p": Q_p, "Q_Li": Q_Li} + esoh_solver = cast(pybamm.lithium_ion.ElectrodeSOHSolver, self.esoh_solver) try: - esoh_sol = self.esoh_solver.solve(inputs=all_inputs) + esoh_sol = esoh_solver.solve(inputs=all_inputs) except pybamm.SolverError as error: # pragma: no cover raise pybamm.SolverError( "Could not solve for eSOH summary variables" diff --git a/src/pybamm/telemetry.py b/src/pybamm/telemetry.py index 3825738d47..ac5103139b 100644 --- a/src/pybamm/telemetry.py +++ b/src/pybamm/telemetry.py @@ -1,3 +1,4 @@ +from typing import cast from posthog import Posthog import pybamm import sys @@ -20,7 +21,7 @@ def capture(**kwargs): # pragma: no cover project_api_key="phc_acTt7KxmvBsAxaE0NyRd5WfJyNxGvBq1U9HnlQSztmb", host="https://us.i.posthog.com", ) - _posthog.log.setLevel("CRITICAL") + cast(Posthog, _posthog).log.setLevel("CRITICAL") def disable(): diff --git a/src/pybamm/util.py b/src/pybamm/util.py index aac8e7127d..7abaf39888 100644 --- a/src/pybamm/util.py +++ b/src/pybamm/util.py @@ -152,7 +152,7 @@ def search( Default is 0.4 """ - if not isinstance(keys, (str, list)) or not all( + if not isinstance(keys, (str, list)) or not all( # type: ignore[redundant-expr] isinstance(k, str) for k in keys ): msg = f"'keys' must be a string or a list of strings, got {type(keys)}" diff --git a/tests/unit/test_expression_tree/test_binary_operators.py b/tests/unit/test_expression_tree/test_binary_operators.py index eba3ca1bbd..ceed90fda2 100644 --- a/tests/unit/test_expression_tree/test_binary_operators.py +++ b/tests/unit/test_expression_tree/test_binary_operators.py @@ -11,7 +11,7 @@ import pybamm import sympy -EMPTY_DOMAINS = { +EMPTY_DOMAINS: dict[str, list[str]] = { "primary": [], "secondary": [], "tertiary": [], diff --git a/tests/unit/test_expression_tree/test_operations/test_copy.py b/tests/unit/test_expression_tree/test_operations/test_copy.py index f0d59a1fe1..67a884771c 100644 --- a/tests/unit/test_expression_tree/test_operations/test_copy.py +++ b/tests/unit/test_expression_tree/test_operations/test_copy.py @@ -79,6 +79,7 @@ def test_symbol_create_copy_new_children(self): a * b, a / b, a**b, + b % a, pybamm.minimum(a, b), pybamm.maximum(a, b), pybamm.Equality(a, b), @@ -89,12 +90,15 @@ def test_symbol_create_copy_new_children(self): b * a, b / a, b**a, + b % a, pybamm.minimum(b, a), pybamm.maximum(b, a), pybamm.Equality(b, a), ], ): - new_symbol = symbol_ab.create_copy(new_children=[b, a]) + new_symbol = symbol_ab.create_copy( + new_children=[b, a], perform_simplifications=False + ) assert new_symbol == symbol_ba assert new_symbol.print_name == symbol_ba.print_name diff --git a/tests/unit/test_solvers/test_solution.py b/tests/unit/test_solvers/test_solution.py index 78a8dca58e..325d1f3adc 100644 --- a/tests/unit/test_solvers/test_solution.py +++ b/tests/unit/test_solvers/test_solution.py @@ -253,7 +253,7 @@ def test_copy_with_computed_variables(self): sol2 = sol1.copy() - assert ( + assert all( sol1._variables[k] == sol2._variables[k] for k in sol1._variables.keys() ) assert sol2.variables_returned is True