diff --git a/src/pybammsolvers/idaklu_source/IDAKLUSolverOpenMP.inl b/src/pybammsolvers/idaklu_source/IDAKLUSolverOpenMP.inl index f939ab4..6082ae8 100644 --- a/src/pybammsolvers/idaklu_source/IDAKLUSolverOpenMP.inl +++ b/src/pybammsolvers/idaklu_source/IDAKLUSolverOpenMP.inl @@ -855,32 +855,69 @@ void IDAKLUSolverOpenMP::SetStepOutputSensitivities( sunrealtype &tval, sunrealtype *y_val, const vector& yS_val, - int &i_save + int &i_save ) { DEBUG("IDAKLUSolver::SetStepOutputSensitivities"); - // Calculate sensitivities - vector dens_dvar_dp = vector(number_of_parameters, 0); - for (size_t dvar_k=0; dvar_kdvar_dy_fcns.size(); dvar_k++) { - // Isolate functions + + // Running index over the flattened outputs + size_t global_out_idx = 0; + + // Loop over each variable + for (size_t dvar_k = 0; dvar_k < functions->var_fcns.size(); ++dvar_k) { Expression* dvar_dy = functions->dvar_dy_fcns[dvar_k]; Expression* dvar_dp = functions->dvar_dp_fcns[dvar_k]; + // Calculate dvar/dy (*dvar_dy)({&tval, y_val, functions->inputs.data()}, {&res_dvar_dy[0]}); - // Calculate dvar/dp and convert to dense array for indexing + // Calculate dvar/dp (*dvar_dp)({&tval, y_val, functions->inputs.data()}, {&res_dvar_dp[0]}); - for (int k=0; knnz_out(); k++) { - dens_dvar_dp[dvar_dp->get_row()[k]] = res_dvar_dp[k]; - } - // Calculate sensitivities - for (int paramk=0; paramknnz_out(); spk++) { - yS_back_paramk[dvar_k] += res_dvar_dy[spk] * yS_val[paramk][dvar_dy->get_col()[spk]]; + // Get number of output components for this function (e.g., scalar → 1; vector → >1) + const size_t n_rows = functions->var_fcns[dvar_k]->nnz_out(); + + // Number of nonzeros in the sparse Jacobians (for dvar/dy and dvar/dp) + const size_t dvar_dy_nnz = dvar_dy->nnz_out(); + const size_t dvar_dp_nnz = dvar_dp->nnz_out(); + + // Row/column indices of nonzero entries (compressed sparse row format) + const auto& dvar_dy_row = dvar_dy->get_row(); // output component (row) for y + const auto& dvar_dy_col = dvar_dy->get_col(); // state variable index (column) for y + const auto& dvar_dp_row = dvar_dp->get_row(); // output component (row) for p + const auto& dvar_dp_col = dvar_dp->get_col(); // parameter index (column) for p + + // Temporary dense vector to hold doutput_row/dp_k for each parameter + vector dvar_dp_dense(number_of_parameters, 0.0); + + // Loop over each scalar component (row) of the output function + for (size_t row = 0; row < n_rows; ++row, ++global_out_idx) { + // Dense dvar_row/dp_k vector (reset to zero) + std::fill(dvar_dp_dense.begin(), dvar_dp_dense.end(), 0.0); + + // Fill in dvar_row/dp_k from sparse structure + for (size_t nz = 0; nz < dvar_dp_nnz; ++nz) { + if (dvar_dp_row[nz] == static_cast(row)) { + // dvar_dp_col[nz] is parameter index k + dvar_dp_dense[dvar_dp_col[nz]] = res_dvar_dp[nz]; // direct derivative + } + } + + // For each parameter p_k, compute total d(output_row)/d(p_k) + for (int paramk = 0; paramk < number_of_parameters; paramk++) { + auto &yS_back_paramk = yS[i_save][paramk]; // Sensitivity vector for p_k at save step i_save + + // Start with direct contribution doutput/dp_k + sunrealtype sens = dvar_dp_dense[paramk]; + + // Add chain rule term + for (size_t nz = 0; nz < dvar_dy_nnz; ++nz) { + if (dvar_dy_row[nz] == static_cast(row)) { + // dvar_dy_col[nz] = j (state index) + // yS_val[paramk][j] = dy_j/dp_k + sens += res_dvar_dy[nz] * yS_val[paramk][dvar_dy_col[nz]]; + } + } + + yS_back_paramk[global_out_idx] = sens; } } }