Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,9 @@ ImplicitFunction
### Settings

```@docs
IterativeLinearSolver
MatrixRepresentation
OperatorRepresentation
NoPreparation
ForwardPreparation
ReversePreparation
BothPreparation
IterativeLinearSolver
```

## Internals
Expand Down
3 changes: 1 addition & 2 deletions src/ImplicitDifferentiation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,8 @@ include("preparation.jl")
include("implicit_function.jl")
include("execution.jl")

export IterativeLinearSolver
export MatrixRepresentation, OperatorRepresentation
export NoPreparation, ForwardPreparation, ReversePreparation, BothPreparation
export IterativeLinearSolver
export ImplicitFunction

end
57 changes: 10 additions & 47 deletions src/implicit_function.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,27 +22,24 @@ This requires solving a linear system `A * J = -B` where `A = ∂₂c`, `B = ∂
conditions;
representation=OperatorRepresentation(),
linear_solver=IterativeLinearSolver(),
backend=nothing,
backends=nothing,
preparation=nothing,
input_example=nothing,
)

## Positional arguments

- `solver`: a callable returning `(x, args...) -> (y, z)` where `z` is an arbitrary byproduct of the solve. Both `x` and `y` must be subtypes of `AbstractArray`, while `z` and `args` can be anything.
- `conditions`: a callable returning a vector of optimality conditions `(x, y, z, args...) -> c`, must be compatible with automatic differentiation
- `conditions`: a callable returning a vector of optimality conditions `(x, y, z, args...) -> c`, must be compatible with automatic differentiation.

## Keyword arguments

- `representation`: either [`MatrixRepresentation`](@ref) or [`OperatorRepresentation`](@ref)
- `linear_solver`: a callable to solve linear systems with two required methods, one for `(A, b)` (single solve) and one for `(A, B)` (batched solve). It defaults to [`IterativeLinearSolver`](@ref) but can also be the built-in `\\`, or a user-provided function.
- `backend::AbstractADType`: specifies how the `conditions` will be differentiated with respect to `x` and `y`. It can be either
- `nothing`, which means that the external autodiff system will be used
- a single object from [ADTypes.jl](https://github.com/SciML/ADTypes.jl)
- a named tuple `(; x, y)` of objects from [ADTypes.jl](https://github.com/SciML/ADTypes.jl)
- `representation`: defines how the partial Jacobian `A` of the conditions with respect to the output is represented, either [`MatrixRepresentation`](@ref) or [`OperatorRepresentation`](@ref).
- `linear_solver`: a callable to solve linear systems with two required methods, one for `(A, b::AbstractVector)` (single solve) and one for `(A, B::AbstractMatrix)` (batched solve). It defaults to [`IterativeLinearSolver`](@ref) but can also be the built-in `\\`, or a user-provided function.
- `backends::AbstractADType`: specifies how the `conditions` will be differentiated with respect to `x` and `y`. It can be either, `nothing`, which means that the external autodiff system will be used, or a named tuple `(; x=AutoSomething(), y=AutoSomethingElse())` of backend objects from [ADTypes.jl](https://github.com/SciML/ADTypes.jl).
- `preparation`: either `nothing` or a mode object from [ADTypes.jl](https://github.com/SciML/ADTypes.jl): `ADTypes.ForwardMode()`, `ADTypes.ReverseMode()` or `ADTypes.ForwardOrReverseMode()`.
- `input_example`: either `nothing` or a tuple `(x, args...)` used to prepare differentiation.
- `strict::Val=Val(true)`: whether or not to enforce a strict match in [DifferentiationInterface.jl](https://github.com/JuliaDiff/DifferentiationInterface.jl) between the preparation and the execution types. Relaxing this to `strict=Val(false)` can prove necessary when working with custom array types like ComponentArrays.jl, which are not always compatible with iterative linear solvers.
- `strict::Val=Val(true)`: whether or not to enforce a strict match in [DifferentiationInterface.jl](https://github.com/JuliaDiff/DifferentiationInterface.jl) between the preparation and the execution types.
"""
struct ImplicitFunction{
F,
Expand All @@ -51,7 +48,6 @@ struct ImplicitFunction{
R<:AbstractRepresentation,
B<:Union{
Nothing, #
AbstractADType,
NamedTuple{(:x, :y),<:Tuple{AbstractADType,AbstractADType}},
},
P<:Union{Nothing,AbstractMode},
Expand Down Expand Up @@ -90,59 +86,26 @@ function ImplicitFunction(
prep_B = nothing
prep_Bᵀ = nothing
else
real_backends = backends isa AbstractADType ? (; x=backends, y=backends) : backends
x, args = first(input_example), Base.tail(input_example)
y, z = solver(x, args...)
c = conditions(x, y, z, args...)
if preparation isa Union{ForwardMode,ForwardOrReverseMode}
prep_A = prepare_A(
representation,
x,
y,
z,
c,
args...;
conditions,
backend=real_backends.y,
strict,
representation, x, y, z, c, args...; conditions, backend=backends.y, strict
)
prep_B = prepare_B(
representation,
x,
y,
z,
c,
args...;
conditions,
backend=real_backends.x,
strict,
representation, x, y, z, c, args...; conditions, backend=backends.x, strict
)
else
prep_A = nothing
prep_B = nothing
end
if preparation isa Union{ReverseMode,ForwardOrReverseMode}
prep_Aᵀ = prepare_Aᵀ(
representation,
x,
y,
z,
c,
args...;
conditions,
backend=real_backends.y,
strict,
representation, x, y, z, c, args...; conditions, backend=backends.y, strict
)
prep_Bᵀ = prepare_Bᵀ(
representation,
x,
y,
z,
c,
args...;
conditions,
backend=real_backends.x,
strict,
representation, x, y, z, c, args...; conditions, backend=backends.x, strict
)
else
prep_Aᵀ = nothing
Expand Down
55 changes: 21 additions & 34 deletions src/settings.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@ Callable object that can solve linear systems `Ax = b` and `AX = B` in the same

# Constructor

IterativeLinearSolver(; kwargs...)
IterativeLinearSolver{package}(; kwargs...)

The type parameter `package` can be either:

- `:Krylov` to use the solver `gmres` from [Krylov.jl](https://github.com/JuliaSmoothOptimizers/Krylov.jl)
- `:Krylov` to use the solver `gmres` from [Krylov.jl](https://github.com/JuliaSmoothOptimizers/Krylov.jl) (the default)
- `:IterativeSolvers` to use the solver `gmres` from [IterativeSolvers.jl](https://github.com/JuliaLinearAlgebra/IterativeSolvers.jl)

Keyword arguments are passed on to the respective solver.
Expand All @@ -34,7 +35,15 @@ struct IterativeLinearSolver{package,K}
end
end

IterativeLinearSolver() = IterativeLinearSolver{:Krylov}()
function Base.show(io::IO, linear_solver::IterativeLinearSolver{package}) where {package}
print(io, "IterativeLinearSolver{$(repr(package))}(; ")
for (k, v) in pairs(linear_solver.kwargs)
print(io, "$k=$v, ")
end
return print(io, ")")
end

IterativeLinearSolver(; kwargs...) = IterativeLinearSolver{:Krylov}(; kwargs...)

function (solver::IterativeLinearSolver{:Krylov})(A, b::AbstractVector)
x, stats = Krylov.gmres(A, b; solver.kwargs...)
Expand Down Expand Up @@ -85,6 +94,7 @@ Specify that the matrix `A` involved in the implicit function theorem should be

# Constructors

OperatorRepresentation(; symmetric=false, hermitian=false)
OperatorRepresentation{package}(; symmetric=false, hermitian=false)

The type parameter `package` can be either:
Expand All @@ -108,38 +118,15 @@ struct OperatorRepresentation{package,symmetric,hermitian} <: AbstractRepresenta
end
end

OperatorRepresentation() = OperatorRepresentation{:LinearOperators}()

## Preparation

abstract type AbstractPreparation end

"""
ForwardPreparation

Specify that the derivatives of the conditions should be prepared for subsequent forward-mode differentiation of the implicit function.
"""
struct ForwardPreparation <: AbstractPreparation end

"""
ReversePreparation

Specify that the derivatives of the conditions should be prepared for subsequent reverse-mode differentiation of the implicit function.
"""
struct ReversePreparation <: AbstractPreparation end

"""
BothPreparation

Specify that the derivatives of the conditions should be prepared for subsequent forward- or reverse-mode differentiation of the implicit function.
"""
struct BothPreparation <: AbstractPreparation end

"""
NoPreparation
function Base.show(
io::IO, ::OperatorRepresentation{package,symmetric,hermitian}
) where {package,symmetric,hermitian}
return print(
io,
"OperatorRepresentation{$(repr(package))}(; symmetric=$symmetric, hermitian=$hermitian)",
)
end

Specify that the derivatives of the conditions should not be prepared for subsequent differentiation of the implicit function.
"""
struct NoPreparation <: AbstractPreparation end
OperatorRepresentation(; kwargs...) = OperatorRepresentation{:LinearOperators}(; kwargs...)

function chainrules_suggested_backend end
8 changes: 4 additions & 4 deletions test/examples.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
@testitem "intro" begin
@testitem "Intro" begin
include(joinpath(dirname(@__DIR__), "examples", "0_intro.jl"))
end

@testitem "basic" begin
@testitem "Basic" begin
include(joinpath(dirname(@__DIR__), "examples", "1_basic.jl"))
end

@testitem "advanced" begin
@testitem "Advanced" begin
include(joinpath(dirname(@__DIR__), "examples", "2_advanced.jl"))
end

@testitem "tricks" begin
@testitem "Tricks" begin
include(joinpath(dirname(@__DIR__), "examples", "3_tricks.jl"))
end
4 changes: 4 additions & 0 deletions test/formalities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,19 @@ using TestItems
using Zygote: Zygote
Aqua.test_all(ImplicitDifferentiation; ambiguities=false, undocumented_names=true)
end

@testitem "Formatting" begin
using JuliaFormatter
@test format(ImplicitDifferentiation; verbose=false, overwrite=false)
end

@testitem "Static checking" begin
using JET
using ForwardDiff: ForwardDiff
using Zygote: Zygote
JET.test_package(ImplicitDifferentiation; target_defined_modules=true)
end

@testitem "Imports" begin
using ExplicitImports
using ForwardDiff: ForwardDiff
Expand All @@ -27,6 +30,7 @@ end
@test check_all_qualified_accesses_via_owners(ImplicitDifferentiation) === nothing
@test check_no_self_qualified_accesses(ImplicitDifferentiation) === nothing
end

@testitem "Doctests" begin
using Documenter
Documenter.DocMeta.setdocmeta!(
Expand Down
63 changes: 63 additions & 0 deletions test/preparation.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
@testitem "Preparation" begin
using ImplicitDifferentiation
using ADTypes
using ADTypes: ForwardOrReverseMode, ForwardMode, ReverseMode
using ForwardDiff: ForwardDiff
using Zygote: Zygote
using Test

solver(x) = sqrt.(x), nothing
conditions(x, y, z) = y .^ 2 .- x
x = rand(5)
input_example = (x,)

@testset "None" begin
implicit_none = ImplicitFunction(solver, conditions)
@test implicit_none.prep_A === nothing
@test implicit_none.prep_Aᵀ === nothing
@test implicit_none.prep_B === nothing
@test implicit_none.prep_Bᵀ === nothing
end

@testset "ForwardMode" begin
implicit_forward = ImplicitFunction(
solver,
conditions;
preparation=ForwardMode(),
backends=(; x=AutoForwardDiff(), y=AutoForwardDiff()),
input_example,
)
@test implicit_forward.prep_A !== nothing
@test implicit_forward.prep_Aᵀ === nothing
@test implicit_forward.prep_B !== nothing
@test implicit_forward.prep_Bᵀ === nothing
end

@testset "ReverseMode" begin
implicit_reverse = ImplicitFunction(
solver,
conditions;
preparation=ReverseMode(),
backends=(; x=AutoZygote(), y=AutoZygote()),
input_example,
)
@test implicit_reverse.prep_A === nothing
@test implicit_reverse.prep_Aᵀ !== nothing
@test implicit_reverse.prep_B === nothing
@test implicit_reverse.prep_Bᵀ !== nothing
end

@testset "Both" begin
implicit_both = ImplicitFunction(
solver,
conditions;
preparation=ForwardOrReverseMode(),
backends=(; x=AutoForwardDiff(), y=AutoZygote()),
input_example,
)
@test implicit_both.prep_A !== nothing
@test implicit_both.prep_Aᵀ !== nothing
@test implicit_both.prep_B !== nothing
@test implicit_both.prep_Bᵀ !== nothing
end
end
6 changes: 6 additions & 0 deletions test/printing.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
using TestItems

@testitem "Settings" begin
@test startswith(string(OperatorRepresentation()), "Operator")
@test startswith(string(IterativeLinearSolver(; atol=1e-5)), "Iterative")
end
6 changes: 6 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
using TestItemRunner

@testmodule TestUtils begin
include("utils.jl")
export Scenario, test_implicit, add_arg_mult
export default_solver, default_conditions
end

@run_package_tests
Loading
Loading