Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 25 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
name = "ImplicitDifferentiation"
uuid = "57b37032-215b-411a-8a7c-41a003a55207"
authors = ["Guillaume Dalle", "Mohamed Tarek"]
version = "0.7.3"
version = "0.8.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
Krylov = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LinearMaps = "7a12625a-238d-50fd-b39a-03d52299707e"
LinearOperators = "5c8ed15e-5a4c-59e4-a42b-c7e8811fb125"

[weakdeps]
Expand All @@ -22,12 +24,30 @@ ImplicitDifferentiationZygoteExt = "Zygote"

[compat]
ADTypes = "1.9.0"
Aqua = "0.8.13"
ChainRulesCore = "1.25.0"
DifferentiationInterface = "0.6.1"
ChainRulesTestUtils = "1.13.0"
ComponentArrays = "0.15.27"
DifferentiationInterface = "0.6.1,0.7"
Documenter = "1.12.0"
ExplicitImports = "1"
FiniteDiff = "2.27.0"
ForwardDiff = "0.10.36, 1"
IterativeSolvers = "0.9.4"
JET = "0.9, 0.10"
JuliaFormatter = "2.1.2"
Krylov = "0.9.6, 0.10"
LinearAlgebra = "1.10"
LinearMaps = "3.11.4"
LinearOperators = "2.8.0"
NLsolve = "4.5.1"
Optim = "1.12.0"
Random = "1"
SparseArrays = "1"
StaticArrays = "1.9.13"
Test = "1"
TestItemRunner = "1.1.0"
TestItems = "1.0.0"
Zygote = "0.7.4"
julia = "1.10"

Expand All @@ -51,7 +71,9 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TestItemRunner = "f8b46487-2199-4994-9208-9a1283c18c0a"
TestItems = "1c621080-faea-4a02-84b6-bbd5e436b8fe"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["ADTypes", "Aqua", "ChainRulesCore", "ChainRulesTestUtils", "ComponentArrays", "DifferentiationInterface", "Documenter", "ExplicitImports", "FiniteDiff", "ForwardDiff", "JET", "JuliaFormatter", "NLsolve", "Optim", "Random", "SparseArrays", "StaticArrays", "Test", "Zygote"]
test = ["ADTypes", "Aqua", "ChainRulesCore", "ChainRulesTestUtils", "ComponentArrays", "DifferentiationInterface", "Documenter", "ExplicitImports", "FiniteDiff", "ForwardDiff", "JET", "JuliaFormatter", "NLsolve", "Optim", "Random", "SparseArrays", "StaticArrays", "Test", "TestItems", "TestItemRunner", "Zygote"]
3 changes: 1 addition & 2 deletions docs/src/api.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
```@meta
CurrentModule = ImplicitDifferentiation
CollapsedDocStrings = true
```

Expand All @@ -20,7 +19,7 @@ ImplicitFunction
### Settings

