Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ jobs:
- uses: julia-actions/julia-buildpkg@v1
- uses: julia-actions/julia-runtest@v1
- uses: julia-actions/julia-processcoverage@v1
- uses: codecov/codecov-action@v4
- uses: codecov/codecov-action@v5
with:
files: lcov.info
token: ${{ secrets.CODECOV_TOKEN }}
Expand Down
1 change: 1 addition & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ ImplicitFunction
MatrixRepresentation
OperatorRepresentation
IterativeLinearSolver
IterativeLeastSquaresSolver
DirectLinearSolver
```

Expand Down
33 changes: 22 additions & 11 deletions ext/ImplicitDifferentiationChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
@@ -1,31 +1,36 @@
module ImplicitDifferentiationChainRulesCoreExt

using ADTypes: AutoChainRules
using ADTypes: AutoChainRules, AutoForwardDiff
using ChainRulesCore: ChainRulesCore, NoTangent, ProjectTo, RuleConfig
using ChainRulesCore: unthunk, @not_implemented
using ImplicitDifferentiation:
ImplicitDifferentiation,
ImplicitFunction,
ImplicitFunctionPreparation,
IterativeLeastSquaresSolver,
build_A,
build_Aᵀ,
build_Bᵀ,
chainrules_suggested_backend
suggested_forward_backend,
suggested_reverse_backend

# not covered by Codecov for now
ImplicitDifferentiation.chainrules_suggested_backend(rc::RuleConfig) = AutoChainRules(rc)
ImplicitDifferentiation.suggested_forward_backend(rc::RuleConfig) = AutoForwardDiff()
ImplicitDifferentiation.suggested_reverse_backend(rc::RuleConfig) = AutoChainRules(rc)

struct ImplicitPullback{TA,TB,TL,TC,TP,Nargs}
struct ImplicitPullback{Nargs,TA,TB,TA2,TL,TC,TP}
Aᵀ::TA
Bᵀ::TB
A::TA2
linear_solver::TL
c0::TC
project_x::TP
_Nargs::Val{Nargs}
end

function (pb::ImplicitPullback{TA,TB,TL,TC,TP,Nargs})((dy, dz)) where {TA,TB,TL,TP,TC,Nargs}
(; Aᵀ, Bᵀ, linear_solver, c0, project_x) = pb
dc = linear_solver(Aᵀ, -unthunk(dy), c0)
function (pb::ImplicitPullback{Nargs})((dy, dz)) where {Nargs}
(; Aᵀ, Bᵀ, A, linear_solver, c0, project_x) = pb
dc = linear_solver(Aᵀ, A, -unthunk(dy), c0)
dx = Bᵀ(dc)
df = NoTangent()
dargs = ntuple(unimplemented_tangent, Val(Nargs))
Expand All @@ -40,13 +45,19 @@ function ChainRulesCore.rrule(
c = conditions(x, y, z, args...)
c0 = zero(c)

suggested_backend = chainrules_suggested_backend(rc)
forward_backend = suggested_forward_backend(rc)
reverse_backend = suggested_reverse_backend(rc)
prep = ImplicitFunctionPreparation(eltype(x))
Aᵀ = build_Aᵀ(implicit, prep, x, y, z, c, args...; suggested_backend)
Bᵀ = build_Bᵀ(implicit, prep, x, y, z, c, args...; suggested_backend)
Aᵀ = build_Aᵀ(implicit, prep, x, y, z, c, args...; suggested_backend=reverse_backend)
Bᵀ = build_Bᵀ(implicit, prep, x, y, z, c, args...; suggested_backend=reverse_backend)
if linear_solver isa IterativeLeastSquaresSolver
A = build_A(implicit, prep, x, y, z, c, args...; suggested_backend=forward_backend)
else
A = nothing
end
project_x = ProjectTo(x)

implicit_pullback = ImplicitPullback(Aᵀ, Bᵀ, linear_solver, c0, project_x, Val(N))
implicit_pullback = ImplicitPullback(Aᵀ, Bᵀ, A, linear_solver, c0, project_x, Val(N))
return (y, z), implicit_pullback
end

Expand Down
23 changes: 15 additions & 8 deletions ext/ImplicitDifferentiationForwardDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,29 +3,36 @@ module ImplicitDifferentiationForwardDiffExt
using ADTypes: AutoForwardDiff
using ForwardDiff: Dual, Partials, partials, value
using ImplicitDifferentiation:
ImplicitFunction, ImplicitFunctionPreparation, build_A, build_B
ImplicitFunction,
ImplicitFunctionPreparation,
IterativeLeastSquaresSolver,
build_A,
build_Aᵀ,
build_B

function (implicit::ImplicitFunction)(
prep::ImplicitFunctionPreparation{R}, x_and_dx::AbstractArray{Dual{T,R,N}}, args...
) where {T,R,N}
(; conditions, linear_solver) = implicit
x = value.(x_and_dx)
y, z = implicit(x, args...)
c = implicit.conditions(x, y, z, args...)
c = conditions(x, y, z, args...)
y0 = zero(y)

suggested_backend = AutoForwardDiff()
A = build_A(implicit, prep, x, y, z, c, args...; suggested_backend)
B = build_B(implicit, prep, x, y, z, c, args...; suggested_backend)

dX = ntuple(Val(N)) do k
partials.(x_and_dx, k)
Aᵀ = if linear_solver isa IterativeLeastSquaresSolver
build_Aᵀ(implicit, prep, x, y, z, c, args...; suggested_backend)
else
nothing
end

dX = ntuple(k -> partials.(x_and_dx, k), Val(N))
dC = map(B, dX)
dY = map(dC) do dₖc
dₖy = implicit.linear_solver(A, -dₖc, y0)
return dₖy
linear_solver(A, Aᵀ, -dₖc, y0)
end

y_and_dy = map(y, LinearIndices(y)) do yi, i
Dual{T}(yi, Partials(ntuple(k -> dY[k][i], Val(N))))
end
Expand Down
5 changes: 3 additions & 2 deletions ext/ImplicitDifferentiationZygoteExt.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
module ImplicitDifferentiationZygoteExt

using ADTypes: AutoZygote
using ADTypes: AutoForwardDiff, AutoZygote
using ImplicitDifferentiation: ImplicitDifferentiation
using Zygote: ZygoteRuleConfig

ImplicitDifferentiation.chainrules_suggested_backend(::ZygoteRuleConfig) = AutoZygote()
ImplicitDifferentiation.suggested_forward_backend(::ZygoteRuleConfig) = AutoForwardDiff()
ImplicitDifferentiation.suggested_reverse_backend(::ZygoteRuleConfig) = AutoZygote()

end
6 changes: 3 additions & 3 deletions src/ImplicitDifferentiation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ using DifferentiationInterface:
prepare_pushforward_same_point,
pullback,
pushforward
using KrylovKit: linsolve
using LinearAlgebra: factorize
using KrylovKit: linsolve, lssolve
using LinearAlgebra: Factorization, factorize

include("utils.jl")
include("settings.jl")
Expand All @@ -30,7 +30,7 @@ include("execution.jl")
include("callable.jl")

export MatrixRepresentation, OperatorRepresentation
export IterativeLinearSolver, DirectLinearSolver
export IterativeLinearSolver, IterativeLeastSquaresSolver, DirectLinearSolver
export ImplicitFunction

end
4 changes: 2 additions & 2 deletions src/implicit_function.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,14 @@ This requires solving a linear system `A * J = -B` where `A = ∂₂c`, `B = ∂
## Keyword arguments

