Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ImplicitDifferentiation"
uuid = "57b37032-215b-411a-8a7c-41a003a55207"
authors = ["Guillaume Dalle", "Mohamed Tarek"]
version = "0.7.1"
version = "0.7.2"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
11 changes: 3 additions & 8 deletions docs/src/faq.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,11 @@ However, this can be switched to any other "inner" backend compatible with [Diff

## Input and output types

### Vectors

Functions that eat or spit out arbitrary vectors are supported, as long as the forward mapping _and_ conditions return vectors of the same size.

If you deal with small vectors (say, less than 100 elements), consider using [StaticArrays.jl](https://github.com/JuliaArrays/StaticArrays.jl) for increased performance.

### Arrays

Functions that eat or spit out matrices and higher-order tensors are not supported.
You can use `vec` and `reshape` for the conversion to and from vectors.
Functions that eat or spit out arbitrary arrays are supported, as long as the forward mapping _and_ conditions return arrays of the same size.

If you deal with small arrays (say, less than 100 elements), consider using [StaticArrays.jl](https://github.com/JuliaArrays/StaticArrays.jl) for increased performance.

### Scalars

Expand Down
4 changes: 2 additions & 2 deletions examples/0_intro.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,15 @@ This is essentially the componentwise square root function but with an additiona
We can check that it does what it's supposed to do.
=#

x = [4.0, 9.0]
x = [1.0 2.0; 3.0 4.0]
badsqrt(x)
@test badsqrt(x) ≈ sqrt.(x) #src

#=
Of course the Jacobian has an explicit formula.
=#

J = Diagonal(0.5 ./ sqrt.(x))
J = Diagonal(0.5 ./ vec(sqrt.(x)))

#=
However, things start to go wrong when we compute it with autodiff, due to the [limitations of ForwardDiff.jl](https://juliadiff.org/ForwardDiff.jl/stable/user/limitations/) and [those of Zygote.jl](https://fluxml.ai/Zygote.jl/stable/limitations/).
Expand Down
8 changes: 5 additions & 3 deletions ext/ImplicitDifferentiationChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ using ImplicitDifferentiation:
ImplicitDifferentiation.chainrules_suggested_backend(rc::RuleConfig) = AutoChainRules(rc)

function ChainRulesCore.rrule(
rc::RuleConfig, implicit::ImplicitFunction, x::AbstractVector, args::Vararg{Any,N};
rc::RuleConfig, implicit::ImplicitFunction, x::AbstractArray, args::Vararg{Any,N};
) where {N}
y, z = implicit(x, args...)

Expand All @@ -25,8 +25,10 @@ function ChainRulesCore.rrule(

function implicit_pullback((dy, dz))
dy = unthunk(dy)
dc = implicit.linear_solver(Aᵀ, -dy)
dx = Bᵀ * dc
dy_vec = vec(dy)
dc_vec = implicit.linear_solver(Aᵀ, -dy_vec)
dx_vec = Bᵀ * dc_vec
dx = reshape(dx_vec, size(x))
df = NoTangent()
dargs = ntuple(unimplemented_tangent, N)
return (df, project_x(dx), dargs...)
Expand Down
15 changes: 8 additions & 7 deletions ext/ImplicitDifferentiationForwardDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ using ForwardDiff: Dual, Partials, partials, value
using ImplicitDifferentiation: ImplicitFunction, build_A, build_B

function (implicit::ImplicitFunction)(
x_and_dx::AbstractVector{Dual{T,R,N}}, args...
x_and_dx::AbstractArray{Dual{T,R,N}}, args...
) where {T,R,N}
x = value.(x_and_dx)
y, z = implicit(x, args...)
Expand All @@ -14,16 +14,17 @@ function (implicit::ImplicitFunction)(
A = build_A(implicit, x, y, z, args...; suggested_backend)
B = build_B(implicit, x, y, z, args...; suggested_backend)

dX = map(1:N) do k
dX = ntuple(Val(N)) do k
partials.(x_and_dx, k)
end
dC = mapreduce(hcat, dX) do dₖx
B * dₖx
dC_mat = mapreduce(hcat, dX) do dₖx
dₖx_vec = vec(dₖx)
dₖc_vec = B * dₖx_vec
end
dY = implicit.linear_solver(A, -dC)
dY_mat = implicit.linear_solver(A, -dC_mat)

y_and_dy = map(eachindex(y)) do i
Dual{T}(y[i], Partials(ntuple(k -> dY[i, k], Val(N))))
y_and_dy = map(LinearIndices(y)) do i
Dual{T}(y[i], Partials(ntuple(k -> dY_mat[i, k], Val(N))))
end

return y_and_dy, z
Expand Down
1 change: 1 addition & 0 deletions src/ImplicitDifferentiation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ using Krylov: gmres
using LinearOperators: LinearOperator
using LinearAlgebra: factorize

include("utils.jl")
include("settings.jl")
include("preparation.jl")
include("implicit_function.jl")
Expand Down
68 changes: 40 additions & 28 deletions src/execution.jl
Original file line number Diff line number Diff line change
@@ -1,40 +1,40 @@
const SYMMETRIC = false
const HERMITIAN = false

struct JVP!{F,P,B,X,C}
struct JVP!{F,P,B,I,C}
f::F
prep::P
backend::B
x::X
input::I
contexts::C
end

struct VJP!{F,P,B,X,C}
struct VJP!{F,P,B,I,C}
f::F
prep::P
backend::B
x::X
input::I
contexts::C
end

function (po::JVP!)(res::AbstractVector, v::AbstractVector)
(; f, backend, x, contexts, prep) = po
pushforward!(f, (res,), prep, backend, x, (v,), contexts...)
(; f, backend, input, contexts, prep) = po
pushforward!(f, (res,), prep, backend, input, (v,), contexts...)
return res
end

function (po::VJP!)(res::AbstractVector, v::AbstractVector)
(; f, backend, x, contexts, prep) = po
pullback!(f, (res,), prep, backend, x, (v,), contexts...)
(; f, backend, input, contexts, prep) = po
pullback!(f, (res,), prep, backend, input, (v,), contexts...)
return res
end

## A

function build_A(
implicit::ImplicitFunction,
x::AbstractVector,
y::AbstractVector,
x::AbstractArray,
y::AbstractArray,
z,
args...;
suggested_backend::AbstractADType,
Expand All @@ -58,21 +58,24 @@ function build_A_aux(
(; conditions, backend, prep_A) = implicit
actual_backend = isnothing(backend) ? suggested_backend : backend
contexts = (Constant(x), Constant(z), map(Constant, args)...)
f_vec = VecToVec(Switch12(conditions), y)
y_vec = vec(y)
dy_vec = vec(zero(y))
prep_A_same = prepare_pushforward_same_point(
Switch12(conditions), prep_A..., actual_backend, y, (zero(y),), contexts...
f_vec, prep_A..., actual_backend, y_vec, (dy_vec,), contexts...
)
prod! = JVP!(Switch12(conditions), prep_A_same, actual_backend, y, contexts)
prod! = JVP!(f_vec, prep_A_same, actual_backend, y_vec, contexts)
return LinearOperator(
eltype(y), length(y), length(y), SYMMETRIC, HERMITIAN, prod!, typeof(y)
eltype(y), length(y), length(y), SYMMETRIC, HERMITIAN, prod!, typeof(y_vec)
)
end

## Aᵀ

function build_Aᵀ(
implicit::ImplicitFunction,
x::AbstractVector,
y::AbstractVector,
x::AbstractArray,
y::AbstractArray,
z,
args...;
suggested_backend::AbstractADType,
Expand All @@ -98,21 +101,24 @@ function build_Aᵀ_aux(
(; conditions, backend, prep_Aᵀ) = implicit
actual_backend = isnothing(backend) ? suggested_backend : backend
contexts = (Constant(x), Constant(z), map(Constant, args)...)
f_vec = VecToVec(Switch12(conditions), y)
y_vec = vec(y)
dc_vec = vec(zero(y))
prep_Aᵀ_same = prepare_pullback_same_point(
Switch12(conditions), prep_Aᵀ..., actual_backend, y, (zero(y),), contexts...
f_vec, prep_Aᵀ..., actual_backend, y_vec, (dc_vec,), contexts...
)
prod! = VJP!(Switch12(conditions), prep_Aᵀ_same, actual_backend, y, contexts)
prod! = VJP!(f_vec, prep_Aᵀ_same, actual_backend, y_vec, contexts)
return LinearOperator(
eltype(y), length(y), length(y), SYMMETRIC, HERMITIAN, prod!, typeof(y)
eltype(y), length(y), length(y), SYMMETRIC, HERMITIAN, prod!, typeof(y_vec)
)
end

## B

function build_B(
implicit::ImplicitFunction,
x::AbstractVector,
y::AbstractVector,
x::AbstractArray,
y::AbstractArray,
z,
args...;
suggested_backend::AbstractADType,
Expand All @@ -135,21 +141,24 @@ function build_B_aux(
(; conditions, backend, prep_B) = implicit
actual_backend = isnothing(backend) ? suggested_backend : backend
contexts = (Constant(y), Constant(z), map(Constant, args)...)
f_vec = VecToVec(conditions, x)
x_vec = vec(x)
dx_vec = vec(zero(x))
prep_B_same = prepare_pushforward_same_point(
conditions, prep_B..., actual_backend, x, (zero(x),), contexts...
f_vec, prep_B..., actual_backend, x_vec, (dx_vec,), contexts...
)
prod! = JVP!(conditions, prep_B_same, actual_backend, x, contexts)
prod! = JVP!(f_vec, prep_B_same, actual_backend, x_vec, contexts)
return LinearOperator(
eltype(y), length(y), length(x), SYMMETRIC, HERMITIAN, prod!, typeof(x)
eltype(y), length(y), length(x), SYMMETRIC, HERMITIAN, prod!, typeof(x_vec)
)
end

## Bᵀ

function build_Bᵀ(
implicit::ImplicitFunction,
x::AbstractVector,
y::AbstractVector,
x::AbstractArray,
y::AbstractArray,
z,
args...;
suggested_backend::AbstractADType,
Expand All @@ -172,11 +181,14 @@ function build_Bᵀ_aux(
(; conditions, backend, prep_Bᵀ) = implicit
actual_backend = isnothing(backend) ? suggested_backend : backend
contexts = (Constant(y), Constant(z), map(Constant, args)...)
f_vec = VecToVec(conditions, x)
x_vec = vec(x)
dc_vec = vec(zero(y))
prep_Bᵀ_same = prepare_pullback_same_point(
conditions, prep_Bᵀ..., actual_backend, x, (zero(y),), contexts...
f_vec, prep_Bᵀ..., actual_backend, x_vec, (dc_vec,), contexts...
)
prod! = VJP!(conditions, prep_Bᵀ_same, actual_backend, x, contexts)
prod! = VJP!(f_vec, prep_Bᵀ_same, actual_backend, x_vec, contexts)
return LinearOperator(
eltype(y), length(x), length(y), SYMMETRIC, HERMITIAN, prod!, typeof(x)
eltype(y), length(x), length(y), SYMMETRIC, HERMITIAN, prod!, typeof(x_vec)
)
end
4 changes: 2 additions & 2 deletions src/implicit_function.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ This requires solving a linear system `A * J = -B` where `A = ∂₂c`, `B = ∂

## Positional arguments

- `solver`: a callable returning `(x, args...) -> (y, z)` where `z` is an arbitrary byproduct of the solve. Both `x` and `y` must be subtypes of `AbstractVector`, while `z` and `args` can be anything.
- `solver`: a callable returning `(x, args...) -> (y, z)` where `z` is an arbitrary byproduct of the solve. Both `x` and `y` must be subtypes of `AbstractArray`, while `z` and `args` can be anything.
- `conditions`: a callable returning a vector of optimality conditions `(x, y, z, args...) -> c`, must be compatible with automatic differentiation

## Keyword arguments
Expand Down Expand Up @@ -127,6 +127,6 @@ function Base.show(io::IO, implicit::ImplicitFunction)
)
end

function (implicit::ImplicitFunction)(x::AbstractVector, args::Vararg{Any,N}) where {N}
function (implicit::ImplicitFunction)(x::AbstractArray, args::Vararg{Any,N}) where {N}
return implicit.solver(x, args...)
end
Loading