1
- struct System{N}
1
+ abstract type Symmetric end # symmetric = true
2
+ abstract type BlockSymmetric <: Symmetric end # symmetric = true
3
+ abstract type SparsitySymmetric end # symmetric = false, Any matrix is sparsity symmetric since zero entries can be assumed nonzero
4
+
5
+ struct System{N,T}
2
6
matrix_entries:: SparseMatrixCSC{GraphBasedSystems.Entry, Int64}
3
7
vector_entries:: Vector{Entry}
4
8
diagonal_inverses:: Vector{Entry}
@@ -9,23 +13,21 @@ struct System{N}
9
13
graph:: SimpleGraph{Int64}
10
14
dfs_graph:: SimpleDiGraph{Int64}
11
15
12
- function System {T} (A, dims; force_static = false ) where T
16
+ function System {T} (A, dims; force_static = false , symmetric = false ) where T
13
17
N = length (dims)
18
+ if symmetric
19
+ all (dims .== 1 ) ? S = Symmetric : S = BlockSymmetric
20
+ else
21
+ S = SparsitySymmetric
22
+ end
14
23
static = force_static || all (dims.<= 10 )
15
24
16
25
full_graph = Graph (A)
17
26
18
27
matrix_entries = spzeros (Entry,N,N)
19
28
20
29
for (i,dimi) in enumerate (dims)
21
- for (j,dimj) in enumerate (dims)
22
- if i == j
23
- matrix_entries[i,j] = Entry {T} (dimi, dimj, static = static)
24
- elseif j ∈ all_neighbors (full_graph,i)
25
- matrix_entries[i,j] = Entry {T} (dimi, dimj, static = static)
26
- matrix_entries[j,i] = Entry {T} (dimj, dimi, static = static)
27
- end
28
- end
30
+ matrix_entries[i,i] = Entry {T} (dimi, dimi, static = static)
29
31
end
30
32
31
33
vector_entries = [Entry {T} (dim, static = static) for dim in dims]
@@ -65,23 +67,36 @@ struct System{N}
65
67
acyclic_children[v] = setdiff (acyclic_children[v], cyclic_children)
66
68
for c in cyclic_children
67
69
matrix_entries[v,c] = Entry {T} (dims[v], dims[c], static = static)
68
- matrix_entries[c,v] = Entry {T} (dims[c], dims[v], static = static)
70
+ ! symmetric && ( matrix_entries[c,v] = Entry {T} (dims[c], dims[v], static = static) )
69
71
70
72
v ∉ parents[c] && push! (parents[c],v)
71
73
end
72
- end
74
+ end
75
+ for v in sub_dfs_list
76
+ for c in acyclic_children[v]
77
+ matrix_entries[v,c] = Entry {T} (dims[v], dims[c], static = static)
78
+ ! symmetric && (matrix_entries[c,v] = Entry {T} (dims[c], dims[v], static = static))
79
+ end
80
+ end
73
81
end
74
82
75
83
full_dfs_graph = SimpleDiGraph (edgelist)
76
84
cyclic_children = [unique (vcat (cycles[i]. .. )) for i= 1 : N]
77
85
78
- new {N} (matrix_entries, vector_entries, diagonal_inverses, acyclic_children, cyclic_children, parents, dfs_list, full_graph, full_dfs_graph)
86
+ new {N,S } (matrix_entries, vector_entries, diagonal_inverses, acyclic_children, cyclic_children, parents, dfs_list, full_graph, full_dfs_graph)
79
87
end
80
88
81
- System (A, dims; force_static = false ) = System {Float64} (A, dims; force_static = force_static )
89
+ System (A, dims; force_static = false , symmetric = false ) = System {Float64} (A, dims; force_static, symmetric )
82
90
end
83
91
84
- function Base. show (io:: IO , mime:: MIME{Symbol("text/plain")} , system:: System{N} ) where {N}
92
+ function Base. show (io:: IO , mime:: MIME{Symbol("text/plain")} , system:: System{N,S} ) where {N,S}
93
+ if S <: BlockSymmetric
94
+ print (io, " BlockSymmetric " )
95
+ elseif S <: Symmetric
96
+ print (io, " Symmetric " )
97
+ elseif S <: SparsitySymmetric
98
+ print (io, " SparsitySymmetric " )
99
+ end
85
100
println (io, " System with " * string (N)* " nodes." )
86
101
SparseArrays. _show_with_braille_patterns (io, system. matrix_entries)
87
102
end
@@ -109,4 +124,24 @@ function full_matrix(system::System{N}) where N
109
124
return A
110
125
end
111
126
127
+ # function full_matrix(system::System{N,<:Symmetric}) where N
128
+ # dims = [length(system.vector_entries[i].value) for i=1:N]
129
+
130
+ # range = [1:dims[1]]
131
+
132
+ # for (i,dim) in enumerate(collect(Iterators.rest(dims, 2)))
133
+ # push!(range,sum(dims[1:i])+1:sum(dims[1:i])+dim)
134
+ # end
135
+ # A = zeros(sum(dims),sum(dims))
136
+
137
+ # for (i,row) in enumerate(system.matrix_entries.rowval)
138
+ # col = findfirst(x->i<x,system.matrix_entries.colptr)-1
139
+ # A[range[row],range[col]] = system.matrix_entries[row,col].value
140
+ # if col != row
141
+ # A[range[col],range[row]] = system.matrix_entries[row,col].value'
142
+ # end
143
+ # end
144
+ # return A
145
+ # end
146
+
112
147
full_vector (system) = vcat (getfield .(system. vector_entries,:value )... )
0 commit comments