From 12c80a71da3759c369c6bba045687052b2b7405e Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Wed, 23 Jul 2025 21:42:37 +0200 Subject: [PATCH 1/7] feat: add least-squares solver --- ...mplicitDifferentiationChainRulesCoreExt.jl | 32 +++++---- ext/ImplicitDifferentiationForwardDiffExt.jl | 23 ++++--- ext/ImplicitDifferentiationZygoteExt.jl | 5 +- src/ImplicitDifferentiation.jl | 4 +- src/implicit_function.jl | 4 +- src/settings.jl | 66 +++++++++++++++++-- 6 files changed, 102 insertions(+), 32 deletions(-) diff --git a/ext/ImplicitDifferentiationChainRulesCoreExt.jl b/ext/ImplicitDifferentiationChainRulesCoreExt.jl index 4b3d0397..91b511ee 100644 --- a/ext/ImplicitDifferentiationChainRulesCoreExt.jl +++ b/ext/ImplicitDifferentiationChainRulesCoreExt.jl @@ -1,31 +1,35 @@ module ImplicitDifferentiationChainRulesCoreExt -using ADTypes: AutoChainRules +using ADTypes: AutoChainRules, AutoForwardDiff using ChainRulesCore: ChainRulesCore, NoTangent, ProjectTo, RuleConfig using ChainRulesCore: unthunk, @not_implemented using ImplicitDifferentiation: ImplicitDifferentiation, ImplicitFunction, ImplicitFunctionPreparation, + IterativeLeastSquaresSolver, build_Aᵀ, build_Bᵀ, - chainrules_suggested_backend + suggested_forward_backend, + suggested_reverse_backend # not covered by Codecov for now -ImplicitDifferentiation.chainrules_suggested_backend(rc::RuleConfig) = AutoChainRules(rc) +ImplicitDifferentiation.suggested_forward_backend(rc::RuleConfig) = AutoForwardDiff() +ImplicitDifferentiation.suggested_reverse_backend(rc::RuleConfig) = AutoChainRules(rc) -struct ImplicitPullback{TA,TB,TL,TC,TP,Nargs} +struct ImplicitPullback{Nargs,TA,TB,TA2,TL,TC,TP} Aᵀ::TA Bᵀ::TB + A::TA2 linear_solver::TL c0::TC project_x::TP _Nargs::Val{Nargs} end -function (pb::ImplicitPullback{TA,TB,TL,TC,TP,Nargs})((dy, dz)) where {TA,TB,TL,TP,TC,Nargs} - (; Aᵀ, Bᵀ, linear_solver, c0, project_x) = pb - dc = linear_solver(Aᵀ, -unthunk(dy), c0) +function (pb::ImplicitPullback{Nargs})((dy, dz)) where {Nargs} + (; Aᵀ, Bᵀ, A, linear_solver, c0, project_x) = pb + dc = linear_solver(Aᵀ, A, -unthunk(dy), c0) dx = Bᵀ(dc) df = NoTangent() dargs = ntuple(unimplemented_tangent, Val(Nargs)) @@ -40,13 +44,19 @@ function ChainRulesCore.rrule( c = conditions(x, y, z, args...) c0 = zero(c) - suggested_backend = chainrules_suggested_backend(rc) + forward_backend = suggested_forward_backend(rc) + reverse_backend = suggested_reverse_backend(rc) prep = ImplicitFunctionPreparation(eltype(x)) - Aᵀ = build_Aᵀ(implicit, prep, x, y, z, c, args...; suggested_backend) - Bᵀ = build_Bᵀ(implicit, prep, x, y, z, c, args...; suggested_backend) + Aᵀ = build_Aᵀ(implicit, prep, x, y, z, c, args...; suggested_backend=reverse_backend) + Bᵀ = build_Bᵀ(implicit, prep, x, y, z, c, args...; suggested_backend=reverse_backend) + if linear_solver isa IterativeLeastSquaresSolver + A = build_Aᵀ(implicit, prep, x, y, z, c, args...; suggested_backend=forward_backend) + else + A = nothing + end project_x = ProjectTo(x) - implicit_pullback = ImplicitPullback(Aᵀ, Bᵀ, linear_solver, c0, project_x, Val(N)) + implicit_pullback = ImplicitPullback(Aᵀ, Bᵀ, A, linear_solver, c0, project_x, Val(N)) return (y, z), implicit_pullback end diff --git a/ext/ImplicitDifferentiationForwardDiffExt.jl b/ext/ImplicitDifferentiationForwardDiffExt.jl index 54fa761f..d412ba60 100644 --- a/ext/ImplicitDifferentiationForwardDiffExt.jl +++ b/ext/ImplicitDifferentiationForwardDiffExt.jl @@ -3,29 +3,36 @@ module ImplicitDifferentiationForwardDiffExt using ADTypes: AutoForwardDiff using ForwardDiff: Dual, Partials, partials, value using ImplicitDifferentiation: - ImplicitFunction, ImplicitFunctionPreparation, build_A, build_B + ImplicitFunction, + ImplicitFunctionPreparation, + IterativeLeastSquaresSolver, + build_A, + build_Aᵀ, + build_B function (implicit::ImplicitFunction)( prep::ImplicitFunctionPreparation{R}, x_and_dx::AbstractArray{Dual{T,R,N}}, args... ) where {T,R,N} + (; conditions, linear_solver) = implicit x = value.(x_and_dx) y, z = implicit(x, args...) - c = implicit.conditions(x, y, z, args...) + c = conditions(x, y, z, args...) y0 = zero(y) suggested_backend = AutoForwardDiff() A = build_A(implicit, prep, x, y, z, c, args...; suggested_backend) B = build_B(implicit, prep, x, y, z, c, args...; suggested_backend) - - dX = ntuple(Val(N)) do k - partials.(x_and_dx, k) + if linear_solver isa IterativeLeastSquaresSolver + Aᵀ = build_Aᵀ(implicit, prep, x, y, z, c, args...; suggested_backend) + else + Aᵀ = nothing end + + dX = ntuple(k -> partials.(x_and_dx, k), Val(N)) dC = map(B, dX) dY = map(dC) do dₖc - dₖy = implicit.linear_solver(A, -dₖc, y0) - return dₖy + linear_solver(A, Aᵀ, -dₖc, y0) end - y_and_dy = map(y, LinearIndices(y)) do yi, i Dual{T}(yi, Partials(ntuple(k -> dY[k][i], Val(N)))) end diff --git a/ext/ImplicitDifferentiationZygoteExt.jl b/ext/ImplicitDifferentiationZygoteExt.jl index bd9e7844..ca44d8f6 100644 --- a/ext/ImplicitDifferentiationZygoteExt.jl +++ b/ext/ImplicitDifferentiationZygoteExt.jl @@ -1,9 +1,10 @@ module ImplicitDifferentiationZygoteExt -using ADTypes: AutoZygote +using ADTypes: AutoForwardDiff, AutoZygote using ImplicitDifferentiation: ImplicitDifferentiation using Zygote: ZygoteRuleConfig -ImplicitDifferentiation.chainrules_suggested_backend(::ZygoteRuleConfig) = AutoZygote() +ImplicitDifferentiation.suggested_forward_backend(::ZygoteRuleConfig) = AutoForwardDiff() +ImplicitDifferentiation.suggested_reverse_backend(::ZygoteRuleConfig) = AutoZygote() end diff --git a/src/ImplicitDifferentiation.jl b/src/ImplicitDifferentiation.jl index d5880daa..2447e9e0 100644 --- a/src/ImplicitDifferentiation.jl +++ b/src/ImplicitDifferentiation.jl @@ -19,7 +19,7 @@ using DifferentiationInterface: prepare_pushforward_same_point, pullback, pushforward -using KrylovKit: linsolve +using KrylovKit: linsolve, lssolve using LinearAlgebra: factorize include("utils.jl") @@ -30,7 +30,7 @@ include("execution.jl") include("callable.jl") export MatrixRepresentation, OperatorRepresentation -export IterativeLinearSolver, DirectLinearSolver +export IterativeLinearSolver, IterativeLeastSquaresSolver, DirectLinearSolver export ImplicitFunction end diff --git a/src/implicit_function.jl b/src/implicit_function.jl index fcb63768..692f0ca8 100644 --- a/src/implicit_function.jl +++ b/src/implicit_function.jl @@ -34,14 +34,14 @@ This requires solving a linear system `A * J = -B` where `A = ∂₂c`, `B = ∂ ## Keyword arguments - `representation`: defines how the partial Jacobian `A` of the conditions with respect to the output is represented. It can be either [`MatrixRepresentation`](@ref) or [`OperatorRepresentation`](@ref). -- `linear_solver`: specifies how the linear system `A * J = -B` will be solved in the implicit function theorem. It can be either [`DirectLinearSolver`](@ref) or [`IterativeLinearSolver`](@ref). +- `linear_solver`: specifies how the linear system `A * J = -B` will be solved in the implicit function theorem. It can be either [`DirectLinearSolver`](@ref), [`IterativeLinearSolver`](@ref) or [`IterativeLeastSquaresSolver`](@ref). - `backends::AbstractADType`: specifies how the `conditions` will be differentiated with respect to `x` and `y`. It can be either, `nothing`, which means that the external autodiff system will be used, or a named tuple `(; x=AutoSomething(), y=AutoSomethingElse())` of backend objects from [ADTypes.jl](https://github.com/SciML/ADTypes.jl). - `strict::Val`: specifies whether preparation inside [DifferentiationInterface.jl](https://github.com/JuliaDiff/DifferentiationInterface.jl) should enforce a strict match between the primal variables and the provided tangents. """ struct ImplicitFunction{ F, C, - L, + L<:AbstractSolver, R<:AbstractRepresentation, B<:Union{ Nothing, # diff --git a/src/settings.jl b/src/settings.jl index a917e21b..6b11d48b 100644 --- a/src/settings.jl +++ b/src/settings.jl @@ -1,47 +1,96 @@ ## Linear solver +abstract type AbstractSolver end + """ DirectLinearSolver Specify that linear systems `Ax = b` should be solved with a direct method. +!!! warning + Can only be used when the `solver` and the `conditions` both output an `AbstractVector`. + # See also - [`ImplicitFunction`](@ref) - [`IterativeLinearSolver`](@ref) +- [`IterativeLeastSquaresSolver`](@ref) """ -struct DirectLinearSolver end +struct DirectLinearSolver <: AbstractSolver end -function (solver::DirectLinearSolver)(A, b::AbstractVector, x0::AbstractVector) +function (solver::DirectLinearSolver)( + A::AbstractMatrix, _Aᵀ, b::AbstractVector, x0::AbstractVector +) return A \ b end +abstract type AbstractIterativeSolver <: AbstractSolver end + """ IterativeLinearSolver Specify that linear systems `Ax = b` should be solved with an iterative method. +!!! warning + Can only be used when the `solver` and the `conditions` both output `AbstractArray`s with the same type and length. + # See also - [`ImplicitFunction`](@ref) - [`DirectLinearSolver`](@ref) +- [`IterativeLeastSquaresSolver`](@ref) """ -struct IterativeLinearSolver{K} +struct IterativeLinearSolver{K} <: AbstractIterativeSolver kwargs::K function IterativeLinearSolver(; kwargs...) return new{typeof(kwargs)}(kwargs) end end -function (solver::IterativeLinearSolver)(A, b, x0) +function (solver::IterativeLinearSolver)(A, _Aᵀ, b, x0) sol, info = linsolve(A, b, x0; solver.kwargs...) @assert info.converged == 1 return sol end -function Base.show(io::IO, linear_solver::IterativeLinearSolver) +""" + IterativeLeastSquaresSolver + +Specify that linear systems `Ax = b` should be solved with an iterative least-squares method. + +!!! tip + Can be used when the `solver` and the `conditions` output `AbstractArray`s with different types or different lengths. + +!!! warning + To ensure performance, remember to specify both `backends` used to differentiate `condtions`. + +# See also + +- [`ImplicitFunction`](@ref) +- [`DirectLinearSolver`](@ref) +- [`IterativeLinearSolver`](@ref) +""" +struct IterativeLeastSquaresSolver{K} <: AbstractIterativeSolver + kwargs::K + function IterativeLeastSquaresSolver(; kwargs...) + return new{typeof(kwargs)}(kwargs) + end +end + +function (solver::IterativeLeastSquaresSolver)(A, Aᵀ, b, x0) + sol, info = lssolve((A, Aᵀ), b; solver.kwargs...) + @assert info.converged == 1 + return sol +end + +function Base.show(io::IO, linear_solver::AbstractIterativeSolver) (; kwargs) = linear_solver - print(io, repr(IterativeLinearSolver; context=io), "(;") + T = if linear_solver isa IterativeLinearSolver + IterativeLinearSolver + else + IterativeLeastSquaresSolver + end + print(io, repr(T; context=io), "(;") for p in pairs(kwargs) print(io, " ", p[1], "=", repr(p[2]; context=io), ",") end @@ -76,4 +125,7 @@ Specify that the matrix `A` involved in the implicit function theorem should be """ struct OperatorRepresentation <: AbstractRepresentation end -function chainrules_suggested_backend end +## Backends + +function suggested_forward_backend end +function suggested_reverse_backend end From 4df5cec4b3837d3e3ad410864e86a55d49f16ae4 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Sat, 16 Aug 2025 13:14:58 +0200 Subject: [PATCH 2/7] Fixes --- docs/src/api.md | 1 + test/systematic.jl | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/docs/src/api.md b/docs/src/api.md index 5c6c3da8..e264d9c9 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -22,6 +22,7 @@ ImplicitFunction MatrixRepresentation OperatorRepresentation IterativeLinearSolver +IterativeLeastSquaresSolver DirectLinearSolver ``` diff --git a/test/systematic.jl b/test/systematic.jl index 3d3f6576..2dc9f6bc 100644 --- a/test/systematic.jl +++ b/test/systematic.jl @@ -29,6 +29,7 @@ end; IterativeLinearSolver(), IterativeLinearSolver(; rtol=1e-8), IterativeLinearSolver(; issymmetric=true, isposdef=true), + IterativeLeastSquaresSolver(; issymmetric=true, isposdef=true), ], [nothing, (; x=AutoForwardDiff(), y=AutoZygote())], [float.(1:3), reshape(float.(1:6), 3, 2)], @@ -53,7 +54,7 @@ end; solver=default_solver, conditions=default_conditions, x=x, - implicit_kwargs=(; strict=Val(false)), + implicit_kwargs=(; linear_solver=IterativeLeastSquaresSolver()), ) scen2 = add_arg_mult(scen) test_implicit(scen) From 88cd8f56565f0e0f344239da5bd8126d71bb487e Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Sat, 16 Aug 2025 13:16:03 +0200 Subject: [PATCH 3/7] Kwargs --- test/systematic.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/systematic.jl b/test/systematic.jl index 2dc9f6bc..b0b67c12 100644 --- a/test/systematic.jl +++ b/test/systematic.jl @@ -29,7 +29,7 @@ end; IterativeLinearSolver(), IterativeLinearSolver(; rtol=1e-8), IterativeLinearSolver(; issymmetric=true, isposdef=true), - IterativeLeastSquaresSolver(; issymmetric=true, isposdef=true), + IterativeLeastSquaresSolver(), ], [nothing, (; x=AutoForwardDiff(), y=AutoZygote())], [float.(1:3), reshape(float.(1:6), 3, 2)], From 1dd6e746e270e0a7199d40783749f90dccc906f8 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Sat, 16 Aug 2025 17:35:32 +0200 Subject: [PATCH 4/7] Fix type-stability --- ext/ImplicitDifferentiationForwardDiffExt.jl | 6 +++--- test/systematic.jl | 1 - 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/ext/ImplicitDifferentiationForwardDiffExt.jl b/ext/ImplicitDifferentiationForwardDiffExt.jl index d412ba60..4ffd9d97 100644 --- a/ext/ImplicitDifferentiationForwardDiffExt.jl +++ b/ext/ImplicitDifferentiationForwardDiffExt.jl @@ -22,10 +22,10 @@ function (implicit::ImplicitFunction)( suggested_backend = AutoForwardDiff() A = build_A(implicit, prep, x, y, z, c, args...; suggested_backend) B = build_B(implicit, prep, x, y, z, c, args...; suggested_backend) - if linear_solver isa IterativeLeastSquaresSolver - Aᵀ = build_Aᵀ(implicit, prep, x, y, z, c, args...; suggested_backend) + Aᵀ = if linear_solver isa IterativeLeastSquaresSolver + build_Aᵀ(implicit, prep, x, y, z, c, args...; suggested_backend) else - Aᵀ = nothing + nothing end dX = ntuple(k -> partials.(x_and_dx, k), Val(N)) diff --git a/test/systematic.jl b/test/systematic.jl index b0b67c12..cbb70055 100644 --- a/test/systematic.jl +++ b/test/systematic.jl @@ -28,7 +28,6 @@ end; [ IterativeLinearSolver(), IterativeLinearSolver(; rtol=1e-8), - IterativeLinearSolver(; issymmetric=true, isposdef=true), IterativeLeastSquaresSolver(), ], [nothing, (; x=AutoForwardDiff(), y=AutoZygote())], From 8105687a1b734ffe285268038054a900c690a394 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Sat, 16 Aug 2025 18:23:40 +0200 Subject: [PATCH 5/7] Codecov v5 --- .github/workflows/CI.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index a64887cd..9084de46 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -32,7 +32,7 @@ jobs: - uses: julia-actions/julia-buildpkg@v1 - uses: julia-actions/julia-runtest@v1 - uses: julia-actions/julia-processcoverage@v1 - - uses: codecov/codecov-action@v4 + - uses: codecov/codecov-action@v5 with: files: lcov.info token: ${{ secrets.CODECOV_TOKEN }} From a80c8ee224eeea73ee574f729114f89acdef81b3 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Sat, 16 Aug 2025 20:33:34 +0200 Subject: [PATCH 6/7] Better tests --- ext/ImplicitDifferentiationChainRulesCoreExt.jl | 3 ++- test/printing.jl | 3 +++ test/utils.jl | 6 ++++-- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/ext/ImplicitDifferentiationChainRulesCoreExt.jl b/ext/ImplicitDifferentiationChainRulesCoreExt.jl index 91b511ee..89a977d3 100644 --- a/ext/ImplicitDifferentiationChainRulesCoreExt.jl +++ b/ext/ImplicitDifferentiationChainRulesCoreExt.jl @@ -8,6 +8,7 @@ using ImplicitDifferentiation: ImplicitFunction, ImplicitFunctionPreparation, IterativeLeastSquaresSolver, + build_A, build_Aᵀ, build_Bᵀ, suggested_forward_backend, @@ -50,7 +51,7 @@ function ChainRulesCore.rrule( Aᵀ = build_Aᵀ(implicit, prep, x, y, z, c, args...; suggested_backend=reverse_backend) Bᵀ = build_Bᵀ(implicit, prep, x, y, z, c, args...; suggested_backend=reverse_backend) if linear_solver isa IterativeLeastSquaresSolver - A = build_Aᵀ(implicit, prep, x, y, z, c, args...; suggested_backend=forward_backend) + A = build_A(implicit, prep, x, y, z, c, args...; suggested_backend=forward_backend) else A = nothing end diff --git a/test/printing.jl b/test/printing.jl index 0190649a..7194e778 100644 --- a/test/printing.jl +++ b/test/printing.jl @@ -4,4 +4,7 @@ using TestItems @test contains(string(ImplicitFunction(nothing, nothing)), "ImplicitFunction") @test contains(string(IterativeLinearSolver()), "IterativeLinearSolver") @test contains(string(IterativeLinearSolver(; rtol=1e-3)), "IterativeLinearSolver") + @test contains( + string(IterativeLeastSquaresSolver(; rtol=1e-3)), "IterativeLeastSquaresSolver" + ) end diff --git a/test/utils.jl b/test/utils.jl index b7fcd460..b6f5f1cb 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -216,5 +216,7 @@ function test_implicit( end end -default_solver(x) = vcat(sqrt.(x .+ 2), -sqrt.(x)), 2 -default_conditions(x, y, z) = abs2.(y) .- vcat(x .+ z, x) +# use vcat to ensure Bᵀ != B +# use reverse to ensure Aᵀ != A +default_solver(x) = reverse(vcat(sqrt.(x .+ 2), -sqrt.(x))), 2 +default_conditions(x, y, z) = reverse(abs2.(y)) .- vcat(x .+ z, x) From e9aba6f9b0f187b3ea3c4ef57f532091aeeaf460 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Sat, 16 Aug 2025 21:14:16 +0200 Subject: [PATCH 7/7] Allow Factorization --- src/ImplicitDifferentiation.jl | 2 +- src/settings.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/ImplicitDifferentiation.jl b/src/ImplicitDifferentiation.jl index 2447e9e0..33e3bb51 100644 --- a/src/ImplicitDifferentiation.jl +++ b/src/ImplicitDifferentiation.jl @@ -20,7 +20,7 @@ using DifferentiationInterface: pullback, pushforward using KrylovKit: linsolve, lssolve -using LinearAlgebra: factorize +using LinearAlgebra: Factorization, factorize include("utils.jl") include("settings.jl") diff --git a/src/settings.jl b/src/settings.jl index 6b11d48b..c5a3dcf1 100644 --- a/src/settings.jl +++ b/src/settings.jl @@ -19,7 +19,7 @@ Specify that linear systems `Ax = b` should be solved with a direct method. struct DirectLinearSolver <: AbstractSolver end function (solver::DirectLinearSolver)( - A::AbstractMatrix, _Aᵀ, b::AbstractVector, x0::AbstractVector + A::Union{AbstractMatrix,Factorization}, _Aᵀ, b::AbstractVector, x0::AbstractVector ) return A \ b end