Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions diffsol/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ pub enum OdeSolverError {
StopTimeBeforeCurrentTime { stop_time: f64, state_time: f64 },
#[error("Mass matrix not supported for this solver")]
MassMatrixNotSupported,
#[error("Stochastic RHS term not supported for this solver")]
StochNotSupported,
#[error("Stop time is at the current state time")]
StopTimeAtCurrentTime,
#[error("Interpolation time is after current time")]
Expand Down
3 changes: 2 additions & 1 deletion diffsol/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ pub use ode_equations::{
sens_equations::SensInit, sens_equations::SensRhs, AugmentedOdeEquations,
AugmentedOdeEquationsImplicit, NoAug, OdeEquations, OdeEquationsAdjoint, OdeEquationsImplicit,
OdeEquationsImplicitAdjoint, OdeEquationsImplicitSens, OdeEquationsRef, OdeEquationsSens,
OdeEquationsStoch, OdeSolverEquations,
OdeSolverEquations, StochEnum,
};
use ode_solver::jacobian_update::JacobianUpdate;
pub use ode_solver::sde::SdeSolverMethod;
Expand All @@ -201,6 +201,7 @@ pub use ode_solver::{
method::AugmentedOdeSolverMethod, method::OdeSolverMethod, method::OdeSolverStopReason,
problem::OdeSolverProblem, sdirk::Sdirk, sdirk_state::RkState,
sensitivities::SensitivitiesOdeSolverMethod, state::OdeSolverState, tableau::Tableau,
tableau_sde::TableauSde,
};
pub use op::constant_op::{ConstantOp, ConstantOpSens, ConstantOpSensAdjoint};
pub use op::linear_op::{LinearOp, LinearOpSens, LinearOpTranspose};
Expand Down
18 changes: 18 additions & 0 deletions diffsol/src/matrix/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,8 @@ pub trait MatrixView<'a>:
type Owned;

fn into_owned(self) -> Self::Owned;



/// Perform a matrix-vector multiplication `y = self * x + beta * y`.
fn gemv_v(
Expand All @@ -129,6 +131,19 @@ pub trait MatrixView<'a>:
);

fn gemv_o(&self, alpha: Self::T, x: &Self::V, beta: Self::T, y: &mut Self::V);

/// Perform a matrix-vector multiplication that is scaled by a vector instead of a scalar `y += alpha .* self * x`.
fn scaled_gemv_o(
&self,
alpha: &Self::V,
x: &Self::V,
y: &mut Self::V,
) {
let mut temp = Self::V::zeros(y.len(), self.context().clone());
self.gemv(Self::T::one(), x, Self::T::zero(), &mut temp);
temp.mul_assign(alpha);
y.add_assign(&temp);
}
}

/// A base matrix trait (including sparse and dense matrices)
Expand All @@ -153,6 +168,9 @@ pub trait Matrix: MatrixCommon + Mul<Scale<Self::T>, Output = Self> + Clone + 's

/// Perform a matrix-vector multiplication `y = alpha * self * x + beta * y`.
fn gemv(&self, alpha: Self::T, x: &Self::V, beta: Self::T, y: &mut Self::V);




/// Copy the contents of `other` into `self`
fn copy_from(&mut self, other: &Self);
Expand Down
41 changes: 23 additions & 18 deletions diffsol/src/ode_equations/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,8 @@ pub trait OdeEquationsRef<'a, ImplicitBounds: Sealed = Bounds<&'a Self>>: Op {
type Root: NonLinearOp<M = Self::M, V = Self::V, T = Self::T, C = Self::C>;
type Init: ConstantOp<M = Self::M, V = Self::V, T = Self::T, C = Self::C>;
type Out: NonLinearOp<M = Self::M, V = Self::V, T = Self::T, C = Self::C>;
type Stoch: NonLinearOp<M = Self::M, V = Self::V, T = Self::T, C = Self::C>;
type StochAdditive: LinearOp<M = Self::M, V = Self::V, T = Self::T, C = Self::C>;
}

impl<'a, T: OdeEquationsRef<'a>> OdeEquationsRef<'a> for &T {
Expand All @@ -205,6 +207,15 @@ impl<'a, T: OdeEquationsRef<'a>> OdeEquationsRef<'a> for &T {
type Root = <T as OdeEquationsRef<'a>>::Root;
type Init = <T as OdeEquationsRef<'a>>::Init;
type Out = <T as OdeEquationsRef<'a>>::Out;
type Stoch = <T as OdeEquationsRef<'a>>::Stoch;
type StochAdditive = <T as OdeEquationsRef<'a>>::StochAdditive;
}

