Skip to content

Commit 6c87755

Browse files
author
chmerdon
committed
fixed sparsity pattern issue
1 parent e1c9e43 commit 6c87755

File tree

3 files changed

+21
-7
lines changed

3 files changed

+21
-7
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
name = "ExtendableFEM"
22
uuid = "a722555e-65e0-4074-a036-ca7ce79a4aed"
3-
version = "1.3.0"
43
authors = ["Christian Merdon <[email protected]>", "Patrick Jaap <[email protected]>"]
4+
version = "1.3.0"
55

66
[deps]
7+
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
78
CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2"
89
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
910
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
@@ -25,6 +26,7 @@ TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
2526
UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228"
2627

2728
[compat]
29+
ADTypes = "1.16.0"
2830
Aqua = "0.8"
2931
CommonSolve = "0.2"
3032
DiffResults = "1"

src/ExtendableFEM.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,13 +68,13 @@ using LinearSolve: LinearSolve, LinearProblem, UMFPACKFactorization, deleteat!,
6868
using Printf: Printf, @printf, @sprintf
6969
using SparseArrays: SparseArrays, AbstractSparseArray, SparseMatrixCSC, findnz, nnz,
7070
nzrange, rowvals, sparse
71-
71+
using ADTypes: ADTypes, KnownJacobianSparsityDetector
7272
using SparseConnectivityTracer: SparseConnectivityTracer, TracerSparsityDetector
7373
using DifferentiationInterface: DifferentiationInterface,
7474
AutoSparse,
7575
AutoForwardDiff,
7676
prepare_jacobian
77-
using SparseMatrixColorings: GreedyColoringAlgorithm #, sparsity_pattern
77+
using SparseMatrixColorings: GreedyColoringAlgorithm, sparsity_pattern
7878
using Symbolics: Symbolics
7979
using SciMLBase: SciMLBase
8080
using TimerOutputs: TimerOutput, print_timer, @timeit

src/common_operators/nonlinear_operator.jl

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -229,17 +229,21 @@ function build_assembler!(A::AbstractMatrix, b::AbstractVector, O::NonlinearOper
229229
Kj = Array{KernelEvaluator, 1}([])
230230

231231
sparse_jacobians = O.parameters[:sparse_jacobians]
232-
sparsity_pattern = O.parameters[:sparse_jacobians_pattern]
233232
use_autodiff = O.jacobian === nothing
234233
for EG in EGs
235234
## prepare parameters
236235
QPj = QPInfos(xgrid; time = time, x = ones(Tv, size(xgrid[Coordinates], 1)), params = O.parameters[:params])
237236
kernel_params = (result, input) -> (O.kernel(result, input, QPj))
238237
input_args = zeros(Tv, op_offsets_args[end] + O.parameters[:extra_inputsize])
239238
result_kernel = zeros(Tv, op_offsets_test[end])
239+
if O.parameters[:sparse_jacobians_pattern] === nothing
240+
sparsity_detector = TracerSparsityDetector()
241+
else
242+
sparsity_detector = KnownJacobianSparsityDetector(O.parameters[:sparse_jacobians_pattern])
243+
end
240244
sparse_forward_backend = AutoSparse(
241245
O.parameters[:autodiff_backend];
242-
sparsity_detector = TracerSparsityDetector(),
246+
sparsity_detector = sparsity_detector,
243247
coloring_algorithm = GreedyColoringAlgorithm()
244248
)
245249
jac_prep = prepare_jacobian(kernel_params, result_kernel, sparse_forward_backend, input_args)
@@ -287,9 +291,17 @@ function build_assembler!(A::AbstractMatrix, b::AbstractVector, O::NonlinearOper
287291
value = K.result_kernel
288292
jac_prep = K.jac_prep
289293
jac_backend = K.jac_backend
290-
# todo: get sparse jacobians to work (need to extract sparsity pattern)
291294
sparse_jacobians = false
292-
jac = zeros(Tv, length(value), length(input_args))
295+
if sparse_jacobians
296+
if O.parameters[:sparse_jacobians_pattern] === nothing
297+
jac_sparsity_pattern = DifferentiationInterface.sparsity_pattern(jac_prep)
298+
else
299+
jac_sparsity_pattern = O.parameters[:sparse_jacobians_pattern]
300+
end
301+
jac = Tv.(sparse(sparsity_pattern))
302+
else
303+
jac = zeros(Tv, length(value), length(input_args))
304+
end
293305
kernel_params = K.kernel
294306
params.time = time
295307

0 commit comments

Comments
 (0)