- `representation`: defines how the partial Jacobian `A` of the conditions with respect to the output is represented. It can be either [`MatrixRepresentation`](@ref) or [`OperatorRepresentation`](@ref).
- `linear_solver`: specifies how the linear system `A * J = -B` will be solved in the implicit function theorem. It can be either [`DirectLinearSolver`](@ref) or [`IterativeLinearSolver`](@ref).
- `linear_solver`: specifies how the linear system `A * J = -B` will be solved in the implicit function theorem. It can be either [`DirectLinearSolver`](@ref), [`IterativeLinearSolver`](@ref) or [`IterativeLeastSquaresSolver`](@ref).
- `backends::AbstractADType`: specifies how the `conditions` will be differentiated with respect to `x` and `y`. It can be either, `nothing`, which means that the external autodiff system will be used, or a named tuple `(; x=AutoSomething(), y=AutoSomethingElse())` of backend objects from [ADTypes.jl](https://github.com/SciML/ADTypes.jl).
- `strict::Val`: specifies whether preparation inside [DifferentiationInterface.jl](https://github.com/JuliaDiff/DifferentiationInterface.jl) should enforce a strict match between the primal variables and the provided tangents.
"""
struct ImplicitFunction{
F,
C,
L,
L<:AbstractSolver,
R<:AbstractRepresentation,
B<:Union{
Nothing, #
Expand Down
66 changes: 59 additions & 7 deletions src/settings.jl
Original file line number Diff line number Diff line change
@@ -1,47 +1,96 @@
## Linear solver

abstract type AbstractSolver end

"""
DirectLinearSolver

Specify that linear systems `Ax = b` should be solved with a direct method.

!!! warning
Can only be used when the `solver` and the `conditions` both output an `AbstractVector`.

# See also

- [`ImplicitFunction`](@ref)
- [`IterativeLinearSolver`](@ref)
- [`IterativeLeastSquaresSolver`](@ref)
"""
struct DirectLinearSolver end
struct DirectLinearSolver <: AbstractSolver end

function (solver::DirectLinearSolver)(A, b::AbstractVector, x0::AbstractVector)
function (solver::DirectLinearSolver)(
A::Union{AbstractMatrix,Factorization}, _Aᵀ, b::AbstractVector, x0::AbstractVector
)
return A \ b
end

abstract type AbstractIterativeSolver <: AbstractSolver end

"""
IterativeLinearSolver

Specify that linear systems `Ax = b` should be solved with an iterative method.

!!! warning
Can only be used when the `solver` and the `conditions` both output `AbstractArray`s with the same type and length.

# See also

- [`ImplicitFunction`](@ref)
- [`DirectLinearSolver`](@ref)
- [`IterativeLeastSquaresSolver`](@ref)
"""
struct IterativeLinearSolver{K}
struct IterativeLinearSolver{K} <: AbstractIterativeSolver
kwargs::K
function IterativeLinearSolver(; kwargs...)
return new{typeof(kwargs)}(kwargs)
end
end

function (solver::IterativeLinearSolver)(A, b, x0)
function (solver::IterativeLinearSolver)(A, _Aᵀ, b, x0)
sol, info = linsolve(A, b, x0; solver.kwargs...)
@assert info.converged == 1
return sol
end

function Base.show(io::IO, linear_solver::IterativeLinearSolver)
"""
IterativeLeastSquaresSolver

Specify that linear systems `Ax = b` should be solved with an iterative least-squares method.

!!! tip
Can be used when the `solver` and the `conditions` output `AbstractArray`s with different types or different lengths.

!!! warning
To ensure performance, remember to specify both `backends` used to differentiate `condtions`.

# See also

- [`ImplicitFunction`](@ref)
- [`DirectLinearSolver`](@ref)
- [`IterativeLinearSolver`](@ref)
"""
struct IterativeLeastSquaresSolver{K} <: AbstractIterativeSolver
kwargs::K
function IterativeLeastSquaresSolver(; kwargs...)
return new{typeof(kwargs)}(kwargs)
end
end

function (solver::IterativeLeastSquaresSolver)(A, Aᵀ, b, x0)
sol, info = lssolve((A, Aᵀ), b; solver.kwargs...)
@assert info.converged == 1
return sol
end

function Base.show(io::IO, linear_solver::AbstractIterativeSolver)
(; kwargs) = linear_solver
print(io, repr(IterativeLinearSolver; context=io), "(;")
T = if linear_solver isa IterativeLinearSolver
IterativeLinearSolver
else
IterativeLeastSquaresSolver
end
print(io, repr(T; context=io), "(;")
for p in pairs(kwargs)
print(io, " ", p[1], "=", repr(p[2]; context=io), ",")
end
Expand Down Expand Up @@ -76,4 +125,7 @@ Specify that the matrix `A` involved in the implicit function theorem should be
"""
struct OperatorRepresentation <: AbstractRepresentation end

function chainrules_suggested_backend end
## Backends

function suggested_forward_backend end
function suggested_reverse_backend end
3 changes: 3 additions & 0 deletions test/printing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,7 @@ using TestItems
@test contains(string(ImplicitFunction(nothing, nothing)), "ImplicitFunction")
@test contains(string(IterativeLinearSolver()), "IterativeLinearSolver")
@test contains(string(IterativeLinearSolver(; rtol=1e-3)), "IterativeLinearSolver")
@test contains(
string(IterativeLeastSquaresSolver(; rtol=1e-3)), "IterativeLeastSquaresSolver"
)
end
4 changes: 2 additions & 2 deletions test/systematic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ end;
[
IterativeLinearSolver(),
IterativeLinearSolver(; rtol=1e-8),
IterativeLinearSolver(; issymmetric=true, isposdef=true),
IterativeLeastSquaresSolver(),
],
[nothing, (; x=AutoForwardDiff(), y=AutoZygote())],
[float.(1:3), reshape(float.(1:6), 3, 2)],
Expand All @@ -53,7 +53,7 @@ end;
solver=default_solver,
conditions=default_conditions,
x=x,
implicit_kwargs=(; strict=Val(false)),
implicit_kwargs=(; linear_solver=IterativeLeastSquaresSolver()),
)
scen2 = add_arg_mult(scen)
test_implicit(scen)
Expand Down
6 changes: 4 additions & 2 deletions test/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -216,5 +216,7 @@ function test_implicit(
end
end

default_solver(x) = vcat(sqrt.(x .+ 2), -sqrt.(x)), 2
default_conditions(x, y, z) = abs2.(y) .- vcat(x .+ z, x)
# use vcat to ensure Bᵀ != B
# use reverse to ensure Aᵀ != A
default_solver(x) = reverse(vcat(sqrt.(x .+ 2), -sqrt.(x))), 2
default_conditions(x, y, z) = reverse(abs2.(y)) .- vcat(x .+ z, x)
Loading