Skip to content

Commit af2c83a

Browse files
pjaapchmerdon
andauthored
Speedup CombineDofs (#63)
Co-authored-by: Christian Merdon <[email protected]>
1 parent 1070f7b commit af2c83a

File tree

5 files changed

+55
-84
lines changed

5 files changed

+55
-84
lines changed

CHANGELOG.md

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,20 @@
11
# CHANGES
22

3-
4-
## v1.2.0 May 28, 2025
3+
## v1.2.0 June 4, 2025
54

65
### Changed
76
- TimerOutputs for measuring/storing/showing runtime and allocations in solve, now also for separate operators
8-
7+
- Coupling matrix result of `compute_periodic_coupling_matrix` is no longer transposed
8+
- rewrote internals of `CombineDofs` `apply_penalty` method to speed up the assembly
9+
910
### Fixed
1011
- HomogeneousData/InterpolateBoundaryData operator fix when system matrix is of type GenericMTExtendableSparseMatrixCSC
11-
12+
1213
## v1.1.1 April 29, 2025
1314

1415
### Fixed
1516
- FixDofs operator does not crash when system matrix is of type GenericMTExtendableSparseMatrixCSC
16-
17+
1718
## v1.1.0 April 17, 2025
1819

1920
### Changed

src/ExtendableFEM.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ using ExtendableGrids: ExtendableGrids, AT_NODES, AbstractElementGeometry,
5757
unique, update_trafo!, xrefFACE2xrefCELL,
5858
xrefFACE2xrefOFACE
5959
using ExtendableSparse: ExtendableSparse, ExtendableSparseMatrix, flush!,
60-
MTExtendableSparseMatrixCSC,
60+
MTExtendableSparseMatrixCSC, findindex,
6161
rawupdateindex!
6262
using ForwardDiff: ForwardDiff
6363
using GridVisualize: GridVisualize, GridVisualizer, gridplot!, reveal, save,
@@ -75,7 +75,6 @@ using SciMLBase: SciMLBase
7575
using TimerOutputs: TimerOutput, print_timer, @timeit
7676
using UnicodePlots: UnicodePlots
7777

78-
7978
## reexport stuff from ExtendableFEMBase and ExtendableGrids
8079
export FESpace, FEMatrix, FEVector
8180
export H1P1, H1P2, H1P3, H1Pk

src/common_operators/combinedofs.jl

Lines changed: 37 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,19 @@
22
### COMBINE DOFS (e.g. for periodicity) ###
33
###########################################
44

5-
mutable struct CombineDofs{UT, CT} <: AbstractOperator
5+
mutable struct CombineDofs{UT, CT, AT} <: AbstractOperator
66
uX::UT # component nr for dofsX
77
uY::UT # component nr for dofsY
88
coupling_info::CT
9+
fixed_dofs::AT
910
FESX::Any
1011
FESY::Any
1112
assembler::Any
1213
parameters::Dict{Symbol, Any}
1314
end
1415

16+
fixed_dofs(O::CombineDofs) = O.fixed_dofs
17+
1518
default_combop_kwargs() = Dict{Symbol, Tuple{Any, String}}(
1619
:name => ("CombineDofs", "name for operator used in printouts"),
1720
:penalty => (1.0e30, "penalty for fixed degrees of freedom"),
@@ -61,7 +64,14 @@ $(_myprint(default_combop_kwargs()))
6164
function CombineDofs(uX, uY, coupling_matrix::AbstractMatrix; kwargs...)
6265
parameters = Dict{Symbol, Any}(k => v[1] for (k, v) in default_combop_kwargs())
6366
_update_params!(parameters, kwargs)
64-
return CombineDofs(uX, uY, coupling_matrix, nothing, nothing, nothing, parameters)
67+
fixed_dofs = zeros(Int, 0)
68+
for dof_i in 1:size(coupling_matrix, 2)
69+
coupling_i = @views coupling_matrix[:, dof_i]
70+
if nnz(coupling_i) > 0
71+
push!(fixed_dofs, dof_i)
72+
end
73+
end
74+
return CombineDofs(uX, uY, coupling_matrix, fixed_dofs, nothing, nothing, nothing, parameters)
6575
end
6676

6777
function apply_penalties!(A, b, sol, CD::CombineDofs{UT, CT}, SC::SolverConfiguration; assemble_matrix = true, assemble_rhs = true, kwargs...) where {UT, CT}
@@ -80,118 +90,81 @@ function build_assembler!(CD::CombineDofs{UT, CT}, FE::Array{<:FEVectorBlock, 1}
8090
FESX, FESY = FE[1].FES, FE[2].FES
8191
if (CD.FESX != FESX) || (CD.FESY != FESY)
8292
coupling_matrix = CD.coupling_info
93+
fixed_dofs = CD.fixed_dofs
8394
offsetX = FE[1].offset
8495
offsetY = FE[2].offset
8596
if CD.parameters[:verbosity] > 0
8697
@info ".... coupling $(length(coupling_matrix.nzval)) dofs"
8798
end
88-
function assemble!(A::AbstractSparseArray{T}, b::AbstractVector{T}, assemble_matrix::Bool, assemble_rhs::Bool, kwargs...) where {T}
89-
90-
# transpose the matrix once for efficient row access
91-
transposed_coupling_matrix = sparse(transpose(coupling_matrix))
99+
penalty = CD.parameters[:penalty]
92100

101+
function assemble!(A::AbstractSparseArray{T}, b::AbstractVector{T}, assemble_matrix::Bool, assemble_rhs::Bool, kwargs...) where {T}
93102
if assemble_matrix
94-
# go through each coupled dof and update the FE adjacency info
95-
# from the constrained dofs here
96-
97-
for dof_i in 1:size(coupling_matrix, 2)
103+
# go through each constrained dof and update the FE adjacency info
104+
# of the coupled dofs
105+
for dof_i in fixed_dofs
98106
# this col-view is efficient
99107
coupling_i = @views coupling_matrix[:, dof_i]
100-
# do nothing if dof_k is not coupled to any constrained dof
101-
if nnz(coupling_i) == 0
102-
continue
103-
end
104108

105109
# write the FE adjacency of the constrained dofs into this row
106-
targetrow = dof_i + offsetX
110+
sourcerow = dof_i + offsetX
107111

108112
# extract the constrained dofs and the weights
109113
coupled_dofs_i, weights_i = findnz(coupling_i)
110114

111-
# parse through all cols and update the entries
112-
for dof_j in 1:size(coupling_matrix, 2)
113-
# this col-view is efficient
114-
coupling_j = @views coupling_matrix[:, dof_j]
115-
116-
# if both dof_i and dof_j are coupled to a constrained dof, then
117-
# the FE adjacency A_ij is not updated: this is covered by the linear combinations
118-
# expressed in the rows of the constrained dofs_on_boundary
119-
# Hence, check that dof_j is not coupled to anything
120-
if nnz(coupling_j) == 0
121-
targetcol = dof_j + offsetY
122-
for (dof_k, weight_ik) in zip(coupled_dofs_i, weights_i)
123-
sourcerow = dof_k + offsetX
124-
sourcecol = targetcol
125-
val = A[sourcerow, sourcecol]
126-
_addnz(A, targetrow, targetcol, val, weight_ik)
115+
# parse through sourcerow and add the contents to the coupled dofs
116+
for col in 1:size(A, 2)
117+
r = findindex(A.cscmatrix, sourcerow, col)
118+
if r > 0
119+
val = A.cscmatrix.nzval[r]
120+
if abs(val) > 1.0e-15
121+
for (dof_k, weight_ik) in zip(coupled_dofs_i, weights_i)
122+
targetrow = dof_k + offsetX
123+
_addnz(A, targetrow, col, val, weight_ik)
124+
end
127125
end
128126
end
129127
end
130128
end
131129

132130
# replace the geometric coupling rows based
133131
# on the original coupling matrix
134-
for dof_i in 1:size(transposed_coupling_matrix, 2)
135-
136-
coupling_i = transposed_coupling_matrix[:, dof_i]
137-
# do nothing if no coupling for dof_i
138-
if nnz(coupling_i) == 0
139-
continue
140-
end
132+
for dof_i in fixed_dofs
133+
coupling_i = coupling_matrix[:, dof_i]
141134

142135
# get the coupled dofs of dof_i and the corresponding weights
143136
coupled_dofs_i, weights_i = findnz(coupling_i)
144-
145137
sourcerow = dof_i + offsetX
146138

147-
# eliminate the sourcerow
148-
for col in 1:size(A, 2)
149-
A[sourcerow, col] = 0
150-
end
151-
152139
# replace sourcerow with coupling linear combination
153-
_addnz(A, sourcerow, sourcerow, -1.0, 1)
140+
_addnz(A, sourcerow, sourcerow, -1.0, penalty)
154141
for (dof_j, weight_ij) in zip(coupled_dofs_i, weights_i)
155142
# weights for ∑ⱼ wⱼdofⱼ - dofᵢ = 0
156-
_addnz(A, sourcerow, dof_j + offsetY, weight_ij, 1)
143+
_addnz(A, sourcerow, dof_j + offsetY, weight_ij, penalty)
157144
end
158-
159145
end
160146
flush!(A)
161147
end
162148

163149
if assemble_rhs
164-
165-
for dof_i in 1:size(coupling_matrix, 2)
150+
for dof_i in fixed_dofs
166151
# this col-view is efficient
167152
coupling_i = @views coupling_matrix[:, dof_i]
168-
# do nothing if no coupling for dof_i
169-
if nnz(coupling_i) == 0
170-
continue
171-
end
172153

173154
# get the coupled dofs of dof_i and the corresponding weights
174155
coupled_dofs, weights = findnz(coupling_i)
175156

176157
# transfer all assembly information to dof_i
177-
targetrow = dof_i + offsetY
158+
sourcerow = dof_i + offsetY
178159
for (dof_j, weight) in zip(coupled_dofs, weights)
179-
sourcerow = dof_j + offsetY
160+
targetrow = dof_j + offsetY
180161
b[targetrow] += weight * b[sourcerow]
181162
end
182163
end
183164

184-
185165
# now set the rows of the constrained dofs to zero to enforce the linear combination
186-
for dof_i in 1:size(transposed_coupling_matrix, 2)
187-
coupling_i = transposed_coupling_matrix[:, dof_i]
188-
# do nothing if no coupling for dof_i
189-
if nnz(coupling_i) == 0
190-
continue
191-
end
192-
166+
for dof_i in fixed_dofs
193167
b[dof_i + offsetX] = 0.0
194-
195168
end
196169
end
197170

src/helper_functions.jl

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -396,7 +396,7 @@ function _get_periodic_coupling_matrix(
396396
# set entries
397397
for (i, target_entry) in enumerate(fe_vector_target.entries)
398398
if abs(target_entry) > sparsity_tol
399-
result[i, local_dof] = target_entry
399+
result[local_dof, i] = target_entry
400400
end
401401
end
402402
end
@@ -410,7 +410,7 @@ function _get_periodic_coupling_matrix(
410410
@warn "no coupling found. Are the grid boundary regions and the give_opposite! method correct?"
411411
end
412412

413-
return sparse(result)
413+
return sp_result
414414
end
415415

416416
"""
@@ -440,10 +440,8 @@ Example: If b_from is at x[1] = 0 and the opposite boundary is at y[1] = 1, then
440440
The return value is a (𝑛 × 𝑛) sparse matrix 𝐴 (𝑛 is the total number of dofs) containing the periodic coupling information.
441441
The relation ship between the degrees of freedome is dofᵢ = ∑ⱼ Aⱼᵢ ⋅ dofⱼ.
442442
It is guaranteed that
443-
i) Aⱼᵢ=0 if dofᵢ is 𝑛𝑜𝑡 on the boundary b_from.
444-
ii) Aⱼᵢ=0 if the opposite of dofᵢ is not in the same grid cell as dofⱼ.
445-
Note that A is transposed for efficient col-wise storage.
446-
443+
i) Aᵢⱼ=0 if dofᵢ is 𝑛𝑜𝑡 on the boundary b_from.
444+
ii) Aᵢⱼ=0 if the opposite of dofᵢ is not in the same grid cell as dofⱼ.
447445
"""
448446
function get_periodic_coupling_matrix(
449447
FES,

test/test_helper_functions.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -54,11 +54,11 @@ function run_test_helper_functions()
5454
end
5555
end
5656

57-
# row sum is 0.0 or 1.0
57+
# col sum is 0.0 or 1.0
5858
for i in 1:size(matrix, 1)
59-
row_sum = sum(matrix[i, :])
60-
if !(row_sum == 0.0 || row_sum 1.0)
61-
@show row_sum i
59+
col_cum = sum(matrix[:, i])
60+
if !(col_cum == 0.0 || col_cum 1.0)
61+
@show col_cum i
6262
return false
6363
end
6464
end
@@ -78,14 +78,14 @@ function run_test_helper_functions()
7878
let # 3D P1
7979
xgrid = simplexgrid(0:0.1:1.0, 0:0.1:1.0, 0:0.1:1.0)
8080
FES = FESpace{H1P1{1}}(xgrid)
81-
A = get_periodic_coupling_matrix(FES, xgrid, 4, 2, give_opposite!, sparsity_tol = 1.0e-8)
81+
A = get_periodic_coupling_matrix(FES, 4, 2, give_opposite!, sparsity_tol = 1.0e-8)
8282
@test test_matrix(A)
8383
end
8484

8585
let # 3D P2 with 2 components
8686
xgrid = simplexgrid(0:0.5:1.0, 0:0.5:1.0, 0:0.5:1.0)
8787
FES = FESpace{H1P2{2, 3}}(xgrid)
88-
A = get_periodic_coupling_matrix(FES, xgrid, 4, 2, give_opposite!, sparsity_tol = 1.0e-8)
88+
A = get_periodic_coupling_matrix(FES, 4, 2, give_opposite!, sparsity_tol = 1.0e-8)
8989
@test test_matrix(A)
9090
end
9191

@@ -98,7 +98,7 @@ function run_test_helper_functions()
9898
xgrid = simplexgrid(b)
9999

100100
FES = FESpace{H1P1{1}}(xgrid)
101-
A = get_periodic_coupling_matrix(FES, xgrid, 4, 2, give_opposite!, sparsity_tol = 1.0e-8)
101+
A = get_periodic_coupling_matrix(FES, 4, 2, give_opposite!, sparsity_tol = 1.0e-8)
102102
@test test_matrix(A; structured_grid = false)
103103
end
104104
end

0 commit comments

Comments
 (0)