Skip to content
This repository was archived by the owner on Sep 28, 2024. It is now read-only.

Commit aa85e56

Browse files
committed
complete SparseKernel1d/2d/3d
1 parent 849fde2 commit aa85e56

File tree

6 files changed

+308
-12
lines changed

6 files changed

+308
-12
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
1010
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
1111
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
1212
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
13+
Polynomials = "f27b6e38-b328-58d1-80ce-0feddd5e7a45"
14+
SpecialPolynomials = "a25cea48-d430-424a-8ee7-0d3ad3742e9e"
1315
Tullio = "bc48ee85-29a4-5162-ae0b-a64e1601d4bc"
1416
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
1517

src/NeuralOperators.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,11 @@ module NeuralOperators
77
using KernelAbstractions
88
using Zygote
99
using ChainRulesCore
10+
using Polynomials
11+
using SpecialPolynomials
1012

13+
include("utils.jl")
14+
include("polynomials.jl")
1115
include("fourier.jl")
1216
include("wavelet.jl")
1317
include("model.jl")

src/polynomials.jl

Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
function legendre_ϕ_ψ(k)
2+
# TODO: row-major -> column major
3+
ϕ_coefs = zeros(k, k)
4+
ϕ_2x_coefs = zeros(k, k)
5+
6+
p = Polynomial([-1, 2]) # 2x-1
7+
p2 = Polynomial([-1, 4]) # 4x-1
8+
9+
for ki in 0:(k-1)
10+
l = convert(Polynomial, gen_poly(Legendre, ki)) # Legendre of n=ki
11+
ϕ_coefs[ki+1, 1:(ki+1)] .= sqrt(2*ki+1) .* coeffs(l(p))
12+
ϕ_2x_coefs[ki+1, 1:(ki+1)] .= sqrt(2*(2*ki+1)) .* coeffs(l(p2))
13+
end
14+
15+
ψ1_coefs .= ϕ_2x_coefs
16+
ψ2_coefs = zeros(k, k)
17+
for ki in 0:(k-1)
18+
for i in 0:(k-1)
19+
a = ϕ_2x_coefs[ki+1, 1:(ki+1)]
20+
b = ϕ_coefs[i+1, 1:(i+1)]
21+
proj_ = proj_factor(a, b)
22+
view(ψ1_coefs, ki+1, :) .-= proj_ .* view(ϕ_coefs, i+1, :)
23+
view(ψ2_coefs, ki+1, :) .-= proj_ .* view(ϕ_coefs, i+1, :)
24+
end
25+
26+
for j in 0:(k-1)
27+
a = ϕ_2x_coefs[ki+1, 1:(ki+1)]
28+
b = ψ1_coefs[j+1, :]
29+
proj_ = proj_factor(a, b)
30+
view(ψ1_coefs, ki+1, :) .-= proj_ .* view(ψ1_coefs, j+1, :)
31+
view(ψ2_coefs, ki+1, :) .-= proj_ .* view(ψ2_coefs, j+1, :)
32+
end
33+
34+
a = ψ1_coefs[ki+1, :]
35+
norm1 = proj_factor(a, a)
36+
37+
a = ψ2_coefs[ki+1, :]
38+
norm2 = proj_factor(a, a, complement=true)
39+
norm_ = sqrt(norm1 + norm2)
40+
ψ1_coefs[ki+1, :] ./= norm_
41+
ψ2_coefs[ki+1, :] ./= norm_
42+
zero_out!(ψ1_coefs)
43+
zero_out!(ψ2_coefs)
44+
end
45+
46+
ϕ = [Polynomial(ϕ_coefs[i,:]) for i in 1:k]
47+
ψ1 = [Polynomial(ψ1_coefs[i,:]) for i in 1:k]
48+
ψ2 = [Polynomial(ψ2_coefs[i,:]) for i in 1:k]
49+
50+
return ϕ, ψ1, ψ2
51+
end
52+
53+
# function chebyshev_ϕ_ψ(k)
54+
# ϕ_coefs = zeros(k, k)
55+
# ϕ_2x_coefs = zeros(k, k)
56+
57+
# p = Polynomial([-1, 2]) # 2x-1
58+
# p2 = Polynomial([-1, 4]) # 4x-1
59+
60+
# for ki in 0:(k-1)
61+
# if ki == 0
62+
# ϕ_coefs[ki+1, 1:(ki+1)] .= sqrt(2/π)
63+
# ϕ_2x_coefs[ki+1, 1:(ki+1)] .= sqrt(4/π)
64+
# else
65+
# c = convert(Polynomial, gen_poly(Chebyshev, ki)) # Chebyshev of n=ki
66+
# ϕ_coefs[ki+1, 1:(ki+1)] .= 2/sqrt(π) .* coeffs(c(p))
67+
# ϕ_2x_coefs[ki+1, 1:(ki+1)] .= sqrt(2) * 2/sqrt(π) .* coeffs(c(p2))
68+
# end
69+
# end
70+
71+
# ϕ = [ϕ_(ϕ_coefs[i, :]) for i in 1:k]
72+
73+
# k_use = 2k
74+
75+
# # phi = [partial(phi_, phi_coeff[i,:]) for i in range(k)]
76+
77+
# # x = Symbol('x')
78+
# # kUse = 2*k
79+
# # roots = Poly(chebyshevt(kUse, 2*x-1)).all_roots()
80+
# # x_m = np.array([rt.evalf(20) for rt in roots]).astype(np.float64)
81+
# # # x_m[x_m==0.5] = 0.5 + 1e-8 # add small noise to avoid the case of 0.5 belonging to both phi(2x) and phi(2x-1)
82+
# # # not needed for our purpose here, we use even k always to avoid
83+
# # wm = np.pi / kUse / 2
84+
85+
# # psi1_coeff = np.zeros((k, k))
86+
# # psi2_coeff = np.zeros((k, k))
87+
88+
# # psi1 = [[] for _ in range(k)]
89+
# # psi2 = [[] for _ in range(k)]
90+
91+
# # for ki in range(k):
92+
# # psi1_coeff[ki,:] = phi_2x_coeff[ki,:]
93+
# # for i in range(k):
94+
# # proj_ = (wm * phi[i](x_m) * np.sqrt(2)* phi[ki](2*x_m)).sum()
95+
# # psi1_coeff[ki,:] -= proj_ * phi_coeff[i,:]
96+
# # psi2_coeff[ki,:] -= proj_ * phi_coeff[i,:]
97+
98+
# # for j in range(ki):
99+
# # proj_ = (wm * psi1[j](x_m) * np.sqrt(2) * phi[ki](2*x_m)).sum()
100+
# # psi1_coeff[ki,:] -= proj_ * psi1_coeff[j,:]
101+
# # psi2_coeff[ki,:] -= proj_ * psi2_coeff[j,:]
102+
103+
# # psi1[ki] = partial(phi_, psi1_coeff[ki,:], lb = 0, ub = 0.5)
104+
# # psi2[ki] = partial(phi_, psi2_coeff[ki,:], lb = 0.5, ub = 1)
105+
106+
# # norm1 = (wm * psi1[ki](x_m) * psi1[ki](x_m)).sum()
107+
# # norm2 = (wm * psi2[ki](x_m) * psi2[ki](x_m)).sum()
108+
109+
# # norm_ = np.sqrt(norm1 + norm2)
110+
# # psi1_coeff[ki,:] /= norm_
111+
# # psi2_coeff[ki,:] /= norm_
112+
# # psi1_coeff[np.abs(psi1_coeff)<1e-8] = 0
113+
# # psi2_coeff[np.abs(psi2_coeff)<1e-8] = 0
114+
115+
# # psi1[ki] = partial(phi_, psi1_coeff[ki,:], lb = 0, ub = 0.5+1e-16)
116+
# # psi2[ki] = partial(phi_, psi2_coeff[ki,:], lb = 0.5+1e-16, ub = 1)
117+
118+
# # return phi, psi1, psi2
119+
# end
120+
121+
function legendre_filter(k)
122+
# x = Symbol('x')
123+
# H0 = np.zeros((k,k))
124+
# H1 = np.zeros((k,k))
125+
# G0 = np.zeros((k,k))
126+
# G1 = np.zeros((k,k))
127+
# PHI0 = np.zeros((k,k))
128+
# PHI1 = np.zeros((k,k))
129+
# phi, psi1, psi2 = get_phi_psi(k, base)
130+
131+
# ----------------------------------------------------------
132+
133+
# roots = Poly(legendre(k, 2*x-1)).all_roots()
134+
# x_m = np.array([rt.evalf(20) for rt in roots]).astype(np.float64)
135+
# wm = 1/k/legendreDer(k,2*x_m-1)/eval_legendre(k-1,2*x_m-1)
136+
137+
# for ki in range(k):
138+
# for kpi in range(k):
139+
# H0[ki, kpi] = 1/np.sqrt(2) * (wm * phi[ki](x_m/2) * phi[kpi](x_m)).sum()
140+
# G0[ki, kpi] = 1/np.sqrt(2) * (wm * psi(psi1, psi2, ki, x_m/2) * phi[kpi](x_m)).sum()
141+
# H1[ki, kpi] = 1/np.sqrt(2) * (wm * phi[ki]((x_m+1)/2) * phi[kpi](x_m)).sum()
142+
# G1[ki, kpi] = 1/np.sqrt(2) * (wm * psi(psi1, psi2, ki, (x_m+1)/2) * phi[kpi](x_m)).sum()
143+
144+
# PHI0 = np.eye(k)
145+
# PHI1 = np.eye(k)
146+
147+
# ----------------------------------------------------------
148+
149+
# H0[np.abs(H0)<1e-8] = 0
150+
# H1[np.abs(H1)<1e-8] = 0
151+
# G0[np.abs(G0)<1e-8] = 0
152+
# G1[np.abs(G1)<1e-8] = 0
153+
154+
# return H0, H1, G0, G1, PHI0, PHI1
155+
end
156+
157+
function chebyshev_filter(k)
158+
# x = Symbol('x')
159+
# H0 = np.zeros((k,k))
160+
# H1 = np.zeros((k,k))
161+
# G0 = np.zeros((k,k))
162+
# G1 = np.zeros((k,k))
163+
# PHI0 = np.zeros((k,k))
164+
# PHI1 = np.zeros((k,k))
165+
# phi, psi1, psi2 = get_phi_psi(k, base)
166+
167+
# ----------------------------------------------------------
168+
169+
# x = Symbol('x')
170+
# kUse = 2*k
171+
# roots = Poly(chebyshevt(kUse, 2*x-1)).all_roots()
172+
# x_m = np.array([rt.evalf(20) for rt in roots]).astype(np.float64)
173+
# # x_m[x_m==0.5] = 0.5 + 1e-8 # add small noise to avoid the case of 0.5 belonging to both phi(2x) and phi(2x-1)
174+
# # not needed for our purpose here, we use even k always to avoid
175+
# wm = np.pi / kUse / 2
176+
177+
# for ki in range(k):
178+
# for kpi in range(k):
179+
# H0[ki, kpi] = 1/np.sqrt(2) * (wm * phi[ki](x_m/2) * phi[kpi](x_m)).sum()
180+
# G0[ki, kpi] = 1/np.sqrt(2) * (wm * psi(psi1, psi2, ki, x_m/2) * phi[kpi](x_m)).sum()
181+
# H1[ki, kpi] = 1/np.sqrt(2) * (wm * phi[ki]((x_m+1)/2) * phi[kpi](x_m)).sum()
182+
# G1[ki, kpi] = 1/np.sqrt(2) * (wm * psi(psi1, psi2, ki, (x_m+1)/2) * phi[kpi](x_m)).sum()
183+
184+
# PHI0[ki, kpi] = (wm * phi[ki](2*x_m) * phi[kpi](2*x_m)).sum() * 2
185+
# PHI1[ki, kpi] = (wm * phi[ki](2*x_m-1) * phi[kpi](2*x_m-1)).sum() * 2
186+
187+
# PHI0[np.abs(PHI0)<1e-8] = 0
188+
# PHI1[np.abs(PHI1)<1e-8] = 0
189+
190+
# ----------------------------------------------------------
191+
192+
# H0[np.abs(H0)<1e-8] = 0
193+
# H1[np.abs(H1)<1e-8] = 0
194+
# G0[np.abs(G0)<1e-8] = 0
195+
# G1[np.abs(G1)<1e-8] = 0
196+
197+
# return H0, H1, G0, G1, PHI0, PHI1
198+
end

