Skip to content
8 changes: 8 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,14 @@ concurrency = ["multiprocessing"]
ignore_missing_imports = true
allow_redefinition = true
disable_error_code = ["call-overload", "operator"]
strict = false
enable_error_code = ["ignore-without-code", "redundant-expr", "truthy-bool"]
exclude = [
"^build/",
"^build/lib/",
"^docs/conf\\.py$",
"^examples/scripts/"
]

[[tool.mypy.overrides]]
module = [
Expand Down
4 changes: 3 additions & 1 deletion src/pybamm/experiment/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ class Experiment:

def __init__(
self,
operating_conditions: list[str | tuple[str] | BaseStep],
operating_conditions: list[
str | tuple[str, ...] | tuple[str | BaseStep] | BaseStep
],
period: str | None = None,
temperature: float | None = None,
termination: list[str] | None = None,
Expand Down
2 changes: 1 addition & 1 deletion src/pybamm/experiment/step/base_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def __init__(
self.value = pybamm.Interpolant(
t,
y,
pybamm.t - pybamm.InputParameter("start time"),
[pybamm.t - pybamm.InputParameter("start time")],
name="Drive Cycle",
)
self.period = np.diff(t).min()
Expand Down
49 changes: 44 additions & 5 deletions src/pybamm/expression_tree/binary_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# Binary operator classes
#
from __future__ import annotations
import numbers

import numpy as np
import numpy.typing as npt
Expand Down Expand Up @@ -34,8 +33,8 @@ def _preprocess_binary(
raise ValueError("right must be a 1D array")
right = pybamm.Vector(right)

# Check both left and right are pybamm Symbols
if not (isinstance(left, pybamm.Symbol) and isinstance(right, pybamm.Symbol)):
# Check right is pybamm Symbol
if not isinstance(right, pybamm.Symbol):
raise NotImplementedError(
f"BinaryOperator not implemented for symbols of type {type(left)} and {type(right)}"
)
Expand Down Expand Up @@ -114,6 +113,9 @@ def __str__(self):
right_str = f"{self.right!s}"
return f"{left_str} {self.name} {right_str}"

def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please add some information on why this was added?

Copy link
Member Author

@Rishab87 Rishab87 Mar 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've replaced self.__ class __ with new_instance because earlier when we were using self.__ class __ it showed a third arg was not getting passed:

error: Missing positional argument "right_child" in call to "BinaryOperator"  [call-arg]

but this function was always getting called from instance of its child classes which don't need to pass 3 arguments, so i thought it was better to make a new_instance method which can be overrided in child classes

I've already added this in the PR description of previous sp-check-guidelines PR, should I add it here too? Or follow some different approach

return self.__class__(self.name, left, right) # pragma: no cover

def create_copy(
self,
new_children: list[pybamm.Symbol] | None = None,
Expand All @@ -128,7 +130,7 @@ def create_copy(
children = self._children_for_copying(new_children)

if not perform_simplifications:
out = self.__class__(children[0], children[1])
out = self._new_instance(children[0], children[1])
else:
# creates a new instance using the overloaded binary operator to perform
# additional simplifications, rather than just calling the constructor
Expand Down Expand Up @@ -225,6 +227,9 @@ def __init__(
"""See :meth:`pybamm.BinaryOperator.__init__()`."""
super().__init__("**", left, right)

def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol:
return Power(left, right)

def _diff(self, variable: pybamm.Symbol):
"""See :meth:`pybamm.Symbol._diff()`."""
# apply chain rule and power rule
Expand Down Expand Up @@ -274,6 +279,9 @@ def __init__(
"""See :meth:`pybamm.BinaryOperator.__init__()`."""
super().__init__("+", left, right)

def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol:
return Addition(left, right)

def _diff(self, variable: pybamm.Symbol):
"""See :meth:`pybamm.Symbol._diff()`."""
return self.left.diff(variable) + self.right.diff(variable)
Expand Down Expand Up @@ -301,6 +309,9 @@ def __init__(

super().__init__("-", left, right)

def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol:
return Subtraction(left, right)

def _diff(self, variable: pybamm.Symbol):
"""See :meth:`pybamm.Symbol._diff()`."""
return self.left.diff(variable) - self.right.diff(variable)
Expand Down Expand Up @@ -330,6 +341,9 @@ def __init__(

super().__init__("*", left, right)

def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol:
return Multiplication(left, right)

def _diff(self, variable: pybamm.Symbol):
"""See :meth:`pybamm.Symbol._diff()`."""
# apply product rule
Expand Down Expand Up @@ -370,6 +384,9 @@ def __init__(
"""See :meth:`pybamm.BinaryOperator.__init__()`."""
super().__init__("@", left, right)

def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol:
return MatrixMultiplication(left, right) # pragma: no cover

def diff(self, variable):
"""See :meth:`pybamm.Symbol.diff()`."""
# We shouldn't need this
Expand Down Expand Up @@ -419,6 +436,9 @@ def __init__(
"""See :meth:`pybamm.BinaryOperator.__init__()`."""
super().__init__("/", left, right)

def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol:
return Division(left, right)

def _diff(self, variable: pybamm.Symbol):
"""See :meth:`pybamm.Symbol._diff()`."""
# apply quotient rule
Expand Down Expand Up @@ -467,6 +487,9 @@ def __init__(
"""See :meth:`pybamm.BinaryOperator.__init__()`."""
super().__init__("inner product", left, right)

def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol:
return Inner(left, right) # pragma: no cover

def _diff(self, variable: pybamm.Symbol):
"""See :meth:`pybamm.Symbol._diff()`."""
# apply product rule
Expand Down Expand Up @@ -544,6 +567,9 @@ def __init__(
"""See :meth:`pybamm.BinaryOperator.__init__()`."""
super().__init__("==", left, right)

def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol:
return Equality(left, right)

def diff(self, variable):
"""See :meth:`pybamm.Symbol.diff()`."""
# Equality should always be multiplied by something else so hopefully don't
Expand Down Expand Up @@ -601,6 +627,10 @@ def __init__(
):
"""See :meth:`pybamm.BinaryOperator.__init__()`."""
super().__init__(name, left, right)
self.name = name

def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol:
return _Heaviside(self.name, left, right) # pragma: no cover

def diff(self, variable):
"""See :meth:`pybamm.Symbol.diff()`."""
Expand Down Expand Up @@ -679,6 +709,9 @@ def __init__(
):
super().__init__("%", left, right)

def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol:
return Modulo(left, right)

def _diff(self, variable: pybamm.Symbol):
"""See :meth:`pybamm.Symbol._diff()`."""
# apply chain rule and power rule
Expand Down Expand Up @@ -721,6 +754,9 @@ def __init__(
):
super().__init__("minimum", left, right)

def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol:
return Minimum(left, right)

def __str__(self):
"""See :meth:`pybamm.Symbol.__str__()`."""
return f"minimum({self.left!s}, {self.right!s})"
Expand Down Expand Up @@ -765,6 +801,9 @@ def __init__(
):
super().__init__("maximum", left, right)

def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol:
return Maximum(left, right)

def __str__(self):
"""See :meth:`pybamm.Symbol.__str__()`."""
return f"maximum({self.left!s}, {self.right!s})"
Expand Down Expand Up @@ -1539,7 +1578,7 @@ def source(
corresponding to a source term in the bulk.
"""
# Broadcast if left is number
if isinstance(left, numbers.Number):
if isinstance(left, (int, float)):
left = pybamm.PrimaryBroadcast(left, "current collector")

# force type cast for mypy
Expand Down
17 changes: 14 additions & 3 deletions src/pybamm/expression_tree/broadcasts.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,7 @@ def _from_json(cls, snippet):
)

def _unary_new_copy(self, child: pybamm.Symbol, perform_simplifications=True):
"""See :meth:`pybamm.UnaryOperator._unary_new_copy()`."""
return self.__class__(child, self.broadcast_domain)
pass # pragma: no cover
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why was this removed?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In src/pybamm/expression_tree/broadcasts.py I've forced child classes of Broadcast to implement _unary_new_copy because earlier we were using self.broadcast_domain in this function in Broadcast class but it does not have any attribute self.broadcast_domain, this function was only getting called by instance of their child classes which has self.broadcast_domain property.



class PrimaryBroadcast(Broadcast):
Expand Down Expand Up @@ -191,6 +190,10 @@ def reduce_one_dimension(self):
"""Reduce the broadcast by one dimension."""
return self.orphans[0]

def _unary_new_copy(self, child: pybamm.Symbol, perform_simplifications=True):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are these functions being repeated?

Copy link
Member Author

@Rishab87 Rishab87 Apr 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As mentioned above now it needs to be overrided in child classes

"""See :meth:`pybamm.UnaryOperator._unary_new_copy()`."""
return self.__class__(child, self.broadcast_domain)


class PrimaryBroadcastToEdges(PrimaryBroadcast):
"""A primary broadcast onto the edges of the domain."""
Expand Down Expand Up @@ -321,6 +324,10 @@ def reduce_one_dimension(self):
"""Reduce the broadcast by one dimension."""
return self.orphans[0]

def _unary_new_copy(self, child: pybamm.Symbol, perform_simplifications=True):
"""See :meth:`pybamm.UnaryOperator._unary_new_copy()`."""
return self.__class__(child, self.broadcast_domain)


class SecondaryBroadcastToEdges(SecondaryBroadcast):
"""A secondary broadcast onto the edges of a domain."""
Expand Down Expand Up @@ -438,6 +445,10 @@ def reduce_one_dimension(self):
"""Reduce the broadcast by one dimension."""
raise NotImplementedError

def _unary_new_copy(self, child: pybamm.Symbol, perform_simplifications=True):
"""See :meth:`pybamm.UnaryOperator._unary_new_copy()`."""
return self.__class__(child, self.broadcast_domain)


class TertiaryBroadcastToEdges(TertiaryBroadcast):
"""A tertiary broadcast onto the edges of a domain."""
Expand All @@ -463,7 +474,7 @@ def __init__(
self,
child_input: Numeric | pybamm.Symbol,
broadcast_domain: DomainType = None,
auxiliary_domains: AuxiliaryDomainType = None,
auxiliary_domains: AuxiliaryDomainType | str = None,
broadcast_domains: DomainsType = None,
name: str | None = None,
):
Expand Down
7 changes: 4 additions & 3 deletions src/pybamm/expression_tree/concatenations.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,7 +516,7 @@ def substrings(s: str):
yield s[i : j + 1]


def intersect(s1: str, s2: str):
def intersect(s1: str, s2: str) -> str:
# find all the common strings between two strings
all_intersects = set(substrings(s1)) & set(substrings(s2))
# intersect is the longest such intercept
Expand All @@ -527,7 +527,7 @@ def intersect(s1: str, s2: str):
return intersect.lstrip().rstrip()


def simplified_concatenation(*children, name: str | None = None):
def simplified_concatenation(*children: pybamm.Symbol, name: str | None = None):
"""Perform simplifications on a concatenation."""
# remove children that are None
children = list(filter(lambda x: x is not None, children))
Expand All @@ -543,7 +543,8 @@ def simplified_concatenation(*children, name: str | None = None):
# Create Concatenation to easily read domains
concat = Concatenation(*children, name=name)
if all(
isinstance(child, pybamm.Broadcast) and child.child == children[0].child
isinstance(child, pybamm.Broadcast)
and getattr(child, "child", None) == getattr(children[0], "child", None)
for child in children
):
unique_child = children[0].orphans[0]
Expand Down
4 changes: 3 additions & 1 deletion src/pybamm/expression_tree/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -965,7 +965,9 @@ def to_casadi(
"""
return pybamm.CasadiConverter(casadi_symbols).convert(self, t, y, y_dot, inputs)

def _children_for_copying(self, children: list[Symbol] | None = None) -> Symbol:
def _children_for_copying(
self, children: list[Symbol] | None = None
) -> list[Symbol]:
"""
Gets existing children for a symbol being copied if they aren't provided.
"""
Expand Down
2 changes: 2 additions & 0 deletions src/pybamm/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ def __init__(self, name="Unnamed model"):
self.use_jacobian = True
self.convert_to_format = "casadi"

self.calculate_sensitivities: list[str] = []

# Model is not initially discretised
self.is_discretised = False
self.y_slices = None
Expand Down
2 changes: 1 addition & 1 deletion src/pybamm/plotting/quick_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def __init__(
# Set colors, linestyles, figsize, axis limits
# call LoopList to make sure list index never runs out
if colors is None:
self.colors = LoopList(colors or ["r", "b", "k", "g", "m", "c"])
self.colors = LoopList(["r", "b", "k", "g", "m", "c"])
else:
self.colors = LoopList(colors)
self.linestyles = LoopList(linestyles or ["-", ":", "--", "-."])
Expand Down
10 changes: 5 additions & 5 deletions src/pybamm/solvers/base_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ def supports_parallel_solve(self):
def requires_explicit_sensitivities(self):
return True

@root_method.setter
def root_method(self, method):
@root_method.setter # type: ignore[attr-defined, no-redef]
def root_method(self, method) -> None:
if method == "casadi":
method = pybamm.CasadiAlgebraicSolver(self.root_tol)
elif isinstance(method, str):
Expand Down Expand Up @@ -1122,7 +1122,7 @@ def _set_sens_initial_conditions_from(
"""

ninputs = len(model.calculate_sensitivities)
initial_conditions = tuple([] for _ in range(ninputs))
initial_conditions: tuple[list[float], ...] = tuple([] for _ in range(ninputs))
solution = solution.last_state
for var in model.initial_conditions:
final_state = solution[var.name]
Expand All @@ -1143,10 +1143,10 @@ def _set_sens_initial_conditions_from(
slices = [y_slices[symbol][0] for symbol in model.initial_conditions.keys()]

# sort equations according to slices
concatenated_initial_conditions = [
concatenated_initial_conditions = tuple(
casadi.vertcat(*[eq for _, eq in sorted(zip(slices, init))])
for init in initial_conditions
]
)
return concatenated_initial_conditions

def process_t_interp(self, t_interp):
Expand Down
12 changes: 6 additions & 6 deletions src/pybamm/solvers/idaklu_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,9 +259,9 @@ def f_isolated(*args, **kwargs):

def jax_value(
self,
t: npt.NDArray = None,
inputs: Union[dict, None] = None,
output_variables: Union[list[str], None] = None,
t: npt.NDArray | None = None,
inputs: dict | None = None,
output_variables: list[str] | None = None,
):
"""Helper function to compute the gradient of a jaxified expression

Expand Down Expand Up @@ -292,9 +292,9 @@ def jax_value(

def jax_grad(
self,
t: npt.NDArray = None,
inputs: Union[dict, None] = None,
output_variables: Union[list[str], None] = None,
t: npt.NDArray | None = None,
inputs: dict | None = None,
output_variables: list[str] | None = None,
):
"""Helper function to compute the gradient of a jaxified expression

Expand Down
2 changes: 1 addition & 1 deletion src/pybamm/solvers/processed_variable_time_integral.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
@dataclass
class ProcessedVariableTimeIntegral:
method: Literal["discrete", "continuous"]
initial_condition: npt.NDArray
initial_condition: npt.NDArray | float
discrete_times: Optional[npt.NDArray]

@staticmethod
Expand Down
2 changes: 1 addition & 1 deletion src/pybamm/solvers/solution.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def __init__(
def has_sensitivities(self) -> bool:
if isinstance(self._all_sensitivities, bool):
return self._all_sensitivities
elif isinstance(self._all_sensitivities, dict):
else:
return len(self._all_sensitivities) > 0

def extract_explicit_sensitivities(self):
Expand Down
Loading