diff --git a/copt/utils.py b/copt/utils.py index 39086b94..797790a3 100644 --- a/copt/utils.py +++ b/copt/utils.py @@ -106,7 +106,7 @@ def fast_csr_vm(x, data, indptr, indices, d, idx): return res -@njit(nogil=True) +@njit(parallel=True) def fast_csr_mv(data, indptr, indices, x, idx): """ Returns the matrix vector product M[idx] * x. M is described @@ -120,10 +120,13 @@ def fast_csr_mv(data, indptr, indices, x, idx): """ res = np.zeros(len(idx)) - for i, row_idx in np.ndenumerate(idx): - for k, j in enumerate(range(indptr[row_idx], indptr[row_idx+1])): + for i in prange(len(idx)): + row_idx = idx[i] + res_i = 0.0 + for j in range(indptr[row_idx], indptr[row_idx+1]): j_idx = indices[j] - res[i] += x[j_idx] * data[j] + res_i += x[j_idx] * data[j] + res[i] = res_i return res