src/utils.jl

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# function ϕ_(ϕ_coefs; lb::Real=0., ub::Real=1.)
2+
# mask =
3+
# return Polynomial(ϕ_coefs)
4+
# end
5+
6+
# def phi_(phi_c, x, lb = 0, ub = 1):
7+
# mask = np.logical_or(x<lb, x>ub) * 1.0
8+
# return np.polynomial.polynomial.Polynomial(phi_c)(x) * (1-mask)
9+
10+
function ψ(ψ1, ψ2, i, inp)
11+
mask = (inp 0.5) * 1.0
12+
return ψ1[i](inp) * mask + ψ2[i](inp) * (1-mask)
13+
end
14+
15+
zero_out!(x; tol=1e-8) = (x[abs.(x) .< tol] .= 0)
16+
17+
function gen_poly(poly, n)
18+
x = zeros(n+1)
19+
x[end] = 1
20+
return poly(x)
21+
end
22+
23+
function convolve(a, b)
24+
n = length(b)
25+
y = similar(a, length(a)+n-1)
26+
for i in 1:length(a)
27+
y[i:(i+n-1)] .+= a[i] .* b
28+
end
29+
return y
30+
end
31+
32+
function proj_factor(a, b; complement::Bool=false)
33+
prod_ = convolve(a, b)
34+
zero_out!(prod_)
35+
r = collect(1:length(prod_))
36+
s = complement ? (1 .- 0.5 .^ r) : (0.5 .^ r)
37+
proj_ = sum(prod_ ./ r .* s)
38+
return proj_
39+
end

