diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index a64887cd..9084de46 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -32,7 +32,7 @@ jobs: - uses: julia-actions/julia-buildpkg@v1 - uses: julia-actions/julia-runtest@v1 - uses: julia-actions/julia-processcoverage@v1 - - uses: codecov/codecov-action@v4 + - uses: codecov/codecov-action@v5 with: files: lcov.info token: ${{ secrets.CODECOV_TOKEN }} diff --git a/docs/src/api.md b/docs/src/api.md index 5c6c3da8..e264d9c9 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -22,6 +22,7 @@ ImplicitFunction MatrixRepresentation OperatorRepresentation IterativeLinearSolver +IterativeLeastSquaresSolver DirectLinearSolver ``` diff --git a/ext/ImplicitDifferentiationChainRulesCoreExt.jl b/ext/ImplicitDifferentiationChainRulesCoreExt.jl index 4b3d0397..89a977d3 100644 --- a/ext/ImplicitDifferentiationChainRulesCoreExt.jl +++ b/ext/ImplicitDifferentiationChainRulesCoreExt.jl @@ -1,31 +1,36 @@ module ImplicitDifferentiationChainRulesCoreExt -using ADTypes: AutoChainRules +using ADTypes: AutoChainRules, AutoForwardDiff using ChainRulesCore: ChainRulesCore, NoTangent, ProjectTo, RuleConfig using ChainRulesCore: unthunk, @not_implemented using ImplicitDifferentiation: ImplicitDifferentiation, ImplicitFunction, ImplicitFunctionPreparation, + IterativeLeastSquaresSolver, + build_A, build_Aᵀ, build_Bᵀ, - chainrules_suggested_backend + suggested_forward_backend, + suggested_reverse_backend # not covered by Codecov for now -ImplicitDifferentiation.chainrules_suggested_backend(rc::RuleConfig) = AutoChainRules(rc) +ImplicitDifferentiation.suggested_forward_backend(rc::RuleConfig) = AutoForwardDiff() +ImplicitDifferentiation.suggested_reverse_backend(rc::RuleConfig) = AutoChainRules(rc) -struct ImplicitPullback{TA,TB,TL,TC,TP,Nargs} +struct ImplicitPullback{Nargs,TA,TB,TA2,TL,TC,TP} Aᵀ::TA Bᵀ::TB + A::TA2 linear_solver::TL c0::TC project_x::TP _Nargs::Val{Nargs} end -function (pb::ImplicitPullback{TA,TB,TL,TC,TP,Nargs})((dy, dz)) where {TA,TB,TL,TP,TC,Nargs} - (; Aᵀ, Bᵀ, linear_solver, c0, project_x) = pb - dc = linear_solver(Aᵀ, -unthunk(dy), c0) +function (pb::ImplicitPullback{Nargs})((dy, dz)) where {Nargs} + (; Aᵀ, Bᵀ, A, linear_solver, c0, project_x) = pb + dc = linear_solver(Aᵀ, A, -unthunk(dy), c0) dx = Bᵀ(dc) df = NoTangent() dargs = ntuple(unimplemented_tangent, Val(Nargs)) @@ -40,13 +45,19 @@ function ChainRulesCore.rrule( c = conditions(x, y, z, args...) c0 = zero(c) - suggested_backend = chainrules_suggested_backend(rc) + forward_backend = suggested_forward_backend(rc) + reverse_backend = suggested_reverse_backend(rc) prep = ImplicitFunctionPreparation(eltype(x)) - Aᵀ = build_Aᵀ(implicit, prep, x, y, z, c, args...; suggested_backend) - Bᵀ = build_Bᵀ(implicit, prep, x, y, z, c, args...; suggested_backend) + Aᵀ = build_Aᵀ(implicit, prep, x, y, z, c, args...; suggested_backend=reverse_backend) + Bᵀ = build_Bᵀ(implicit, prep, x, y, z, c, args...; suggested_backend=reverse_backend) + if linear_solver isa IterativeLeastSquaresSolver + A = build_A(implicit, prep, x, y, z, c, args...; suggested_backend=forward_backend) + else + A = nothing + end project_x = ProjectTo(x) - implicit_pullback = ImplicitPullback(Aᵀ, Bᵀ, linear_solver, c0, project_x, Val(N)) + implicit_pullback = ImplicitPullback(Aᵀ, Bᵀ, A, linear_solver, c0, project_x, Val(N)) return (y, z), implicit_pullback end diff --git a/ext/ImplicitDifferentiationForwardDiffExt.jl b/ext/ImplicitDifferentiationForwardDiffExt.jl index 54fa761f..4ffd9d97 100644 --- a/ext/ImplicitDifferentiationForwardDiffExt.jl +++ b/ext/ImplicitDifferentiationForwardDiffExt.jl @@ -3,29 +3,36 @@ module ImplicitDifferentiationForwardDiffExt using ADTypes: AutoForwardDiff using ForwardDiff: Dual, Partials, partials, value using ImplicitDifferentiation: - ImplicitFunction, ImplicitFunctionPreparation, build_A, build_B + ImplicitFunction, + ImplicitFunctionPreparation, + IterativeLeastSquaresSolver, + build_A, + build_Aᵀ, + build_B function (implicit::ImplicitFunction)( prep::ImplicitFunctionPreparation{R}, x_and_dx::AbstractArray{Dual{T,R,N}}, args... ) where {T,R,N} + (; conditions, linear_solver) = implicit x = value.(x_and_dx) y, z = implicit(x, args...) - c = implicit.conditions(x, y, z, args...) + c = conditions(x, y, z, args...) y0 = zero(y) suggested_backend = AutoForwardDiff() A = build_A(implicit, prep, x, y, z, c, args...; suggested_backend) B = build_B(implicit, prep, x, y, z, c, args...; suggested_backend) - - dX = ntuple(Val(N)) do k - partials.(x_and_dx, k) + Aᵀ = if linear_solver isa IterativeLeastSquaresSolver + build_Aᵀ(implicit, prep, x, y, z, c, args...; suggested_backend) + else + nothing end + + dX = ntuple(k -> partials.(x_and_dx, k), Val(N)) dC = map(B, dX) dY = map(dC) do dₖc - dₖy = implicit.linear_solver(A, -dₖc, y0) - return dₖy + linear_solver(A, Aᵀ, -dₖc, y0) end - y_and_dy = map(y, LinearIndices(y)) do yi, i Dual{T}(yi, Partials(ntuple(k -> dY[k][i], Val(N)))) end diff --git a/ext/ImplicitDifferentiationZygoteExt.jl b/ext/ImplicitDifferentiationZygoteExt.jl index bd9e7844..ca44d8f6 100644 --- a/ext/ImplicitDifferentiationZygoteExt.jl +++ b/ext/ImplicitDifferentiationZygoteExt.jl @@ -1,9 +1,10 @@ module ImplicitDifferentiationZygoteExt -using ADTypes: AutoZygote +using ADTypes: AutoForwardDiff, AutoZygote using ImplicitDifferentiation: ImplicitDifferentiation using Zygote: ZygoteRuleConfig -ImplicitDifferentiation.chainrules_suggested_backend(::ZygoteRuleConfig) = AutoZygote() +ImplicitDifferentiation.suggested_forward_backend(::ZygoteRuleConfig) = AutoForwardDiff() +ImplicitDifferentiation.suggested_reverse_backend(::ZygoteRuleConfig) = AutoZygote() end diff --git a/src/ImplicitDifferentiation.jl b/src/ImplicitDifferentiation.jl index d5880daa..33e3bb51 100644 --- a/src/ImplicitDifferentiation.jl +++ b/src/ImplicitDifferentiation.jl @@ -19,8 +19,8 @@ using DifferentiationInterface: prepare_pushforward_same_point, pullback, pushforward -using KrylovKit: linsolve -using LinearAlgebra: factorize +using KrylovKit: linsolve, lssolve +using LinearAlgebra: Factorization, factorize include("utils.jl") include("settings.jl") @@ -30,7 +30,7 @@ include("execution.jl") include("callable.jl") export MatrixRepresentation, OperatorRepresentation -export IterativeLinearSolver, DirectLinearSolver +export IterativeLinearSolver, IterativeLeastSquaresSolver, DirectLinearSolver export ImplicitFunction end diff --git a/src/implicit_function.jl b/src/implicit_function.jl index fcb63768..692f0ca8 100644 --- a/src/implicit_function.jl +++ b/src/implicit_function.jl @@ -34,14 +34,14 @@ This requires solving a linear system `A * J = -B` where `A = ∂₂c`, `B = ∂ ## Keyword arguments - `representation`: defines how the partial Jacobian `A` of the conditions with respect to the output is represented. It can be either [`MatrixRepresentation`](@ref) or [`OperatorRepresentation`](@ref). -- `linear_solver`: specifies how the linear system `A * J = -B` will be solved in the implicit function theorem. It can be either [`DirectLinearSolver`](@ref) or [`IterativeLinearSolver`](@ref). +- `linear_solver`: specifies how the linear system `A * J = -B` will be solved in the implicit function theorem. It can be either [`DirectLinearSolver`](@ref), [`IterativeLinearSolver`](@ref) or [`IterativeLeastSquaresSolver`](@ref). - `backends::AbstractADType`: specifies how the `conditions` will be differentiated with respect to `x` and `y`. It can be either, `nothing`, which means that the external autodiff system will be used, or a named tuple `(; x=AutoSomething(), y=AutoSomethingElse())` of backend objects from [ADTypes.jl](https://github.com/SciML/ADTypes.jl). - `strict::Val`: specifies whether preparation inside [DifferentiationInterface.jl](https://github.com/JuliaDiff/DifferentiationInterface.jl) should enforce a strict match between the primal variables and the provided tangents. """ struct ImplicitFunction{ F, C, - L, + L<:AbstractSolver, R<:AbstractRepresentation, B<:Union{ Nothing, # diff --git a/src/settings.jl b/src/settings.jl index a917e21b..c5a3dcf1 100644 --- a/src/settings.jl +++ b/src/settings.jl @@ -1,47 +1,96 @@ ## Linear solver +abstract type AbstractSolver end + """ DirectLinearSolver Specify that linear systems `Ax = b` should be solved with a direct method. +!!! warning + Can only be used when the `solver` and the `conditions` both output an `AbstractVector`. + # See also - [`ImplicitFunction`](@ref) - [`IterativeLinearSolver`](@ref) +- [`IterativeLeastSquaresSolver`](@ref) """ -struct DirectLinearSolver end +struct DirectLinearSolver <: AbstractSolver end -function (solver::DirectLinearSolver)(A, b::AbstractVector, x0::AbstractVector) +function (solver::DirectLinearSolver)( + A::Union{AbstractMatrix,Factorization}, _Aᵀ, b::AbstractVector, x0::AbstractVector +) return A \ b end +abstract type AbstractIterativeSolver <: AbstractSolver end + """ IterativeLinearSolver Specify that linear systems `Ax = b` should be solved with an iterative method. +!!! warning + Can only be used when the `solver` and the `conditions` both output `AbstractArray`s with the same type and length. + # See also - [`ImplicitFunction`](@ref) - [`DirectLinearSolver`](@ref) +- [`IterativeLeastSquaresSolver`](@ref) """ -struct IterativeLinearSolver{K} +struct IterativeLinearSolver{K} <: AbstractIterativeSolver kwargs::K function IterativeLinearSolver(; kwargs...) return new{typeof(kwargs)}(kwargs) end end -function (solver::IterativeLinearSolver)(A, b, x0) +function (solver::IterativeLinearSolver)(A, _Aᵀ, b, x0) sol, info = linsolve(A, b, x0; solver.kwargs...) @assert info.converged == 1 return sol end -function Base.show(io::IO, linear_solver::IterativeLinearSolver) +""" + IterativeLeastSquaresSolver + +Specify that linear systems `Ax = b` should be solved with an iterative least-squares method. + +!!! tip + Can be used when the `solver` and the `conditions` output `AbstractArray`s with different types or different lengths. + +!!! warning + To ensure performance, remember to specify both `backends` used to differentiate `condtions`. + +# See also + +- [`ImplicitFunction`](@ref) +- [`DirectLinearSolver`](@ref) +- [`IterativeLinearSolver`](@ref) +""" +struct IterativeLeastSquaresSolver{K} <: AbstractIterativeSolver + kwargs::K + function IterativeLeastSquaresSolver(; kwargs...) + return new{typeof(kwargs)}(kwargs) + end +end + +function (solver::IterativeLeastSquaresSolver)(A, Aᵀ, b, x0) + sol, info = lssolve((A, Aᵀ), b; solver.kwargs...) + @assert info.converged == 1 + return sol +end + +function Base.show(io::IO, linear_solver::AbstractIterativeSolver) (; kwargs) = linear_solver - print(io, repr(IterativeLinearSolver; context=io), "(;") + T = if linear_solver isa IterativeLinearSolver + IterativeLinearSolver + else + IterativeLeastSquaresSolver + end + print(io, repr(T; context=io), "(;") for p in pairs(kwargs) print(io, " ", p[1], "=", repr(p[2]; context=io), ",") end @@ -76,4 +125,7 @@ Specify that the matrix `A` involved in the implicit function theorem should be """ struct OperatorRepresentation <: AbstractRepresentation end -function chainrules_suggested_backend end +## Backends + +function suggested_forward_backend end +function suggested_reverse_backend end diff --git a/test/printing.jl b/test/printing.jl index 0190649a..7194e778 100644 --- a/test/printing.jl +++ b/test/printing.jl @@ -4,4 +4,7 @@ using TestItems @test contains(string(ImplicitFunction(nothing, nothing)), "ImplicitFunction") @test contains(string(IterativeLinearSolver()), "IterativeLinearSolver") @test contains(string(IterativeLinearSolver(; rtol=1e-3)), "IterativeLinearSolver") + @test contains( + string(IterativeLeastSquaresSolver(; rtol=1e-3)), "IterativeLeastSquaresSolver" + ) end diff --git a/test/systematic.jl b/test/systematic.jl index 3d3f6576..cbb70055 100644 --- a/test/systematic.jl +++ b/test/systematic.jl @@ -28,7 +28,7 @@ end; [ IterativeLinearSolver(), IterativeLinearSolver(; rtol=1e-8), - IterativeLinearSolver(; issymmetric=true, isposdef=true), + IterativeLeastSquaresSolver(), ], [nothing, (; x=AutoForwardDiff(), y=AutoZygote())], [float.(1:3), reshape(float.(1:6), 3, 2)], @@ -53,7 +53,7 @@ end; solver=default_solver, conditions=default_conditions, x=x, - implicit_kwargs=(; strict=Val(false)), + implicit_kwargs=(; linear_solver=IterativeLeastSquaresSolver()), ) scen2 = add_arg_mult(scen) test_implicit(scen) diff --git a/test/utils.jl b/test/utils.jl index b7fcd460..b6f5f1cb 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -216,5 +216,7 @@ function test_implicit( end end -default_solver(x) = vcat(sqrt.(x .+ 2), -sqrt.(x)), 2 -default_conditions(x, y, z) = abs2.(y) .- vcat(x .+ z, x) +# use vcat to ensure Bᵀ != B +# use reverse to ensure Aᵀ != A +default_solver(x) = reverse(vcat(sqrt.(x .+ 2), -sqrt.(x))), 2 +default_conditions(x, y, z) = reverse(abs2.(y)) .- vcat(x .+ z, x)