Skip to content

Commit 8eeebc8

Browse files
authored
Update to DI v0.2 (#139)
1 parent 2579f7a commit 8eeebc8

10 files changed

Lines changed: 135 additions & 74 deletions

File tree

.github/workflows/CI.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ jobs:
5050
version: '1'
5151
- uses: julia-actions/cache@v1
5252
- name: Install dependencies
53-
run: julia --project=docs/ -e 'using Pkg; Pkg.develop(PackageSpec(path=pwd())); Pkg.instantiate()'
53+
run: julia --project=docs/ -e 'using Pkg; Pkg.Registry.update(); Pkg.develop(PackageSpec(path=pwd())); Pkg.instantiate()'
5454
- name: Build and deploy
5555
env:
5656
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} # If authenticating with GitHub Actions token

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ ImplicitDifferentiationForwardDiffExt = "ForwardDiff"
2323
[compat]
2424
ADTypes = "0.2"
2525
ChainRulesCore = "1.23.0"
26-
DifferentiationInterface = "0.1"
26+
DifferentiationInterface = "0.2"
2727
Enzyme = "0.11.20"
2828
ForwardDiff = "0.10.36"
2929
Krylov = "0.9.5"

docs/src/faq.md

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
## Supported autodiff backends
44

5-
To differentiate an `ImplicitFunction`, the following backends are supported.
5+
To differentiate through an `ImplicitFunction`, the following backends are supported.
66

77
| Backend | Forward mode | Reverse mode |
88
| :--------------------------------------------------------------------- | :----------- | :----------- |
@@ -31,9 +31,9 @@ Or better yet, wrap it in a static vector: `SVector(val)`.
3131
### Sparse arrays
3232

3333
!!! danger "Danger"
34-
Sparse arrays are not officially supported and might give incorrect values or `NaN`s!
34+
Sparse arrays are not supported and might give incorrect values or `NaN`s!
3535

