@@ -221,27 +221,92 @@ Solves the linear system and updates the solution vector. This includes:
221
221
- Computing the residual
222
222
- Updating the solution with optional damping
223
223
"""
224
- function solve_linear_system! (A, b, sol, soltemp, residual, linsolve, unknowns, freedofs, damping, PD, SC, stats, is_linear, timer)
225
- # Update system matrix if needed
226
- if ! SC. parameters[:constant_matrix ] || ! SC. parameters[:initialized ]
224
+ function solve_linear_system! (A, b, sol, soltemp, residual, linsolve, unknowns, freedofs, damping, PD, SC, stats, is_linear, timer, kwargs... )
225
+
226
+ @timeit timer " Lagrange restrictions" begin
227
+ # # assemble restrctions
228
+ if ! SC. parameters[:initialized ]
229
+ for restriction in PD. restrictions
230
+ @timeit timer " $(restriction. parameters[:name ]) " assemble! (restriction, A, b, sol, SC; kwargs... )
231
+ end
232
+ end
233
+ end
234
+
235
+ @timeit timer " linear solver" begin
236
+ # # start with the assembled matrix containing all assembled operators
237
+ if ! SC. parameters[:constant_matrix ] || ! SC. parameters[:initialized ]
238
+ if length (freedofs) > 0
239
+ A_unrestricted = A. entries. cscmatrix[freedofs, freedofs]
240
+ else
241
+ A_unrestricted = A. entries. cscmatrix
242
+ end
243
+ end
244
+
245
+ # we solve for A Δx = r
246
+ # and update x = sol - Δx
227
247
if length (freedofs) > 0
228
- linsolve . A = A . entries. cscmatrix[freedofs, freedofs]
248
+ b_unrestricted = residual . entries[ freedofs]
229
249
else
230
- linsolve . A = A . entries. cscmatrix
250
+ b_unrestricted = residual . entries
231
251
end
232
- end
233
252
234
- # Set right-hand side
235
- if length (freedofs) > 0
236
- linsolve. b = residual. entries[freedofs]
237
- else
238
- linsolve. b = residual. entries
239
- end
240
- SC. parameters[:initialized ] = true
253
+ @timeit timer " LM restrictions" begin
254
+ # # add possible Lagrange restrictions
255
+ @timeit timer " prepare" begin
256
+ restriction_matrices = [length (freedofs) > 0 ? re. parameters[:matrix ][freedofs, :] : re. parameters[:matrix ] for re in PD. restrictions ]
257
+ restriction_rhs = [length (freedofs) > 0 ? re. parameters[:rhs ][freedofs] : re. parameters[:rhs ] for re in PD. restrictions ]
258
+
259
+ # # we need to add the (initial) solution to the rhs, since we work with the residual equation
260
+ for (B, rhs) in zip (restriction_matrices, restriction_rhs)
261
+ rhs .+ = B' sol. entries
262
+ end
263
+ end
264
+
265
+
266
+ @timeit timer " compute blocks" begin
267
+ # block sizes for the block matrix
268
+ block_sizes = [size (A_unrestricted, 2 ), [ size (B, 2 ) for B in restriction_matrices ]. .. ]
241
269
242
- # Solve linear system
243
- push! (stats[:matrix_nnz ], nnz (linsolve. A))
244
- @timeit timer " solve! call" Δx = LinearSolve. solve! (linsolve)
270
+ total_size = sum (block_sizes)
271
+ Tv = eltype (A_unrestricted)
272
+
273
+ # # create block matrix
274
+ A_block = BlockMatrix (spzeros (Tv, total_size, total_size), block_sizes, block_sizes)
275
+ A_block[Block (1 , 1 )] = A_unrestricted
276
+
277
+ b_block = BlockVector (zeros (Tv, total_size), block_sizes)
278
+ b_block[Block (1 )] = b_unrestricted
279
+
280
+ u_unrestricted = linsolve. u
281
+ u_block = BlockVector (zeros (Tv, total_size), block_sizes)
282
+ u_block[Block (1 )] = u_unrestricted
283
+
284
+ for i in eachindex (PD. restrictions)
285
+ A_block[Block (1 , i + 1 )] = restriction_matrices[i]
286
+ A_block[Block (i + 1 , 1 )] = transpose (restriction_matrices[i])
287
+ b_block[Block (i + 1 )] = restriction_rhs[i]
288
+
289
+ end
290
+ end
291
+
292
+ @timeit timer " convert" begin
293
+
294
+ linsolve. A = sparse (A_block) # convert to CSC Matrix
295
+ linsolve. b = Vector (b_block) # convert to dense vector
296
+ linsolve. u = Vector (u_block) # convert to dense vector
297
+
298
+ end
299
+ end
300
+
301
+ SC. parameters[:initialized ] = true
302
+
303
+ # Solve linear system
304
+ push! (stats[:matrix_nnz ], nnz (linsolve. A))
305
+ @timeit timer " solve! call" blocked_Δx = LinearSolve. solve! (linsolve)
306
+
307
+ # extract the solution / dismiss the lagrange multipliers
308
+ @views Δx = blocked_Δx[1 : length (b_unrestricted)]
309
+ end
245
310
246
311
# Compute solution update
247
312
@timeit timer " update solution" begin
@@ -570,12 +635,12 @@ function CommonSolve.solve(PD::ProblemDescription, FES::Union{<:FESpace, Vector{
570
635
571
636
linsolve = SC. linsolver
572
637
573
- @timeit timer " linear solver " begin
574
- linres = solve_linear_system! (
575
- A, b, sol, soltemp, residual, linsolve, unknowns,
576
- freedofs, damping, PD, SC, stats, is_linear, timer
577
- )
578
- end
638
+
639
+ linres = solve_linear_system! (
640
+ A, b, sol, soltemp, residual, linsolve, unknowns,
641
+ freedofs, damping, PD, SC, stats, is_linear, timer
642
+ )
643
+
579
644
if SC. parameters[:verbosity ] > - 1
580
645
if is_linear
581
646
@printf " %.3e\t " linres
0 commit comments