Skip to content

Commit 3295efc

Browse files
committed
add new linearizatino method
1 parent ce3fb3a commit 3295efc

File tree

2 files changed

+8
-0
lines changed

2 files changed

+8
-0
lines changed

src/approximations/linearization.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,13 @@ function local_linearization(r::Real, splitg::S, x_hat::H) where {S, H}
7777
return local_linearization_split(r, fA, x_hat)
7878
end
7979

80+
# In case if `g(x_hat)` returns a vector, but input is a number
81+
function local_linearization(result::AbstractVector, g::G, x_hat::Tuple{T}) where {G, T <: Real}
82+
A = ForwardDiff.derivative(g, first(x_hat))
83+
b = result - A * first(x_hat)
84+
return (A, b)
85+
end
86+
8087
# In case if `g(x_hat)` returns a vector, but inputs are numbers
8188
function local_linearization(r::AbstractVector, splitg::S, x_hat::H) where {S, H}
8289
# `r` is a vector, so we need to use `jacobian` instead of `gradient`

test/approximations/linearization_tests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,6 @@
1111
@test @inferred(approximate(Linearization(), (x, y) -> x - y, (1, 2))) == ([1 -1], 0)
1212
@test @inferred(approximate(Linearization(), (x, y) -> x .- y, ([1.0, 2.0], 1.0))) == ([1.0 0.0 -1.0; 0.0 1.0 -1.0], [0.0, 0.0])
1313
@test @inferred(approximate(Linearization(), (x, y) -> x .- y, ([1.0, 2.0], [1.0, 1.0]))) == ([1.0 0.0 -1.0 0.0; 0.0 1.0 0.0 -1.0], [0.0, 0.0])
14+
@test @inferred(approximate(Linearization(), (x) -> x .- [1, 1], (1.0,))) == ([1.0, 1.0], [-1.0, -1.0])
1415
end
1516
end

0 commit comments

Comments
 (0)