36-
With ForwardDiff.jl, differentiation of sparse arrays will always give wrong results due to [sparsity pattern cancellation](https://github.com/JuliaDiff/ForwardDiff.jl/issues/658).
36+
With ForwardDiff.jl, differentiation of sparse arrays will often give wrong results due to [sparsity pattern cancellation](https://github.com/JuliaDiff/ForwardDiff.jl/issues/658).
3737
That is why we do not test behavior for sparse inputs.
3838

3939
## Number of inputs and outputs
@@ -92,6 +92,21 @@ This is mainly useful when the solution procedure creates objects such as Jacobi
9292
In that case, you may want to write the conditions differentiation rules yourself.
9393
A more advanced application is given by [DifferentiableFrankWolfe.jl](https://github.com/gdalle/DifferentiableFrankWolfe.jl).
9494

95+
## Linear system
96+
97+
### Lazy or dense
98+
99+
Usually, dense Jacobians are more efficient in small dimension, while lazy operators become necessary in high dimension.
100+
This choice is made via the `lazy` type parameter of [`ImplicitFunction`](@ref), with `lazy = true` being the default.
101+
102+
### Picking a solver
103+
104+
The right linear solver to use depends on the Jacobian representation.
105+
You can usually stick to the default settings:
106+
107+
- the direct solver `\` for dense Jacobians
108+
- an iterative solver from [Krylov.jl](https://github.com/JuliaSmoothOptimizers/Krylov.jl) for lazy operators
109+
95110
## Modeling tips
96111

97112
### Writing conditions

examples/1_basic.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ end;
7373
We now have all the ingredients to construct our implicit function.
7474
=#
7575

76-
implicit_optim = ImplicitFunction(; forward=forward_optim, conditions=conditions_optim)
76+
implicit_optim = ImplicitFunction(forward_optim, conditions_optim)
7777

7878
# And indeed, it behaves as it should when we call it:
7979

src/ImplicitDifferentiation.jl

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,7 @@ module ImplicitDifferentiation
99

1010
using ADTypes: AbstractADType
1111
using DifferentiationInterface:
12-
jacobian,
13-
prepare_pushforward,
14-
prepare_pullback,
15-
pushforward!!,
16-
value_and_pullback!!_split
12+
jacobian, prepare_pushforward, prepare_pullback, pushforward!, value_and_pullback!_split
1713
using Krylov: block_gmres, gmres
1814
using LinearOperators: LinearOperator
1915
using LinearAlgebra: factorize, lu

src/implicit_function.jl

Lines changed: 64 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ function (::DefaultLinearSolver)(A, B::AbstractMatrix)
1414
end
1515

1616
"""
17-
ImplicitFunction
17+
ImplicitFunction{lazy}
1818
1919
Wrapper for an implicit function defined by a forward mapping `y` and a set of conditions `c`.
2020
@@ -27,47 +27,94 @@ This requires solving a linear system `A * J = -B` where `A = ∂c/∂y`, `B =
2727
2828
# Fields
2929
30-
- `forward`: a callable, does not need to be compatible with automatic differentiation
31-
- `conditions`: a callable, must be compatible with automatic differentiation
32-
- `linear_solver`: a callable with two methods:
33-
- `(A, b::AbstractVector) -> s::AbstractVector` such that `A * s = b`
34-
- `(A, B::AbstractVector) -> S::AbstractMatrix` such that `A * S = B`
35-
- `conditions_x_backend`: either `nothing` or an object subtyping `AbstractADType` from [ADTypes.jl](https://github.com/SciML/ADTypes.jl), defines how the conditions will be differentiated with respect to the first argument `x`
36-
- `conditions_y_backend`: same for the second argument `y`
30+
- `forward`: a callable computing `y(x)`, does not need to be compatible with automatic differentiation
31+
- `conditions`: a callable computing `c(x, y)`, must be compatible with automatic differentiation
32+
- `linear_solver`: a callable to solve the linear system `A * J = -B`
33+
- `conditions_x_backend`: defines how the conditions will be differentiated with respect to the first argument `x`
34+
- `conditions_y_backend`: defines how the conditions will be differentiated with respect to the second argument `y`
35+
36+
# Type parameters
37+
38+
- `lazy`: whether to use a `LinearOperator` from [LinearOperators.jl](https://github.com/JuliaSmoothOptimizers/LinearOperators.jl) (`lazy = true`) or a dense Jacobian matrix (`lazy = false`) for `A` and `B`
39+
40+
# Constructors
41+
42+
ImplicitFunction{lazy}(
43+
forward, conditions;
44+
linear_solver, conditions_x_backend, conditions_x_backend
45+
)
46+
47+
ImplicitFunction(
48+
forward, conditions;
49+
linear_solver, conditions_x_backend, conditions_x_backend
50+
)
51+
52+
Default values:
53+
54+
- `lazy = true`
55+
- `linear_solver`: the direct solver `\` for dense Jacobians, or an iterative solver from [Krylov.jl](https://github.com/JuliaSmoothOptimizers/Krylov.jl) for lazy operators
56+
- `conditions_x_backend = nothing`
57+
- `conditions_y_backend = nothing`
58+
59+
# Function signatures
3760
3861
There are two possible signatures for `forward` and `conditions`, which must be consistent with one another:
3962
4063
| standard | byproduct |
4164
|:---|:---|
42-
| `forward(x, args...; kwargs...) = y` | `conditions(x, y, args...; kwargs...) = c` |
65+
| `forward(x, args...; kwargs...) = y` | `conditions(x, y, args...; kwargs...) = c` |
4366
| `forward(x, args...; kwargs...) = (y, z)` | `conditions(x, y, z, args...; kwargs...) = c` |
4467
4568
In both cases, `x`, `y` and `c` must be `AbstractVector`s, with `length(y) = length(c)`.
4669
In the second case, the byproduct `z` can be an arbitrary object generated by `forward`.
4770
The positional arguments `args...` and keyword arguments `kwargs...` must be the same for both `forward` and `conditions`.
4871
4972
The byproduct `z` and the other positional arguments `args...` beyond `x` are considered constant for differentiation purposes.
73+
74+
# Linear solver
75+
76+
The provided `linear_solver` objects needs to be callable, with two methods:
77+
- `(A, b::AbstractVector) -> s::AbstractVector` such that `A * s = b`
78+
- `(A, B::AbstractVector) -> S::AbstractMatrix` such that `A * S = B`
79+
80+
# Condition backends
81+
82+
The provided `conditions_x_backend` and `conditions_y_backend` can be either:
83+
- `nothing`, in which case the outer backend (the one differentiating through the `ImplicitFunction`) is used
84+
- an object subtyping `AbstractADType` from [ADTypes.jl](https://github.com/SciML/ADTypes.jl).
5085
"""
51-
@kwdef struct ImplicitFunction{
52-
F,C,L,B1<:Union{Nothing,AbstractADType},B2<:Union{Nothing,AbstractADType}
86+
struct ImplicitFunction{
87+
lazy,F,C,L,B1<:Union{Nothing,AbstractADType},B2<:Union{Nothing,AbstractADType}
5388
}
5489
forward::F
5590
conditions::C
56-
linear_solver::L = DefaultLinearSolver()
57-
conditions_x_backend::B1 = nothing
58-
conditions_y_backend::B2 = nothing
91+
linear_solver::L
92+
conditions_x_backend::B1
93+
conditions_y_backend::B2
94+
end
95+
96+
function ImplicitFunction{lazy}(
97+
forward::F,
98+
conditions::C;
99+
linear_solver::L=lazy ? DefaultLinearSolver() : \,
100+
conditions_x_backend::B1=nothing,
101+
conditions_y_backend::B2=nothing,
102+
) where {lazy,F,C,L,B1,B2}
103+
return ImplicitFunction{lazy,F,C,L,B1,B2}(
104+
forward, conditions, linear_solver, conditions_x_backend, conditions_y_backend
105+
)
59106
end
60107

61108
function ImplicitFunction(forward, conditions; kwargs...)
62-
return ImplicitFunction(; forward, conditions, kwargs...)
109+
return ImplicitFunction{true}(forward, conditions; kwargs...)
63110
end
64111

65-
function Base.show(io::IO, implicit::ImplicitFunction)
112+
function Base.show(io::IO, implicit::ImplicitFunction{lazy}) where {lazy}
66113
(; forward, conditions, linear_solver, conditions_x_backend, conditions_y_backend) =
67114
implicit
68115
return print(
69116
io,
70-
"ImplicitFunction($forward, $conditions, $linear_solver, $conditions_x_backend, $conditions_y_backend)",
117+
"ImplicitFunction{$(lazy ? "lazy" : "dense")}($forward, $conditions, $linear_solver, $conditions_x_backend, $conditions_y_backend)",
71118
)
72119
end
73120

src/operators.jl

Lines changed: 35 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -81,51 +81,47 @@ end
8181

8282
function (po::PushforwardOperator!)(res, v, α, β)
8383
if iszero(β)
84-
res .= pushforward!!(po.f, res, po.backend, po.x, v, po.extras)
84+
pushforward!(po.f, res, po.backend, po.x, v, po.extras)
8585
res .= α .* res
8686
else
8787
po.res_backup .= res
88-
res .= pushforward!!(po.f, res, po.backend, po.x, v, po.extras)
88+
pushforward!(po.f, res, po.backend, po.x, v, po.extras)
8989
res .= α .* res .+ β .* po.res_backup
9090
end
9191
return res
9292
end
9393

9494
struct PullbackOperator!{PB,R}
95-
pullbackfunc!!::PB
95+
pullbackfunc!::PB
9696
res_backup::R
9797
end
9898

9999
function (po::PullbackOperator!)(res, v, α, β)
100100
if iszero(β)
101-
res .= po.pullbackfunc!!(res, v)
102-
res .= α .* res
101+
po.pullbackfunc!(res, v)
103102
else
104103
po.res_backup .= res
105-
res .= po.pullbackfunc!!(res, v)
104+
po.pullbackfunc!(res, v)
106105
res .= α .* res .+ β .+ po.res_backup
107106
end
108107
return res
109108
end
110109

111110
function build_A(
112-
implicit::ImplicitFunction,
111+
implicit::ImplicitFunction{lazy},
113112
x::AbstractVector,
114113
y_or_yz,
115114
args...;
116115
suggested_backend,
117116
kwargs...,
118-
)
117+
) where {lazy}
119118
(; conditions, linear_solver, conditions_y_backend) = implicit
120119
y = output(y_or_yz)
121120
n, m = length(x), length(y)
122121
back_y = isnothing(conditions_y_backend) ? suggested_backend : conditions_y_backend
123122
cond_y = ConditionsY(conditions, x, y_or_yz, args...; kwargs...)
124-
if linear_solver isa typeof(\)
125-
J = jacobian(cond_y, back_y, y)
126-
A = factorize(J)
127-
else
128-
extras = prepare_pushforward(cond_y, back_y, y)
123+
if lazy
124+
extras = prepare_pushforward(cond_y, back_y, y, similar(y))
129125
A = LinearOperator(
130126
eltype(y),
131127
m,
@@ -135,59 +131,60 @@ function build_A(
135131
PushforwardOperator!(cond_y, back_y, y, extras, similar(y)),
136132
typeof(y),
137133
)
134+
else
135+
J = jacobian(cond_y, back_y, y)
136+
A = factorize(J)
138137
end
139138
return A
140139
end
141140

142141
function build_Aᵀ(
143-
implicit::ImplicitFunction,
142+
implicit::ImplicitFunction{lazy},
144143
x::AbstractVector,
145144
y_or_yz,
146145
args...;
147146
suggested_backend,
148147
kwargs...,
149-
)
148+
) where {lazy}
150149
(; conditions, linear_solver, conditions_y_backend) = implicit
151150
y = output(y_or_yz)
152151
n, m = length(x), length(y)
153152
back_y = isnothing(conditions_y_backend) ? suggested_backend : conditions_y_backend
154153
cond_y = ConditionsY(conditions, x, y_or_yz, args...; kwargs...)
155-
if linear_solver isa typeof(\)
156-
Jᵀ = transpose(jacobian(cond_y, back_y, y))
157-
Aᵀ = factorize(Jᵀ)
158-
else
159-
extras = prepare_pullback(cond_y, back_y, y)
160-
_, pullbackfunc!! = value_and_pullback!!_split(cond_y, back_y, y, extras)
154+
if lazy
155+
extras = prepare_pullback(cond_y, back_y, y, similar(y))
156+
_, pullbackfunc! = value_and_pullback!_split(cond_y, back_y, y, extras)
161157
Aᵀ = LinearOperator(
162158
eltype(y),
163159
m,
164160
m,
165161
false,
166162
false,
167-
PullbackOperator!(pullbackfunc!!, similar(y)),
163+
PullbackOperator!(pullbackfunc!, similar(y)),
168164
typeof(y),
169165
)
166+
else
167+
Jᵀ = transpose(jacobian(cond_y, back_y, y))
168+
Aᵀ = factorize(Jᵀ)
170169
end
171170
return Aᵀ
172171
end
173172

174173
function build_B(
175-
implicit::ImplicitFunction,
174+
implicit::ImplicitFunction{lazy},
176175
x::AbstractVector,
177176
y_or_yz,
178177
args...;
179178
suggested_backend,
180179
kwargs...,
181-
)
180+
) where {lazy}
182181
(; conditions, linear_solver, conditions_x_backend) = implicit
183182
y = output(y_or_yz)
184183
n, m = length(x), length(y)
185184
back_x = isnothing(conditions_x_backend) ? suggested_backend : conditions_x_backend
186185
cond_x = ConditionsX(conditions, x, y_or_yz, args...; kwargs...)
187-
if linear_solver isa typeof(\)
188-
B = transpose(jacobian(cond_x, back_x, x))
189-
else
190-
extras = prepare_pushforward(cond_x, back_x, x)
186+
if lazy
187+
extras = prepare_pushforward(cond_x, back_x, x, similar(x))
191188
B = LinearOperator(
192189
eltype(y),
193190
m,
@@ -197,37 +194,39 @@ function build_B(
197194
PushforwardOperator!(cond_x, back_x, x, extras, similar(y)),
198195
typeof(x),
199196
)
197+
else
198+
B = transpose(jacobian(cond_x, back_x, x))
200199
end
201200
return B
202201
end
203202

204203
function build_Bᵀ(
205-
implicit::ImplicitFunction,
204+
implicit::ImplicitFunction{lazy},
206205
x::AbstractVector,
207206
y_or_yz,
208207
args...;
209208
suggested_backend,
210209
kwargs...,
211-
)
210+
) where {lazy}
212211
(; conditions, linear_solver, conditions_x_backend) = implicit
213212
y = output(y_or_yz)
214213
n, m = length(x), length(y)
215214
back_x = isnothing(conditions_x_backend) ? suggested_backend : conditions_x_backend
216215
cond_x = ConditionsX(conditions, x, y_or_yz, args...; kwargs...)
217-
if linear_solver isa typeof(\)
218-
Bᵀ = transpose(jacobian(cond_x, back_x, x))
219-
else
220-
extras = prepare_pullback(cond_x, back_x, x)
221-
_, pullbackfunc!! = value_and_pullback!!_split(cond_x, back_x, x, extras)
216+
if lazy
217+
extras = prepare_pullback(cond_x, back_x, x, similar(y))
218+
_, pullbackfunc! = value_and_pullback!_split(cond_x, back_x, x, extras)
222219
Bᵀ = LinearOperator(
223220
eltype(y),
224221
n,
225222
m,
226223
false,
227224
false,
228-
PullbackOperator!(pullbackfunc!!, similar(y)),
225+
PullbackOperator!(pullbackfunc!, similar(y)),
229226
typeof(x),
230227
)
228+
else
229+
Bᵀ = transpose(jacobian(cond_x, back_x, x))
231230
end
232231
return Bᵀ
233232
end

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ using Random
1010
using Test
1111
using Zygote: Zygote
1212

13-
DocMeta.setdocmeta!(
13+
Documenter.DocMeta.setdocmeta!(
1414
ImplicitDifferentiation, :DocTestSetup, :(using ImplicitDifferentiation); recursive=true
1515
)
1616

0 commit comments

Comments
 (0)