src/wavelet.jl

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,53 @@
1-
struct SparseKernel1d{T,S}
1+
struct SparseKernel{T,S}
22
k::Int
33
conv_blk::S
44
out_weight::T
55
end
66

7-
function SparseKernel1d(k::Int, c::Int=1; init=Flux.glorot_uniform)
7+
function SparseKernel1d(k::Int, α, c::Int=1; init=Flux.glorot_uniform)
88
input_dim = c*k
99
emb_dim = 128
1010
conv = Conv((3,), input_dim=>emb_dim, relu; stride=1, pad=1, init=init)
1111
W_out = Dense(emb_dim, input_dim; init=init)
12-
return SparseKernel1d(k, conv, W_out)
12+
return SparseKernel(k, conv, W_out)
1313
end
1414

15-
function (l::SparseKernel1d)(X::AbstractArray)
16-
X_ = l.conv_blk(batched_transpose(X))
17-
Y = l.out_weight(batched_transpose(X_))
18-
return Y
15+
function SparseKernel2d(k::Int, α, c::Int=1; init=Flux.glorot_uniform)
16+
input_dim = c*k^2
17+
emb_dim = α*k^2
18+
conv = Conv((3, 3), input_dim=>emb_dim, relu; stride=1, pad=1, init=init)
19+
W_out = Dense(emb_dim, input_dim; init=init)
20+
return SparseKernel(k, conv, W_out)
21+
end
22+
23+
function SparseKernel3d(k::Int, α, c::Int=1; init=Flux.glorot_uniform)
24+
input_dim = c*k^2
25+
emb_dim = α*k^2
26+
conv = Conv((3, 3, 3), emb_dim=>emb_dim, relu; stride=1, pad=1, init=init)
27+
W_out = Dense(emb_dim, input_dim; init=init)
28+
return SparseKernel(k, conv, W_out)
29+
end
30+
31+
function (l::SparseKernel)(X::AbstractArray)
32+
bch_sz, _, dims_r... = reverse(size(X))
33+
dims = reverse(dims_r)
34+
35+
X_ = l.conv_blk(X) # (dims..., emb_dims, B)
36+
X_ = reshape(X_, prod(dims), :, bch_sz) # (prod(dims), emb_dims, B)
37+
Y = l.out_weight(batched_transpose(X_)) # (in_dims, prod(dims), B)
38+
Y = reshape(batched_transpose(Y), dims..., :, bch_sz) # (dims..., in_dims, B)
39+
return collect(Y)
1940
end
2041

