Skip to content

Commit f41651f

Browse files
authored
Clarify linear solvers and vector requirements (#140)
* Clarify linear solvers and vector requirements * Fix parsing
1 parent 8eeebc8 commit f41651f

4 files changed

Lines changed: 97 additions & 70 deletions

File tree

docs/src/faq.md

Lines changed: 7 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -18,22 +18,22 @@ You can override the default with the `conditions_x_backend` and `conditions_y_b
1818

1919
### Arrays
2020

21-
Functions that eat or spit out arbitrary arrays are supported, as long as the forward mapping _and_ conditions return arrays of the same size.
21+
Functions that eat or spit out arbitrary vectors are supported, as long as the forward mapping _and_ conditions return vectors of the same size.
2222

23-
If you deal with small arrays (say, less than 100 elements), consider using [StaticArrays.jl](https://github.com/JuliaArrays/StaticArrays.jl) for increased performance.
23+
If you deal with small vectors (say, less than 100 elements), consider using [StaticArrays.jl](https://github.com/JuliaArrays/StaticArrays.jl) for increased performance.
2424

2525
### Scalars
2626

2727
Functions that eat or spit out a single number are not supported.
28-
The forward mapping _and_ conditions need arrays: instead of returning `val` you should return `[val]` (a 1-element `Vector`).
28+
The forward mapping _and_ conditions need vectors: instead of returning `val` you should return `[val]` (a 1-element `Vector`).
2929
Or better yet, wrap it in a static vector: `SVector(val)`.
3030

31-
### Sparse arrays
31+
### Sparse vectors
3232

3333
!!! danger "Danger"
34-
Sparse arrays are not supported and might give incorrect values or `NaN`s!
34+
Sparse vectors are not supported and might give incorrect values or `NaN`s!
3535

36-
With ForwardDiff.jl, differentiation of sparse arrays will often give wrong results due to [sparsity pattern cancellation](https://github.com/JuliaDiff/ForwardDiff.jl/issues/658).
36+
With ForwardDiff.jl, differentiation of sparse vectors will often give wrong results due to [sparsity pattern cancellation](https://github.com/JuliaDiff/ForwardDiff.jl/issues/658).
3737
That is why we do not test behavior for sparse inputs.
3838

3939
## Number of inputs and outputs
@@ -51,7 +51,7 @@ We now detail each of these options.
5151

5252
### Multiple inputs or outputs | Derivatives needed
5353

54-
Say your forward mapping takes multiple input arrays and returns multiple output arrays, such that you want derivatives for all of them.
54+
Say your forward mapping takes multiple inputs and returns multiple outputs, such that you want derivatives for all of them.
5555

5656
The trick is to leverage [ComponentArrays.jl](https://github.com/jonniedie/ComponentArrays.jl) to wrap all the inputs inside a single a `ComponentVector`, and do the same for all the outputs.
5757
See the examples for a demonstration.
@@ -92,21 +92,6 @@ This is mainly useful when the solution procedure creates objects such as Jacobi
9292
In that case, you may want to write the conditions differentiation rules yourself.
9393
A more advanced application is given by [DifferentiableFrankWolfe.jl](https://github.com/gdalle/DifferentiableFrankWolfe.jl).
9494

95-
## Linear system
96-
97-
### Lazy or dense
98-
99-
Usually, dense Jacobians are more efficient in small dimension, while lazy operators become necessary in high dimension.
100-
This choice is made via the `lazy` type parameter of [`ImplicitFunction`](@ref), with `lazy = true` being the default.
101-
102-
### Picking a solver
103-
104-
The right linear solver to use depends on the Jacobian representation.
105-
You can usually stick to the default settings:
106-
107-
- the direct solver `\` for dense Jacobians
108-
- an iterative solver from [Krylov.jl](https://github.com/JuliaSmoothOptimizers/Krylov.jl) for lazy operators
109-
11095
## Modeling tips
11196

11297
### Writing conditions

src/implicit_function.jl

Lines changed: 81 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,31 @@
1-
struct DefaultLinearSolver end
1+
"""
2+
KrylovLinearSolver
3+
4+
Callable object that can solve linear systems `As = b` and `AS = b` in the same way that `\`.
5+
Uses an iterative solver from [Krylov.jl](https://github.com/JuliaSmoothOptimizers/Krylov.jl) under the hood.
6+
7+
# Note
28
3-
function (::DefaultLinearSolver)(A, b::AbstractVector)
9+
This name is not exported, and thus not part of the public API, but it is used in the [`ImplicitFunction`](@ref) constructors.
10+
"""
11+
struct KrylovLinearSolver end
12+
13+
"""
14+
(::KylovLinearSolver)(A, b::AbstractVector)
15+
16+
Solve a linear system with a single right-hand side.
17+
"""
18+
function (::KrylovLinearSolver)(A, b::AbstractVector)
419
x, stats = gmres(A, b)
520
return x
621
end
722

8-
function (::DefaultLinearSolver)(A, B::AbstractMatrix)
23+
"""
24+
(::KrylovLinearSolver)(A, B::AbstractMatrix)
25+
26+
Solve a linear system with multiple right-hand sides.
27+
"""
28+
function (::KrylovLinearSolver)(A, B::AbstractMatrix)
929
# X, stats = block_gmres(A, B) # https://github.com/JuliaSmoothOptimizers/Krylov.jl/issues/854
1030
X = mapreduce(hcat, eachcol(B)) do b
1131
first(gmres(A, b))
@@ -25,44 +45,28 @@ When a derivative is queried, the Jacobian of `y` is computed using the implicit
2545
2646
This requires solving a linear system `A * J = -B` where `A = ∂c/∂y`, `B = ∂c/∂x` and `J = ∂y/∂x`.
2747
48+
# Type parameters
49+
50+
- `lazy::Bool`: whether to represent `A` and `B` with a `LinearOperator` from [LinearOperators.jl](https://github.com/JuliaSmoothOptimizers/LinearOperators.jl) (`lazy = true`) or a dense Jacobian matrix (`lazy = false`)
51+
52+
Usually, dense Jacobians are more efficient in small dimension, while lazy operators become necessary in high dimension.
53+
The value of `lazy` must be chosen together with the `linear_solver`, see below.
54+
2855
# Fields
2956
3057
- `forward`: a callable computing `y(x)`, does not need to be compatible with automatic differentiation
3158
- `conditions`: a callable computing `c(x, y)`, must be compatible with automatic differentiation
32-
- `linear_solver`: a callable to solve the linear system `A * J = -B`
59+
- `linear_solver`: a callable to solve the linear system
3360
- `conditions_x_backend`: defines how the conditions will be differentiated with respect to the first argument `x`
3461
- `conditions_y_backend`: defines how the conditions will be differentiated with respect to the second argument `y`
3562
36-
# Type parameters
37-
38-
- `lazy`: whether to use a `LinearOperator` from [LinearOperators.jl](https://github.com/JuliaSmoothOptimizers/LinearOperators.jl) (`lazy = true`) or a dense Jacobian matrix (`lazy = false`) for `A` and `B`
39-
40-
# Constructors
41-
42-
ImplicitFunction{lazy}(
43-
forward, conditions;
44-
linear_solver, conditions_x_backend, conditions_x_backend
45-
)
46-
47-
ImplicitFunction(
48-
forward, conditions;
49-
linear_solver, conditions_x_backend, conditions_x_backend
50-
)
51-
52-
Default values:
53-
54-
- `lazy = true`
55-
- `linear_solver`: the direct solver `\` for dense Jacobians, or an iterative solver from [Krylov.jl](https://github.com/JuliaSmoothOptimizers/Krylov.jl) for lazy operators
56-
- `conditions_x_backend = nothing`
57-
- `conditions_y_backend = nothing`
58-
5963
# Function signatures
6064
6165
There are two possible signatures for `forward` and `conditions`, which must be consistent with one another:
6266
6367
| standard | byproduct |
6468
|:---|:---|
65-
| `forward(x, args...; kwargs...) = y` | `conditions(x, y, args...; kwargs...) = c` |
69+
| `forward(x, args...; kwargs...) = y` | `conditions(x, y, args...; kwargs...) = c` |
6670
| `forward(x, args...; kwargs...) = (y, z)` | `conditions(x, y, z, args...; kwargs...) = c` |
6771
6872
In both cases, `x`, `y` and `c` must be `AbstractVector`s, with `length(y) = length(c)`.
@@ -75,13 +79,16 @@ The byproduct `z` and the other positional arguments `args...` beyond `x` are co
7579
7680
The provided `linear_solver` objects needs to be callable, with two methods:
7781
- `(A, b::AbstractVector) -> s::AbstractVector` such that `A * s = b`
78-
- `(A, B::AbstractVector) -> S::AbstractMatrix` such that `A * S = B`
82+
- `(A, B::AbstractVector) -> S::AbstractMatrix` such that `A * S = B`
83+
84+
It can be either a direct solver (like `\`) or an iterative one (like [`KrylovLinearSolver`](@ref)).
85+
Typically, direct solvers work best with dense Jacobians (`lazy = false`) while iterative solvers work best with operators (`lazy = true`).
7986
8087
# Condition backends
8188
8289
The provided `conditions_x_backend` and `conditions_y_backend` can be either:
83-
- `nothing`, in which case the outer backend (the one differentiating through the `ImplicitFunction`) is used
84-
- an object subtyping `AbstractADType` from [ADTypes.jl](https://github.com/SciML/ADTypes.jl).
90+
- an object subtyping `AbstractADType` from [ADTypes.jl](https://github.com/SciML/ADTypes.jl);
91+
- `nothing`, in which case the outer backend (the one differentiating through the `ImplicitFunction`) is used.
8592
"""
8693
struct ImplicitFunction{
8794
lazy,F,C,L,B1<:Union{Nothing,AbstractADType},B2<:Union{Nothing,AbstractADType}
@@ -93,10 +100,28 @@ struct ImplicitFunction{
93100
conditions_y_backend::B2
94101
end
95102

103+
"""
104+
ImplicitFunction{lazy}(
105+
forward, conditions;
106+
linear_solver=if lazy
107+
KrylovLinearSolver()
108+
else
109+
\
110+
end,
111+
conditions_x_backend=nothing,
112+
conditions_x_backend=nothing,
113+
)
114+
115+
Constructor for an [`ImplicitFunction`](@ref) which picks the `linear_solver` automatically based on the `lazy` parameter.
116+
"""
96117
function ImplicitFunction{lazy}(
97118
forward::F,
98119
conditions::C;
99-
linear_solver::L=lazy ? DefaultLinearSolver() : \,
120+
linear_solver::L=if lazy
121+
KrylovLinearSolver()
122+
else
123+
\
124+
end,
100125
conditions_x_backend::B1=nothing,
101126
conditions_y_backend::B2=nothing,
102127
) where {lazy,F,C,L,B1,B2}
@@ -105,8 +130,29 @@ function ImplicitFunction{lazy}(
105130
)
106131
end
107132

108-
function ImplicitFunction(forward, conditions; kwargs...)
109-
return ImplicitFunction{true}(forward, conditions; kwargs...)
133+
"""
134+
ImplicitFunction(
135+
forward, conditions;
136+
linear_solver=KrylovLinearSolver(),
137+
conditions_x_backend=nothing,
138+
conditions_x_backend=nothing,
139+
)
140+
141+
Constructor for an [`ImplicitFunction`](@ref) which picks the `lazy` parameter automatically based on the `linear_solver`, using the following heuristic:
142+
143+
lazy = linear_solver != \
144+
"""
145+
function ImplicitFunction(
146+
forward,
147+
conditions;
148+
linear_solver=KrylovLinearSolver(),
149+
conditions_x_backend=nothing,
150+
conditions_y_backend=nothing,
151+
)
152+
lazy = linear_solver != \
153+
return ImplicitFunction{lazy}(
154+
forward, conditions; linear_solver, conditions_x_backend, conditions_y_backend
155+
)
110156
end
111157

112158
function Base.show(io::IO, implicit::ImplicitFunction{lazy}) where {lazy}
@@ -119,7 +165,7 @@ function Base.show(io::IO, implicit::ImplicitFunction{lazy}) where {lazy}
119165
end
120166

121167
"""
122-
(implicit::ImplicitFunction)(x::AbstractArray, args...; kwargs...)
168+
(implicit::ImplicitFunction)(x::AbstractVector, args...; kwargs...)
123169
124170
Return `implicit.forward(x, args...; kwargs...)`, which can be either an `AbstractVector` `y` or a tuple `(y, z)`.
125171

test/systematic.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ backends = [
1818

1919
linear_solver_candidates = (
2020
\, #
21-
ID.DefaultLinearSolver(),
21+
ID.KrylovLinearSolver(),
2222
)
2323

2424
conditions_backend_candidates = (

test/utils.jl

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -31,35 +31,31 @@ end
3131

3232
## Various signatures
3333

34-
function make_implicit_sqrt(; linear_solver, kwargs...)
35-
lazy = !(linear_solver isa typeof(\))
34+
function make_implicit_sqrt(; kwargs...)
3635
forward(x) = mysqrt(x)
3736
conditions(x, y) = abs2.(y) .- abs.(x)
38-
implicit = ImplicitFunction{lazy}(forward, conditions; kwargs...)
37+
implicit = ImplicitFunction(forward, conditions; kwargs...)
3938
return implicit
4039
end
4140

42-
function make_implicit_sqrt_byproduct(; linear_solver, kwargs...)
43-
lazy = !(linear_solver isa typeof(\))
41+
function make_implicit_sqrt_byproduct(; kwargs...)
4442
forward(x) = one(eltype(x)) .* mysqrt(x), one(eltype(x))
4543
conditions(x, y, z) = abs2.(y ./ z) .- abs.(x)
46-
implicit = ImplicitFunction{lazy}(forward, conditions; linear_solver, kwargs...)
44+
implicit = ImplicitFunction(forward, conditions; kwargs...)
4745
return implicit
4846
end
4947

50-
function make_implicit_sqrt_args(; linear_solver, kwargs...)
51-
lazy = !(linear_solver isa typeof(\))
48+
function make_implicit_sqrt_args(; kwargs...)
5249
forward(x, p) = p .* mysqrt(x)
5350
conditions(x, y, p) = abs2.(y ./ p) .- abs.(x)
54-
implicit = ImplicitFunction{lazy}(forward, conditions; linear_solver, kwargs...)
51+
implicit = ImplicitFunction(forward, conditions; kwargs...)
5552
return implicit
5653
end
5754

58-
function make_implicit_sqrt_kwargs(; linear_solver, kwargs...)
59-
lazy = !(linear_solver isa typeof(\))
55+
function make_implicit_sqrt_kwargs(; kwargs...)
6056
forward(x; p) = p .* mysqrt(x)
6157
conditions(x, y; p) = abs2.(y ./ p) .- abs.(x)
62-
implicit = ImplicitFunction{lazy}(forward, conditions; linear_solver, kwargs...)
58+
implicit = ImplicitFunction(forward, conditions; kwargs...)
6359
return implicit
6460
end
6561

0 commit comments

Comments
 (0)