Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 55 additions & 18 deletions src/pybammsolvers/idaklu_source/IDAKLUSolverOpenMP.inl
Original file line number Diff line number Diff line change
Expand Up @@ -855,32 +855,69 @@ void IDAKLUSolverOpenMP<ExprSet>::SetStepOutputSensitivities(
sunrealtype &tval,
sunrealtype *y_val,
const vector<sunrealtype*>& yS_val,
int &i_save
int &i_save
) {
DEBUG("IDAKLUSolver::SetStepOutputSensitivities");
// Calculate sensitivities
vector<sunrealtype> dens_dvar_dp = vector<sunrealtype>(number_of_parameters, 0);
for (size_t dvar_k=0; dvar_k<functions->dvar_dy_fcns.size(); dvar_k++) {
// Isolate functions

// Running index over the flattened outputs
size_t global_out_idx = 0;

// Loop over each variable
for (size_t dvar_k = 0; dvar_k < functions->var_fcns.size(); ++dvar_k) {
Expression* dvar_dy = functions->dvar_dy_fcns[dvar_k];
Expression* dvar_dp = functions->dvar_dp_fcns[dvar_k];

// Calculate dvar/dy
(*dvar_dy)({&tval, y_val, functions->inputs.data()}, {&res_dvar_dy[0]});
// Calculate dvar/dp and convert to dense array for indexing
// Calculate dvar/dp
(*dvar_dp)({&tval, y_val, functions->inputs.data()}, {&res_dvar_dp[0]});
for (int k=0; k<number_of_parameters; k++) {
dens_dvar_dp[k]=0;
}
for (int k=0; k<dvar_dp->nnz_out(); k++) {
dens_dvar_dp[dvar_dp->get_row()[k]] = res_dvar_dp[k];
}
// Calculate sensitivities
for (int paramk=0; paramk<number_of_parameters; paramk++) {
auto &yS_back_paramk = yS[i_save][paramk];
yS_back_paramk[dvar_k] = dens_dvar_dp[paramk];

for (int spk=0; spk<dvar_dy->nnz_out(); spk++) {
yS_back_paramk[dvar_k] += res_dvar_dy[spk] * yS_val[paramk][dvar_dy->get_col()[spk]];
// Get number of output components for this function (e.g., scalar → 1; vector → >1)
const size_t n_rows = functions->var_fcns[dvar_k]->nnz_out();

// Number of nonzeros in the sparse Jacobians (for dvar/dy and dvar/dp)
const size_t dvar_dy_nnz = dvar_dy->nnz_out();
const size_t dvar_dp_nnz = dvar_dp->nnz_out();

// Row/column indices of nonzero entries (compressed sparse row format)
const auto& dvar_dy_row = dvar_dy->get_row(); // output component (row) for y
const auto& dvar_dy_col = dvar_dy->get_col(); // state variable index (column) for y
const auto& dvar_dp_row = dvar_dp->get_row(); // output component (row) for p
const auto& dvar_dp_col = dvar_dp->get_col(); // parameter index (column) for p

// Temporary dense vector to hold doutput_row/dp_k for each parameter
vector<sunrealtype> dvar_dp_dense(number_of_parameters, 0.0);

// Loop over each scalar component (row) of the output function
for (size_t row = 0; row < n_rows; ++row, ++global_out_idx) {
// Dense dvar_row/dp_k vector (reset to zero)
std::fill(dvar_dp_dense.begin(), dvar_dp_dense.end(), 0.0);

// Fill in dvar_row/dp_k from sparse structure
for (size_t nz = 0; nz < dvar_dp_nnz; ++nz) {
if (dvar_dp_row[nz] == static_cast<int>(row)) {
// dvar_dp_col[nz] is parameter index k
dvar_dp_dense[dvar_dp_col[nz]] = res_dvar_dp[nz]; // direct derivative
}
}

// For each parameter p_k, compute total d(output_row)/d(p_k)
for (int paramk = 0; paramk < number_of_parameters; paramk++) {
auto &yS_back_paramk = yS[i_save][paramk]; // Sensitivity vector for p_k at save step i_save

// Start with direct contribution doutput/dp_k
sunrealtype sens = dvar_dp_dense[paramk];

// Add chain rule term
for (size_t nz = 0; nz < dvar_dy_nnz; ++nz) {
if (dvar_dy_row[nz] == static_cast<int>(row)) {
// dvar_dy_col[nz] = j (state index)
// yS_val[paramk][j] = dy_j/dp_k
sens += res_dvar_dy[nz] * yS_val[paramk][dvar_dy_col[nz]];
}
}

yS_back_paramk[global_out_idx] = sens;
}
}
}
Expand Down
Loading