pub enum StochEnum<A: NonLinearOp, B: LinearOp> {
Scalar(A),
Diagonal(A),
Additive(B),
None,
}

// seal the trait so that users must use the provided default type for ImplicitBounds
Expand Down Expand Up @@ -235,7 +246,9 @@ pub trait OdeEquations: for<'a> OdeEquationsRef<'a> {
fn rhs(&self) -> <Self as OdeEquationsRef<'_>>::Rhs;

/// returns the mass matrix `M` as a [LinearOp]
fn mass(&self) -> Option<<Self as OdeEquationsRef<'_>>::Mass>;
fn mass(&self) -> Option<<Self as OdeEquationsRef<'_>>::Mass> {
None
}

/// returns the root function `G(t, y)` as a [NonLinearOp]
fn root(&self) -> Option<<Self as OdeEquationsRef<'_>>::Root> {
Expand All @@ -247,6 +260,10 @@ pub trait OdeEquations: for<'a> OdeEquationsRef<'a> {
None
}

fn stoch(&self) -> StochEnum<<Self as OdeEquationsRef<'_>>::Stoch, <Self as OdeEquationsRef<'_>>::StochAdditive> {
StochEnum::None
}

/// returns the initial condition, i.e. `y(t)`, where `t` is the initial time
fn init(&self) -> <Self as OdeEquationsRef<'_>>::Init;

Expand Down Expand Up @@ -307,7 +324,11 @@ impl<T: OdeEquations> OdeEquations for &'_ T {
fn out(&self) -> Option<<Self as OdeEquationsRef<'_>>::Out> {
(*self).out()
}


fn stoch(&self) -> Option<<Self as OdeEquationsRef<'_>>::Stoch> {
(*self).stoch()
}

fn init(&self) -> <Self as OdeEquationsRef<'_>>::Init {
(*self).init()
}
Expand All @@ -331,22 +352,6 @@ impl<T> OdeEquationsImplicit for T where
{
}

pub trait OdeEquationsStoch:
OdeEquations<
Rhs: NonLinearOp<M = Self::M, V = Self::V, T = Self::T, C = Self::C>
+ StochOp<M = Self::M, V = Self::V, T = Self::T, C = Self::C>,
>
{
}

impl<T> OdeEquationsStoch for T where
T: OdeEquations<
Rhs: NonLinearOp<M = T::M, V = T::V, T = T::T, C = T::C>
+ StochOp<M = T::M, V = T::V, T = T::T, C = T::C>,
>
{
}

pub trait OdeEquationsSens:
OdeEquations<
Rhs: NonLinearOpSens<M = Self::M, V = Self::V, T = Self::T, C = Self::C>,
Expand Down
4 changes: 4 additions & 0 deletions diffsol/src/ode_solver/bdf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,10 @@ where
mut nonlinear_solver: Nls,
integrate_main_eqn: bool,
) -> Result<Self, DiffsolError> {
// check that there isn't any diffusion term
if problem.eqn.stoch().is_some() {
return Err(DiffsolError::from(OdeSolverError::StochNotSupported));
}
// kappa values for difference orders, taken from Table 1 of [1]
let kappa = [
Eqn::T::from(0.0),
Expand Down
9 changes: 9 additions & 0 deletions diffsol/src/ode_solver/explicit_rk.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use super::method::AugmentedOdeSolverMethod;
use super::runge_kutta::Rk;
use crate::error::DiffsolError;
use crate::error::OdeSolverError;
use crate::ode_solver::bdf::BdfStatistics;
use crate::vector::VectorRef;
use crate::NoAug;
Expand Down Expand Up @@ -81,6 +82,10 @@ where
tableau: Tableau<M>,
) -> Result<Self, DiffsolError> {
Rk::<Eqn, M>::check_explicit_rk(problem, &tableau)?;
// check that there isn't any diffusion term
if problem.eqn.stoch().is_some() {
return Err(DiffsolError::from(OdeSolverError::StochNotSupported));
}
Ok(Self {
rk: Rk::new(problem, state, tableau)?,
augmented_eqn: None,
Expand All @@ -94,6 +99,10 @@ where
augmented_eqn: AugmentedEqn,
) -> Result<Self, DiffsolError> {
Rk::<Eqn, M>::check_explicit_rk(problem, &tableau)?;
// check that there isn't any diffusion term
if problem.eqn.stoch().is_some() {
return Err(DiffsolError::from(OdeSolverError::StochNotSupported));
}
Ok(Self {
rk: Rk::new_augmented(problem, state, tableau, &augmented_eqn)?,
augmented_eqn: Some(augmented_eqn),
Expand Down
Loading
Loading