Skip to content

Commit 4ffb853

Browse files
Incomplete sensitivities for 1D+ output_variables (#52)
* Iterate correctly over sensitivities for multidimensional variables * Use reference rather than writing directly to yS * Cache matrix properties * use size_t for nnz_out constants * Add comments --------- Co-authored-by: Agriya Khetarpal <[email protected]>
1 parent c68fd83 commit 4ffb853

File tree

1 file changed

+55
-18
lines changed

1 file changed

+55
-18
lines changed

src/pybammsolvers/idaklu_source/IDAKLUSolverOpenMP.inl

Lines changed: 55 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -855,32 +855,69 @@ void IDAKLUSolverOpenMP<ExprSet>::SetStepOutputSensitivities(
855855
sunrealtype &tval,
856856
sunrealtype *y_val,
857857
const vector<sunrealtype*>& yS_val,
858-
int &i_save
858+
int &i_save
859859
) {
860860
DEBUG("IDAKLUSolver::SetStepOutputSensitivities");
861-
// Calculate sensitivities
862-
vector<sunrealtype> dens_dvar_dp = vector<sunrealtype>(number_of_parameters, 0);
863-
for (size_t dvar_k=0; dvar_k<functions->dvar_dy_fcns.size(); dvar_k++) {
864-
// Isolate functions
861+
862+
// Running index over the flattened outputs
863+
size_t global_out_idx = 0;
864+
865+
// Loop over each variable
866+
for (size_t dvar_k = 0; dvar_k < functions->var_fcns.size(); ++dvar_k) {
865867
Expression* dvar_dy = functions->dvar_dy_fcns[dvar_k];
866868
Expression* dvar_dp = functions->dvar_dp_fcns[dvar_k];
869+
867870
// Calculate dvar/dy
868871
(*dvar_dy)({&tval, y_val, functions->inputs.data()}, {&res_dvar_dy[0]});
869-
// Calculate dvar/dp and convert to dense array for indexing
872+
// Calculate dvar/dp
870873
(*dvar_dp)({&tval, y_val, functions->inputs.data()}, {&res_dvar_dp[0]});
871-
for (int k=0; k<number_of_parameters; k++) {
872-
dens_dvar_dp[k]=0;
873-
}
874-
for (int k=0; k<dvar_dp->nnz_out(); k++) {
875-
dens_dvar_dp[dvar_dp->get_row()[k]] = res_dvar_dp[k];
876-
}
877-
// Calculate sensitivities
878-
for (int paramk=0; paramk<number_of_parameters; paramk++) {
879-
auto &yS_back_paramk = yS[i_save][paramk];
880-
yS_back_paramk[dvar_k] = dens_dvar_dp[paramk];
881874

882-
for (int spk=0; spk<dvar_dy->nnz_out(); spk++) {
883-
yS_back_paramk[dvar_k] += res_dvar_dy[spk] * yS_val[paramk][dvar_dy->get_col()[spk]];
875+
// Get number of output components for this function (e.g., scalar → 1; vector → >1)
876+
const size_t n_rows = functions->var_fcns[dvar_k]->nnz_out();
877+
878+
// Number of nonzeros in the sparse Jacobians (for dvar/dy and dvar/dp)
879+
const size_t dvar_dy_nnz = dvar_dy->nnz_out();
880+
const size_t dvar_dp_nnz = dvar_dp->nnz_out();
881+
882+
// Row/column indices of nonzero entries (compressed sparse row format)
883+
const auto& dvar_dy_row = dvar_dy->get_row(); // output component (row) for y
884+
const auto& dvar_dy_col = dvar_dy->get_col(); // state variable index (column) for y
885+
const auto& dvar_dp_row = dvar_dp->get_row(); // output component (row) for p
886+
const auto& dvar_dp_col = dvar_dp->get_col(); // parameter index (column) for p
887+
888+
// Temporary dense vector to hold doutput_row/dp_k for each parameter
889+
vector<sunrealtype> dvar_dp_dense(number_of_parameters, 0.0);
890+
891+
// Loop over each scalar component (row) of the output function
892+
for (size_t row = 0; row < n_rows; ++row, ++global_out_idx) {
893+
// Dense dvar_row/dp_k vector (reset to zero)
894+
std::fill(dvar_dp_dense.begin(), dvar_dp_dense.end(), 0.0);
895+
896+
// Fill in dvar_row/dp_k from sparse structure
897+
for (size_t nz = 0; nz < dvar_dp_nnz; ++nz) {
898+
if (dvar_dp_row[nz] == static_cast<int>(row)) {
899+
// dvar_dp_col[nz] is parameter index k
900+
dvar_dp_dense[dvar_dp_col[nz]] = res_dvar_dp[nz]; // direct derivative
901+
}
902+
}
903+
904+
// For each parameter p_k, compute total d(output_row)/d(p_k)
905+
for (int paramk = 0; paramk < number_of_parameters; paramk++) {
906+
auto &yS_back_paramk = yS[i_save][paramk]; // Sensitivity vector for p_k at save step i_save
907+
908+
// Start with direct contribution doutput/dp_k
909+
sunrealtype sens = dvar_dp_dense[paramk];
910+
911+
// Add chain rule term
912+
for (size_t nz = 0; nz < dvar_dy_nnz; ++nz) {
913+
if (dvar_dy_row[nz] == static_cast<int>(row)) {
914+
// dvar_dy_col[nz] = j (state index)
915+
// yS_val[paramk][j] = dy_j/dp_k
916+
sens += res_dvar_dy[nz] * yS_val[paramk][dvar_dy_col[nz]];
917+
}
918+
}
919+
920+
yS_back_paramk[global_out_idx] = sens;
884921
}
885922
}
886923
}

0 commit comments

Comments
 (0)