```@docs
KrylovLinearSolver
IterativeLinearSolver
MatrixRepresentation
OperatorRepresentation
NoPreparation
Expand Down
5 changes: 4 additions & 1 deletion docs/src/faq.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,10 @@ Say your forward mapping takes multiple inputs and returns multiple outputs, suc
The trick is to leverage [ComponentArrays.jl](https://github.com/jonniedie/ComponentArrays.jl) to wrap all the inputs inside a single a `ComponentVector`, and do the same for all the outputs.
See the examples for a demonstration.

!!! warning "Warning"
!!! warning
The default linear operator representation does not support ComponentArrays.jl: you need to select `representation=OperatorRepresentation{:LinearMaps}()` in the [`ImplicitFunction`](@ref) constructor for it to work.

!!! warning
You may run into issues trying to differentiate through the `ComponentVector` constructor.
For instance, Zygote.jl will throw `ERROR: Mutating arrays is not supported`.
Check out [this issue](https://github.com/gdalle/ImplicitDifferentiation.jl/issues/67) for a dirty workaround involving custom chain rules for the constructor.
Expand Down
9 changes: 6 additions & 3 deletions examples/3_tricks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ We demonstrate several features that may come in handy for some users.
using ComponentArrays
using ForwardDiff
using ImplicitDifferentiation
using Krylov
using LinearAlgebra
using Test #src
using Zygote
Expand Down Expand Up @@ -43,9 +42,13 @@ function conditions_components(x::ComponentVector, y::ComponentVector, _z)
return c
end;

# And build your implicit function like so.
# And build your implicit function like so, switching the operator representation to avoid errors with ComponentArrays.

implicit_components = ImplicitFunction(forward_components, conditions_components);
implicit_components = ImplicitFunction(
forward_components,
conditions_components;
representation=OperatorRepresentation{:LinearMaps}(),
);

# Now we're good to go.

Expand Down
7 changes: 4 additions & 3 deletions ext/ImplicitDifferentiationChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,18 @@ function ChainRulesCore.rrule(
rc::RuleConfig, implicit::ImplicitFunction, x::AbstractArray, args::Vararg{Any,N};
) where {N}
y, z = implicit(x, args...)
c = implicit.conditions(x, y, z, args...)

suggested_backend = chainrules_suggested_backend(rc)
Aᵀ = build_Aᵀ(implicit, x, y, z, args...; suggested_backend)
Bᵀ = build_Bᵀ(implicit, x, y, z, args...; suggested_backend)
Aᵀ = build_Aᵀ(implicit, x, y, z, c, args...; suggested_backend)
Bᵀ = build_Bᵀ(implicit, x, y, z, c, args...; suggested_backend)
project_x = ProjectTo(x)

function implicit_pullback((dy, dz))
dy = unthunk(dy)
dy_vec = vec(dy)
dc_vec = implicit.linear_solver(Aᵀ, -dy_vec)
dx_vec = Bᵀ * dc_vec
dx_vec = Bᵀ(dc_vec)
dx = reshape(dx_vec, size(x))
df = NoTangent()
dargs = ntuple(unimplemented_tangent, N)
Expand Down
12 changes: 7 additions & 5 deletions ext/ImplicitDifferentiationForwardDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,24 @@ function (implicit::ImplicitFunction)(
) where {T,R,N}
x = value.(x_and_dx)
y, z = implicit(x, args...)
c = implicit.conditions(x, y, z, args...)

suggested_backend = AutoForwardDiff()
A = build_A(implicit, x, y, z, args...; suggested_backend)
B = build_B(implicit, x, y, z, args...; suggested_backend)
A = build_A(implicit, x, y, z, c, args...; suggested_backend)
B = build_B(implicit, x, y, z, c, args...; suggested_backend)

dX = ntuple(Val(N)) do k
partials.(x_and_dx, k)
end
dC_mat = mapreduce(hcat, dX) do dₖx
dₖx_vec = vec(dₖx)
dₖc_vec = B * dₖx_vec
dₖc_vec = B(dₖx_vec)
return dₖc_vec
end
dY_mat = implicit.linear_solver(A, -dC_mat)

y_and_dy = map(LinearIndices(y)) do i
Dual{T}(y[i], Partials(ntuple(k -> dY_mat[i, k], Val(N))))
y_and_dy = map(y, LinearIndices(y)) do yi, i
Dual{T}(yi, Partials(ntuple(k -> dY_mat[i, k], Val(N))))
end

return y_and_dy, z
Expand Down
10 changes: 7 additions & 3 deletions src/ImplicitDifferentiation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,13 @@ using DifferentiationInterface:
prepare_pushforward,
prepare_pushforward_same_point,
pullback!,
pushforward!
using Krylov: gmres
pullback,
pushforward!,
pushforward
using Krylov: Krylov
using IterativeSolvers: IterativeSolvers
using LinearOperators: LinearOperator
using LinearMaps: FunctionMap
using LinearAlgebra: factorize

include("utils.jl")
Expand All @@ -29,7 +33,7 @@ include("preparation.jl")
include("implicit_function.jl")
include("execution.jl")

export KrylovLinearSolver
export IterativeLinearSolver
export MatrixRepresentation, OperatorRepresentation
export NoPreparation, ForwardPreparation, ReversePreparation, BothPreparation
export ImplicitFunction
Expand Down
Loading
Loading