Skip to content

Commit 2b52d30

Browse files
committed
Implement first demo of Lagrange Restrictions
1 parent ce1b3a8 commit 2b52d30

File tree

8 files changed

+193
-37
lines changed

8 files changed

+193
-37
lines changed

Project.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,16 @@ version = "1.5.0"
44
authors = ["Christian Merdon <[email protected]>", "Patrick Jaap <[email protected]>"]
55

66
[deps]
7+
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
8+
Chairmarks = "0ca39b1e-fe0b-4e98-acfc-b1656634c4de"
79
ChunkSplitters = "ae650224-84b6-46f8-82ea-d812ca08434e"
810
CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2"
911
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
1012
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
1113
ExtendableFEMBase = "12fb9182-3d4c-4424-8fd1-727a0899810c"
1214
ExtendableGrids = "cfc395e8-590f-11e8-1f13-43a2532b2fa8"
1315
ExtendableSparse = "95c220a8-a1cf-11e9-0c77-dbfce5f500b3"
16+
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
1417
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1518
GridVisualize = "5eed8a63-0fb0-45eb-886d-8d5a387d12b8"
1619
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
@@ -26,6 +29,7 @@ UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228"
2629

2730
[compat]
2831
Aqua = "0.8"
32+
BlockArrays = "1.7.0"
2933
ChunkSplitters = "3.1.2"
3034
CommonSolve = "0.2"
3135
DiffResults = "1"
@@ -35,6 +39,7 @@ ExplicitImports = "1"
3539
ExtendableFEMBase = "1.3.0"
3640
ExtendableGrids = "1.10.3"
3741
ExtendableSparse = "1.5.3"
42+
FillArrays = "1.13.0"
3843
ForwardDiff = "0.10.35,1"
3944
GridVisualize = "1.8.1"
4045
IncompleteLU = "0.2.1"

