Skip to content

Commit 7205e9c

Browse files
committed
Add symmetric system constructor
1 parent fb8f561 commit 7205e9c

File tree

1 file changed

+50
-15
lines changed

1 file changed

+50
-15
lines changed

src/system/system.jl

Lines changed: 50 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
1-
struct System{N}
1+
abstract type Symmetric end # symmetric = true
2+
abstract type BlockSymmetric <: Symmetric end # symmetric = true
3+
abstract type SparsitySymmetric end # symmetric = false, Any matrix is sparsity symmetric since zero entries can be assumed nonzero
4+
5+
struct System{N,T}
26
matrix_entries::SparseMatrixCSC{GraphBasedSystems.Entry, Int64}
37
vector_entries::Vector{Entry}
48
diagonal_inverses::Vector{Entry}
@@ -9,23 +13,21 @@ struct System{N}
913
graph::SimpleGraph{Int64}
1014
dfs_graph::SimpleDiGraph{Int64}
1115

12-
function System{T}(A, dims; force_static = false) where T
16+
function System{T}(A, dims; force_static = false, symmetric = false) where T
1317
N = length(dims)
18+
if symmetric
19+
all(dims .== 1) ? S = Symmetric : S = BlockSymmetric
20+
else
21+
S = SparsitySymmetric
22+
end
1423
static = force_static || all(dims.<=10)
1524

1625
full_graph = Graph(A)
1726

1827
matrix_entries = spzeros(Entry,N,N)
1928

2029
for (i,dimi) in enumerate(dims)
21-
for (j,dimj) in enumerate(dims)
22-
if i == j
23-
matrix_entries[i,j] = Entry{T}(dimi, dimj, static = static)
24-
elseif j all_neighbors(full_graph,i)
25-
matrix_entries[i,j] = Entry{T}(dimi, dimj, static = static)
26-
matrix_entries[j,i] = Entry{T}(dimj, dimi, static = static)
27-
end
28-
end
30+
matrix_entries[i,i] = Entry{T}(dimi, dimi, static = static)
2931
end
3032

3133
vector_entries = [Entry{T}(dim, static = static) for dim in dims]
@@ -65,23 +67,36 @@ struct System{N}
6567
acyclic_children[v] = setdiff(acyclic_children[v], cyclic_children)
6668
for c in cyclic_children
6769
matrix_entries[v,c] = Entry{T}(dims[v], dims[c], static = static)
68-
matrix_entries[c,v] = Entry{T}(dims[c], dims[v], static = static)
70+
!symmetric && (matrix_entries[c,v] = Entry{T}(dims[c], dims[v], static = static))
6971

7072
v parents[c] && push!(parents[c],v)
7173
end
72-
end
74+
end
75+
for v in sub_dfs_list
76+
for c in acyclic_children[v]
77+
matrix_entries[v,c] = Entry{T}(dims[v], dims[c], static = static)
78+
!symmetric && (matrix_entries[c,v] = Entry{T}(dims[c], dims[v], static = static))
79+
end
80+
end
7381
end
7482

7583
full_dfs_graph = SimpleDiGraph(edgelist)
7684
cyclic_children = [unique(vcat(cycles[i]...)) for i=1:N]
7785

78-
new{N}(matrix_entries, vector_entries, diagonal_inverses, acyclic_children, cyclic_children, parents, dfs_list, full_graph, full_dfs_graph)
86+
new{N,S}(matrix_entries, vector_entries, diagonal_inverses, acyclic_children, cyclic_children, parents, dfs_list, full_graph, full_dfs_graph)
7987
end
8088

81-
System(A, dims; force_static = false) = System{Float64}(A, dims; force_static = force_static)
89+
System(A, dims; force_static = false, symmetric = false) = System{Float64}(A, dims; force_static, symmetric)
8290
end
8391

84-
function Base.show(io::IO, mime::MIME{Symbol("text/plain")}, system::System{N}) where {N}
92+
function Base.show(io::IO, mime::MIME{Symbol("text/plain")}, system::System{N,S}) where {N,S}
93+
if S <: BlockSymmetric
94+
print(io, "BlockSymmetric ")
95+
elseif S <: Symmetric
96+
print(io, "Symmetric ")
97+
elseif S <: SparsitySymmetric
98+
print(io, "SparsitySymmetric ")
99+
end
85100
println(io, "System with "*string(N)*" nodes.")
86101
SparseArrays._show_with_braille_patterns(io, system.matrix_entries)
87102
end
@@ -109,4 +124,24 @@ function full_matrix(system::System{N}) where N
109124
return A
110125
end
111126

127+
# function full_matrix(system::System{N,<:Symmetric}) where N
128+
# dims = [length(system.vector_entries[i].value) for i=1:N]
129+
130+
# range = [1:dims[1]]
131+
132+
# for (i,dim) in enumerate(collect(Iterators.rest(dims, 2)))
133+
# push!(range,sum(dims[1:i])+1:sum(dims[1:i])+dim)
134+
# end
135+
# A = zeros(sum(dims),sum(dims))
136+
137+
# for (i,row) in enumerate(system.matrix_entries.rowval)
138+
# col = findfirst(x->i<x,system.matrix_entries.colptr)-1
139+
# A[range[row],range[col]] = system.matrix_entries[row,col].value
140+
# if col != row
141+
# A[range[col],range[row]] = system.matrix_entries[row,col].value'
142+
# end
143+
# end
144+
# return A
145+
# end
146+
112147
full_vector(system) = vcat(getfield.(system.vector_entries,:value)...)

0 commit comments

Comments
 (0)