2142

43+
# struct MWT_CZ1d
44+
45+
# end
46+
47+
# function MWT_CZ1d(k::Int=3, c::Int=1; init=Flux.glorot_uniform)
48+
49+
# end
50+
2251
# class MWT_CZ1d(nn.Module):
2352
# def __init__(self,
2453
# k = 3, alpha = 5,

test/wavelet.jl

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,37 @@
11
using NeuralOperators
2+
using CUDA
3+
using Zygote
4+
5+
CUDA.allowscalar(false)
26

37
T = Float32
4-
k = 10
8+
k = 3
9+
batch_size = 32
10+
11+
α = 4
512
c = 1
613
in_chs = 20
7-
batch_size = 32
814

915

10-
l = NeuralOperators.SparseKernel1d(k, c)
16+
l1 = NeuralOperators.SparseKernel1d(k, α, c)
17+
X = rand(T, in_chs, c*k, batch_size)
18+
Y = l1(X)
19+
gradient(x->sum(l1(x)), X)
20+
21+
22+
α = 4
23+
c = 3
24+
Nx = 5
25+
Ny = 7
26+
27+
l2 = NeuralOperators.SparseKernel2d(k, α, c)
28+
X = rand(T, Nx, Ny, c*k^2, batch_size)
29+
Y = l2(X)
30+
gradient(x->sum(l2(x)), X)
31+
32+
Nz = 13
1133

12-
X = rand(T, c*k, in_chs, batch_size)
13-
Y = l(X)
34+
l3 = NeuralOperators.SparseKernel3d(k, α, c)
35+
X = rand(T, Nx, Ny, Nz, α*k^2, batch_size)
36+
Y = l3(X)
37+
gradient(x->sum(l3(x)), X)

0 commit comments

Comments
 (0)