examples/Example212_PeriodicBoundary2D.jl

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ end
116116
function main(;
117117
order = 1,
118118
periodic = true,
119+
use_LM_restrictions = true,
119120
Plotter = nothing,
120121
force = 10.0,
121122
h = 1.0e-2,
@@ -154,8 +155,12 @@ function main(;
154155
return nothing
155156
end
156157

157-
@time coupling_matrix = get_periodic_coupling_matrix(FES, reg_left, reg_right, give_opposite!; parallel = threads > 1, threads)
158-
assign_operator!(PD, CombineDofs(u, u, coupling_matrix; kwargs...))
158+
@showtime coupling_matrix = get_periodic_coupling_matrix(FES, reg_left, reg_right, give_opposite!; parallel = threads > 1, threads)
159+
if use_LM_restrictions
160+
assign_restriction!(PD, CoupledDofsRestriction(coupling_matrix))
161+
else
162+
assign_operator!(PD, CombineDofs(u, u, coupling_matrix; kwargs...))
163+
end
159164
end
160165

161166
sol = solve(PD, FES)
@@ -172,10 +177,15 @@ end
172177

173178
generateplots = ExtendableFEM.default_generateplots(Example212_PeriodicBoundary2D, "example212.png") #hide
174179
function runtests() #hide
175-
sol, _ = main() #hide
176-
@test abs(maximum(view(sol[1])) - 1.3447465095618172) < 1.0e-3 #hide
177-
sol2, _ = main(threads = 4) #hide
178-
@test sol.entries sol2.entries #hide
180+
sol1, _ = main(use_LM_restrictions = false, threads = 1) #hide
181+
@test abs(maximum(view(sol1[1])) - 1.3447465095618172) < 1.0e-3 #hide
182+
183+
sol2, _ = main(use_LM_restrictions = false, threads = 4) #hide
184+
@test sol1.entries sol2.entries #hide
185+
186+
sol3, _ = main(use_LM_restrictions = true, threads = 4) #hide
187+
@test sol1.entries sol3.entries #hide
188+
179189
return nothing #hide
180190
end #hide
181191

examples/Example312_PeriodicBoundary3D.jl

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ end
130130
function main(;
131131
order = 1,
132132
periodic = true,
133+
use_LM_restrictions = true,
133134
Plotter = nothing,
134135
force = 1.0,
135136
h = 1.0e-4,
@@ -169,8 +170,12 @@ function main(;
169170
y[1] = width - x[1]
170171
return nothing
171172
end
172-
@time coupling_matrix = get_periodic_coupling_matrix(FES, reg_left, reg_right, give_opposite!; parallel = threads > 1, threads)
173-
assign_operator!(PD, CombineDofs(u, u, coupling_matrix; kwargs...))
173+
@showtime coupling_matrix = get_periodic_coupling_matrix(FES, reg_left, reg_right, give_opposite!; parallel = threads > 1, threads)
174+
if use_LM_restrictions
175+
assign_restriction!(PD, CoupledDofsRestriction(coupling_matrix))
176+
else
177+
assign_operator!(PD, CombineDofs(u, u, coupling_matrix; kwargs...))
178+
end
174179
end
175180

176181
## solve
@@ -185,10 +190,15 @@ end
185190

186191
generateplots = ExtendableFEM.default_generateplots(Example312_PeriodicBoundary3D, "example312.png") #hide
187192
function runtests() #hide
188-
sol, _ = main() #hide
189-
@test abs(maximum(view(sol[1])) - 1.8004602502175202) < 2.0e-3 #hide
190-
sol2, _ = main(threads = 4) #hide
191-
@test sol.entries sol2.entries #hide
193+
sol1, _ = main(use_LM_restrictions = false, threads = 1) #hide
194+
@test abs(maximum(view(sol1[1])) - 1.8004602502175202) < 2.0e-3 #hide
195+
196+
sol2, _ = main(use_LM_restrictions = false, threads = 4) #hide
197+
@test sol1.entries sol2.entries #hide
198+
199+
sol3, _ = main(use_LM_restrictions = true, threads = 4) #hide
200+
@test sol1.entries sol3.entries #hide
201+
192202
return nothing #hide
193203
end #hide
194204

src/ExtendableFEM.jl

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ $(read(joinpath(@__DIR__, "..", "README.md"), String))
55
"""
66
module ExtendableFEM
77

8+
using BlockArrays: BlockMatrix, BlockVector, Block, BlockedMatrix, BlockedVector
89
using ChunkSplitters: chunks
910
using CommonSolve: CommonSolve
1011
using DiffResults: DiffResults
@@ -60,6 +61,7 @@ using ExtendableGrids: ExtendableGrids, AT_NODES, AbstractElementGeometry,
6061
using ExtendableSparse: ExtendableSparse, ExtendableSparseMatrix, flush!,
6162
MTExtendableSparseMatrixCSC, findindex,
6263
rawupdateindex!
64+
using FillArrays: Zeros
6365
using ForwardDiff: ForwardDiff
6466
using GridVisualize: GridVisualize, GridVisualizer, gridplot!, reveal, save,
6567
scalarplot!, vectorplot!
@@ -68,8 +70,8 @@ using LinearSolve: LinearSolve, LinearProblem, UMFPACKFactorization, deleteat!,
6870
init, solve
6971
using Printf: Printf, @printf, @sprintf
7072
using SparseArrays: SparseArrays, AbstractSparseArray, SparseMatrixCSC, findnz, nnz,
71-
nzrange, rowvals, sparse, SparseVector
72-
using StaticArrays: @MArray
73+
nzrange, rowvals, sparse, SparseVector, spzeros
74+
using StaticArrays: @MArray
7375
using SparseDiffTools: SparseDiffTools, ForwardColorJacCache,
7476
forwarddiff_color_jacobian!, matrix_colors
7577
using Symbolics: Symbolics
@@ -123,11 +125,16 @@ include("common_operators/reduction_operator.jl")
123125
#export AbstractReductionOperator
124126
#export FixbyInterpolation
125127

128+
include("restrictions.jl")
129+
include("common_restrictions/coupled_dofs_restriction.jl")
130+
export CoupledDofsRestriction
131+
126132
include("problemdescription.jl")
127133
export ProblemDescription
128134
export assign_unknown!
129135
export assign_operator!
130136
export replace_operator!
137+
export assign_restriction!
131138

132139
include("helper_functions.jl")
133140
export get_periodic_coupling_info
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
struct CoupledDofsRestriction{TM} <: AbstractRestriction
2+
coupling_matrix::TM
3+
parameters::Dict{Symbol, Any}
4+
end
5+
6+
function CoupledDofsRestriction(matrix::AbstractMatrix)
7+
return CoupledDofsRestriction(matrix, Dict{Symbol, Any}(:name => "CoupledDofsRestriction"))
8+
end
9+
10+
11+
function assemble!(R::CoupledDofsRestriction, A, b, sol, SC; kwargs...)
12+
13+
# extract all col indices
14+
_, J, _ = findnz(R.coupling_matrix)
15+
16+
# remove duplicates
17+
unique_cols = unique(J)
18+
19+
# subtract diagonal and shrink matrix to non-empty cols
20+
B = (R.coupling_matrix - LinearAlgebra.I)[:, unique_cols]
21+
22+
R.parameters[:matrix] = B
23+
return R.parameters[:rhs] = Zeros(length(unique_cols))
24+
25+
return nothing
26+
end

src/problemdescription.jl

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,11 @@ mutable struct ProblemDescription
2626
"""
2727
operators::Array{AbstractOperator, 1}
2828
#reduction_operators::Array{AbstractReductionOperator,1}
29+
30+
"""
31+
A vector of Lagrange restrictions that are involved in the problem.
32+
"""
33+
restrictions::Vector{AbstractRestriction}
2934
end
3035

3136
"""
@@ -41,7 +46,7 @@ Create an empty `ProblemDescription` with the given name.
4146
4247
"""
4348
function ProblemDescription(name = "My problem")
44-
return ProblemDescription(name, Array{Unknown, 1}(undef, 0), Array{AbstractOperator, 1}(undef, 0))
49+
return ProblemDescription(name, [], [], [])
4550
end
4651

4752

@@ -92,6 +97,12 @@ function assign_operator!(PD::ProblemDescription, o::AbstractOperator)
9297
end
9398

9499

100+
function assign_restriction!(PD::ProblemDescription, r::AbstractRestriction)
101+
push!(PD.restrictions, r)
102+
return length(PD.restrictions)
103+
end
104+
105+
95106
"""
96107
$(TYPEDSIGNATURES)
97108

src/restrictions.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
"""
2+
AbstractRestriction
3+
4+
Root type for all restrictions
5+
"""
6+
abstract type AbstractRestriction end
7+
8+
9+
function Base.show(io::IO, R::AbstractRestriction)
10+
print(io, "AbstractRestriction")
11+
return nothing
12+
end
13+
14+
# informs solver when operator needs reassembly in a time dependent setting
15+
function is_timedependent(R::AbstractRestriction)
16+
return false
17+
end
18+
19+
function assemble!(R::AbstractRestriction, A, b, sol, SC; kwargs...)
20+
## assembles internal restriction matrix in R
21+
return nothing
22+
end

src/solvers.jl

Lines changed: 87 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -221,27 +221,92 @@ Solves the linear system and updates the solution vector. This includes:
221221
- Computing the residual
222222
- Updating the solution with optional damping
223223
"""
224-
function solve_linear_system!(A, b, sol, soltemp, residual, linsolve, unknowns, freedofs, damping, PD, SC, stats, is_linear, timer)
225-
# Update system matrix if needed
226-
if !SC.parameters[:constant_matrix] || !SC.parameters[:initialized]
224+
function solve_linear_system!(A, b, sol, soltemp, residual, linsolve, unknowns, freedofs, damping, PD, SC, stats, is_linear, timer, kwargs...)
225+
226+
@timeit timer "Lagrange restrictions" begin
227+
## assemble restrctions
228+
if !SC.parameters[:initialized]
229+
for restriction in PD.restrictions
230+
@timeit timer "$(restriction.parameters[:name])" assemble!(restriction, A, b, sol, SC; kwargs...)
231+
end
232+
end
233+
end
234+
235+
@timeit timer "linear solver" begin
236+
## start with the assembled matrix containing all assembled operators
237+
if !SC.parameters[:constant_matrix] || !SC.parameters[:initialized]
238+
if length(freedofs) > 0
239+
A_unrestricted = A.entries.cscmatrix[freedofs, freedofs]
240+
else
241+
A_unrestricted = A.entries.cscmatrix
242+
end
243+
end
244+
245+
# we solve for A Δx = r
246+
# and update x = sol - Δx
227247
if length(freedofs) > 0
228-
linsolve.A = A.entries.cscmatrix[freedofs, freedofs]
248+
b_unrestricted = residual.entries[freedofs]
229249
else
230-
linsolve.A = A.entries.cscmatrix
250+
b_unrestricted = residual.entries
231251
end
232-
end
233252

234-
# Set right-hand side
235-
if length(freedofs) > 0
236-
linsolve.b = residual.entries[freedofs]
237-
else
238-
linsolve.b = residual.entries
239-
end
240-
SC.parameters[:initialized] = true
253+
@timeit timer "LM restrictions" begin
254+
## add possible Lagrange restrictions
255+
@timeit timer "prepare" begin
256+
restriction_matrices = [length(freedofs) > 0 ? re.parameters[:matrix][freedofs, :] : re.parameters[:matrix] for re in PD.restrictions ]
257+
restriction_rhs = [length(freedofs) > 0 ? re.parameters[:rhs][freedofs] : re.parameters[:rhs] for re in PD.restrictions ]
258+
259+
## we need to add the (initial) solution to the rhs, since we work with the residual equation
260+
for (B, rhs) in zip(restriction_matrices, restriction_rhs)
261+
rhs .+= B'sol.entries
262+
end
263+
end
264+
265+
266+
@timeit timer "compute blocks" begin
267+
# block sizes for the block matrix
268+
block_sizes = [size(A_unrestricted, 2), [ size(B, 2) for B in restriction_matrices ]...]
241269

242-
# Solve linear system
243-
push!(stats[:matrix_nnz], nnz(linsolve.A))
244-
@timeit timer "solve! call" Δx = LinearSolve.solve!(linsolve)
270+
total_size = sum(block_sizes)
271+
Tv = eltype(A_unrestricted)
272+
273+
## create block matrix
274+
A_block = BlockMatrix(spzeros(Tv, total_size, total_size), block_sizes, block_sizes)
275+
A_block[Block(1, 1)] = A_unrestricted
276+
277+
b_block = BlockVector(zeros(Tv, total_size), block_sizes)
278+
b_block[Block(1)] = b_unrestricted
279+
280+
u_unrestricted = linsolve.u
281+
u_block = BlockVector(zeros(Tv, total_size), block_sizes)
282+
u_block[Block(1)] = u_unrestricted
283+
284+
for i in eachindex(PD.restrictions)
285+
A_block[Block(1, i + 1)] = restriction_matrices[i]
286+
A_block[Block(i + 1, 1)] = transpose(restriction_matrices[i])
287+
b_block[Block(i + 1)] = restriction_rhs[i]
288+
289+
end
290+
end
291+
292+
@timeit timer "convert" begin
293+
294+
linsolve.A = sparse(A_block) # convert to CSC Matrix
295+
linsolve.b = Vector(b_block) # convert to dense vector
296+
linsolve.u = Vector(u_block) # convert to dense vector
297+
298+
end
299+
end
300+
301+
SC.parameters[:initialized] = true
302+
303+
# Solve linear system
304+
push!(stats[:matrix_nnz], nnz(linsolve.A))
305+
@timeit timer "solve! call" blocked_Δx = LinearSolve.solve!(linsolve)
306+
307+
# extract the solution / dismiss the lagrange multipliers
308+
@views Δx = blocked_Δx[1:length(b_unrestricted)]
309+
end
245310

246311
# Compute solution update
247312
@timeit timer "update solution" begin
@@ -570,12 +635,12 @@ function CommonSolve.solve(PD::ProblemDescription, FES::Union{<:FESpace, Vector{
570635

571636
linsolve = SC.linsolver
572637

573-
@timeit timer "linear solver" begin
574-
linres = solve_linear_system!(
575-
A, b, sol, soltemp, residual, linsolve, unknowns,
576-
freedofs, damping, PD, SC, stats, is_linear, timer
577-
)
578-
end
638+
639+
linres = solve_linear_system!(
640+
A, b, sol, soltemp, residual, linsolve, unknowns,
641+
freedofs, damping, PD, SC, stats, is_linear, timer
642+
)
643+
579644
if SC.parameters[:verbosity] > -1
580645
if is_linear
581646
@printf "%.3e\t" linres

0 commit comments

Comments
 (0)