From 76bb497755883f74c8a0d21df114ef9e12b0c760 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Wed, 2 Jul 2025 08:33:11 +0200 Subject: [PATCH 01/12] revamp!: use KrylovKit for type flexibility (beyond Vector) --- Project.toml | 32 +--- docs/Project.toml | 3 +- examples/3_tricks.jl | 5 +- ...mplicitDifferentiationChainRulesCoreExt.jl | 17 +- ext/ImplicitDifferentiationForwardDiffExt.jl | 20 +-- src/ImplicitDifferentiation.jl | 6 +- src/callable.jl | 6 +- src/execution.jl | 153 +++++------------- src/implicit_function.jl | 2 + src/preparation.jl | 33 ++-- src/settings.jl | 71 ++------ test/systematic.jl | 6 +- test/utils.jl | 33 ++-- 13 files changed, 108 insertions(+), 279 deletions(-) diff --git a/Project.toml b/Project.toml index f4bf0b47..cac22e5e 100644 --- a/Project.toml +++ b/Project.toml @@ -6,10 +6,8 @@ version = "0.9.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" -Krylov = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7" +KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -LinearMaps = "7a12625a-238d-50fd-b39a-03d52299707e" -LinearOperators = "5c8ed15e-5a4c-59e4-a42b-c7e8811fb125" [weakdeps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" @@ -35,9 +33,8 @@ ForwardDiff = "0.10.36, 1" JET = "0.9, 0.10" JuliaFormatter = "2.1.2" Krylov = "0.9.6, 0.10" +KrylovKit = "0.9.5" LinearAlgebra = "1" -LinearMaps = "3.11.4" -LinearOperators = "2.8.0" NLsolve = "4.5.1" Optim = "1.12.0" Random = "1" @@ -75,27 +72,4 @@ TestItems = "1c621080-faea-4a02-84b6-bbd5e436b8fe" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = [ - "ADTypes", - "Aqua", - "ChainRulesCore", - "ChainRulesTestUtils", - "ComponentArrays", - "DifferentiationInterface", - "Documenter", - "ExplicitImports", - "FiniteDiff", - "ForwardDiff", - "JET", - "JuliaFormatter", - "LinearAlgebra", - "NLsolve", - "Optim", - "Random", - "SparseArrays", - "StaticArrays", - "Test", - "TestItems", - "TestItemRunner", - "Zygote", -] +test = ["ADTypes", "Aqua", "ChainRulesCore", "ChainRulesTestUtils", "ComponentArrays", "DifferentiationInterface", "Documenter", "ExplicitImports", "FiniteDiff", "ForwardDiff", "JET", "JuliaFormatter", "LinearAlgebra", "NLsolve", "Optim", "Random", "SparseArrays", "StaticArrays", "Test", "TestItems", "TestItemRunner", "Zygote"] diff --git a/docs/Project.toml b/docs/Project.toml index d3a46cbc..b328c3ba 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -6,7 +6,6 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" ImplicitDifferentiation = "57b37032-215b-411a-8a7c-41a003a55207" -Krylov = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306" NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56" @@ -16,4 +15,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] -Documenter = "1.3" \ No newline at end of file +Documenter = "1.3" diff --git a/examples/3_tricks.jl b/examples/3_tricks.jl index 3d47251c..07744ef5 100644 --- a/examples/3_tricks.jl +++ b/examples/3_tricks.jl @@ -44,13 +44,16 @@ end; # And build your implicit function like so: -implicit_components = ImplicitFunction(forward_components, conditions_components); +implicit_components = ImplicitFunction( + forward_components, conditions_components; strict=Val(false) +); # Now we're good to go. a, b, m = [1.0, 2.0], [3.0, 4.0, 5.0], 6.0 x = ComponentVector(; a=a, b=b, m=m) y, z = implicit_components(x) +conditions_components(x, y, z) # And it works with both ForwardDiff.jl and Zygote.jl diff --git a/ext/ImplicitDifferentiationChainRulesCoreExt.jl b/ext/ImplicitDifferentiationChainRulesCoreExt.jl index 341fc37d..c0b2a744 100644 --- a/ext/ImplicitDifferentiationChainRulesCoreExt.jl +++ b/ext/ImplicitDifferentiationChainRulesCoreExt.jl @@ -15,30 +15,23 @@ using ImplicitDifferentiation: ImplicitDifferentiation.chainrules_suggested_backend(rc::RuleConfig) = AutoChainRules(rc) function ChainRulesCore.rrule( - rc::RuleConfig, - implicit::ImplicitFunction, - prep::ImplicitFunctionPreparation, - x::AbstractArray, - args::Vararg{Any,N}; + rc::RuleConfig, implicit::ImplicitFunction, x::AbstractArray, args::Vararg{Any,N}; ) where {N} y, z = implicit(x, args...) c = implicit.conditions(x, y, z, args...) suggested_backend = chainrules_suggested_backend(rc) + prep = ImplicitFunctionPreparation(eltype(x)) Aᵀ = build_Aᵀ(implicit, prep, x, y, z, c, args...; suggested_backend) Bᵀ = build_Bᵀ(implicit, prep, x, y, z, c, args...; suggested_backend) project_x = ProjectTo(x) function implicit_pullback_prepared((dy, dz)) - dy = unthunk(dy) - dy_vec = vec(dy) - dc_vec = implicit.linear_solver(Aᵀ, -dy_vec) - dx_vec = Bᵀ(dc_vec) - dx = reshape(dx_vec, size(x)) + dc = implicit.linear_solver(Aᵀ, -unthunk(dy)) + dx = Bᵀ(dc) df = NoTangent() - dprep = @not_implemented("Tangents for mutable arguments are not defined") dargs = ntuple(unimplemented_tangent, N) - return (df, dprep, project_x(dx), dargs...) + return (df, project_x(dx), dargs...) end return (y, z), implicit_pullback_prepared diff --git a/ext/ImplicitDifferentiationForwardDiffExt.jl b/ext/ImplicitDifferentiationForwardDiffExt.jl index 9f8bae71..5228c270 100644 --- a/ext/ImplicitDifferentiationForwardDiffExt.jl +++ b/ext/ImplicitDifferentiationForwardDiffExt.jl @@ -6,7 +6,7 @@ using ImplicitDifferentiation: ImplicitFunction, ImplicitFunctionPreparation, build_A, build_B function (implicit::ImplicitFunction)( - prep::ImplicitFunctionPreparation, x_and_dx::AbstractArray{Dual{T,R,N}}, args... + prep::ImplicitFunctionPreparation{R}, x_and_dx::AbstractArray{Dual{T,R,N}}, args... ) where {T,R,N} x = value.(x_and_dx) y, z = implicit(x, args...) @@ -19,14 +19,9 @@ function (implicit::ImplicitFunction)( dX = ntuple(Val(N)) do k partials.(x_and_dx, k) end - dC_vec = map(dX) do dₖx - dₖx_vec = vec(dₖx) - dₖc_vec = B(dₖx_vec) - return dₖc_vec - end - dY = map(dC_vec) do dₖc_vec - dₖy_vec = implicit.linear_solver(A, -dₖc_vec) - dₖy = reshape(dₖy_vec, size(y)) + dC = map(B, dX) + dY = map(dC) do dₖc + dₖy = implicit.linear_solver(A, -dₖc) return dₖy end @@ -37,4 +32,11 @@ function (implicit::ImplicitFunction)( return y_and_dy, z end +function (implicit::ImplicitFunction)( + x_and_dx::AbstractArray{Dual{T,R,N}}, args... +) where {T,R,N} + prep = ImplicitFunctionPreparation(R) + return implicit(prep, x_and_dx, args...) +end + end diff --git a/src/ImplicitDifferentiation.jl b/src/ImplicitDifferentiation.jl index c601f669..f334f6a9 100644 --- a/src/ImplicitDifferentiation.jl +++ b/src/ImplicitDifferentiation.jl @@ -17,13 +17,9 @@ using DifferentiationInterface: prepare_pullback_same_point, prepare_pushforward, prepare_pushforward_same_point, - pullback!, pullback, - pushforward!, pushforward -using Krylov: Krylov, krylov_workspace, krylov_solve!, solution -using LinearOperators: LinearOperator -using LinearMaps: FunctionMap +using KrylovKit: GMRES, linsolve using LinearAlgebra: factorize include("utils.jl") diff --git a/src/callable.jl b/src/callable.jl index 5ce5f172..b9187af3 100644 --- a/src/callable.jl +++ b/src/callable.jl @@ -1,9 +1,9 @@ function (implicit::ImplicitFunction)(x::AbstractArray, args::Vararg{Any,N}) where {N} - return implicit(ImplicitFunctionPreparation(), x, args...) + return implicit(ImplicitFunctionPreparation(eltype(x)), x, args...) end function (implicit::ImplicitFunction)( - ::ImplicitFunctionPreparation, x::AbstractArray, args::Vararg{Any,N} -) where {N} + ::ImplicitFunctionPreparation{R}, x::AbstractArray{R}, args::Vararg{Any,N} +) where {R<:Real,N} return implicit.solver(x, args...) end diff --git a/src/execution.jl b/src/execution.jl index 5d543de9..07123ce3 100644 --- a/src/execution.jl +++ b/src/execution.jl @@ -1,43 +1,29 @@ -struct JVP!{F,P,B,I,V,C} +struct JVP{F,P,B,I,C} f::F prep::P backend::B input::I - v_buffer::V contexts::C end -struct VJP!{F,P,B,I,V,C} +struct VJP{F,P,B,I,C} f::F prep::P backend::B input::I - v_buffer::V contexts::C end -function (po::JVP!)(res::AbstractVector, v_wrongtype::AbstractVector) - (; f, backend, input, v_buffer, contexts, prep) = po - if typeof(v_buffer) == typeof(v_wrongtype) - v = v_wrongtype - else - copyto!(v_buffer, v_wrongtype) - v = v_buffer - end - pushforward!(f, (res,), prep, backend, input, (v,), contexts...) - return res +function (po::JVP)(v) + (; f, backend, input, contexts, prep) = po + res = pushforward(f, prep, backend, input, (v,), contexts...) + return only(res) end -function (po::VJP!)(res::AbstractVector, v_wrongtype::AbstractVector) - (; f, backend, input, v_buffer, contexts, prep) = po - if typeof(v_buffer) == typeof(v_wrongtype) - v = v_wrongtype - else - copyto!(v_buffer, v_wrongtype) - v = v_buffer - end - pullback!(f, (res,), prep, backend, input, (v,), contexts...) - return res +function (po::VJP)(v) + (; f, backend, input, contexts, prep) = po + res = pullback(f, prep, backend, input, (v,), contexts...) + return only(res) end ## A @@ -64,56 +50,35 @@ function build_A_aux( (; prep_A) = prep actual_backend = isnothing(backends) ? suggested_backend : backends.y contexts = (Constant(x), Constant(z), map(Constant, args)...) + f = Switch12(conditions) if isnothing(prep_A) - A = jacobian(Switch12(conditions), actual_backend, y, contexts...) + A = jacobian(f, actual_backend, y, contexts...) else - A = jacobian(Switch12(conditions), prep_A, actual_backend, y, contexts...) + A = jacobian(f, prep_A, actual_backend, y, contexts...) end return factorize(A) end function build_A_aux( - ::OperatorRepresentation{package,symmetric,hermitian,posdef}, - implicit, - prep, - x, - y, - z, - c, - args...; - suggested_backend, -) where {package,symmetric,hermitian,posdef} + ::OperatorRepresentation, implicit, prep, x, y, z, c, args...; suggested_backend +) T = Base.promote_eltype(x, y, c) (; conditions, backends) = implicit (; prep_A) = prep actual_backend = isnothing(backends) ? suggested_backend : backends.y contexts = (Constant(x), Constant(z), map(Constant, args)...) - f_vec = VecToVec(Switch12(conditions), y) - y_vec = vec(y) - dy_vec = vec(zero(y)) + f = Switch12(conditions) if isnothing(prep_A) prep_A_same = prepare_pushforward_same_point( - f_vec, actual_backend, y_vec, (dy_vec,), contexts...; strict=implicit.strict + f, actual_backend, y, (zero(y),), contexts...; strict=implicit.strict ) else prep_A_same = prepare_pushforward_same_point( - f_vec, prep_A, actual_backend, y_vec, (dy_vec,), contexts... - ) - end - prod! = JVP!(f_vec, prep_A_same, actual_backend, y_vec, dy_vec, contexts) - if package == :LinearOperators - return LinearOperator(T, length(c), length(y), symmetric, hermitian, prod!) - elseif package == :LinearMaps - return FunctionMap{T}( - prod!, - length(c), - length(y); - ismutating=true, - issymmetric=symmetric, - ishermitian=hermitian, - isposdef=posdef, + f, prep_A, actual_backend, y, (zero(y),), contexts... ) end + A = JVP(f, prep_A_same, actual_backend, y, contexts) + return A end ## Aᵀ @@ -140,58 +105,35 @@ function build_Aᵀ_aux( (; prep_Aᵀ) = prep actual_backend = isnothing(backends) ? suggested_backend : backends.y contexts = (Constant(x), Constant(z), map(Constant, args)...) + f = Switch12(conditions) if isnothing(prep_Aᵀ) - Aᵀ = transpose(jacobian(Switch12(conditions), actual_backend, y, contexts...)) + Aᵀ = transpose(jacobian(f, actual_backend, y, contexts...)) else - Aᵀ = transpose( - jacobian(Switch12(conditions), prep_Aᵀ, actual_backend, y, contexts...) - ) + Aᵀ = transpose(jacobian(f, prep_Aᵀ, actual_backend, y, contexts...)) end return factorize(Aᵀ) end function build_Aᵀ_aux( - ::OperatorRepresentation{package,symmetric,hermitian,posdef}, - implicit, - prep, - x, - y, - z, - c, - args...; - suggested_backend, -) where {package,symmetric,hermitian,posdef} + ::OperatorRepresentation, implicit, prep, x, y, z, c, args...; suggested_backend +) T = Base.promote_eltype(x, y, c) (; conditions, backends) = implicit (; prep_Aᵀ) = prep actual_backend = isnothing(backends) ? suggested_backend : backends.y contexts = (Constant(x), Constant(z), map(Constant, args)...) - f_vec = VecToVec(Switch12(conditions), y) - y_vec = vec(y) - dc_vec = vec(zero(c)) + f = Switch12(conditions) if isnothing(prep_Aᵀ) prep_Aᵀ_same = prepare_pullback_same_point( - f_vec, actual_backend, y_vec, (dc_vec,), contexts...; strict=implicit.strict + f, actual_backend, y, (zero(c),), contexts...; strict=implicit.strict ) else prep_Aᵀ_same = prepare_pullback_same_point( - f_vec, prep_Aᵀ, actual_backend, y_vec, (dc_vec,), contexts... - ) - end - prod! = VJP!(f_vec, prep_Aᵀ_same, actual_backend, y_vec, dc_vec, contexts) - if package == :LinearOperators - return LinearOperator(T, length(y), length(c), symmetric, hermitian, prod!) - elseif package == :LinearMaps - return FunctionMap{T}( - prod!, - length(y), - length(c); - ismutating=true, - issymmetric=symmetric, - ishermitian=hermitian, - isposdef=posdef, + f, prep_Aᵀ, actual_backend, y, (zero(c),), contexts... ) end + A = VJP(f, prep_Aᵀ_same, actual_backend, y, contexts) + return A end ## B @@ -210,26 +152,17 @@ function build_B( (; prep_B) = prep actual_backend = isnothing(backends) ? suggested_backend : backends.x contexts = (Constant(y), Constant(z), map(Constant, args)...) - f_vec = VecToVec(conditions, x) - x_vec = vec(x) - dx_vec = vec(zero(x)) if isnothing(prep_B) prep_B_same = prepare_pushforward_same_point( - f_vec, actual_backend, x_vec, (dx_vec,), contexts... + conditions, actual_backend, x, (zero(x),), contexts... ) else prep_B_same = prepare_pushforward_same_point( - f_vec, prep_B, actual_backend, x_vec, (dx_vec,), contexts... + conditions, prep_B, actual_backend, x, (zero(x),), contexts... ) end - function B_fun(dx_vec_wrongtype) - @assert typeof(dx_vec) == typeof(dx_vec_wrongtype) - dx_vec = dx_vec_wrongtype - return pushforward( - f_vec, prep_B_same, actual_backend, x_vec, (dx_vec,), contexts... - )[1] - end - return B_fun + B = JVP(conditions, prep_B_same, actual_backend, x, contexts) + return B end ## Bᵀ @@ -248,25 +181,15 @@ function build_Bᵀ( (; prep_Bᵀ) = prep actual_backend = isnothing(backends) ? suggested_backend : backends.x contexts = (Constant(y), Constant(z), map(Constant, args)...) - f_vec = VecToVec(conditions, x) - x_vec = vec(x) - dc_vec = vec(zero(c)) if isnothing(prep_Bᵀ) prep_Bᵀ_same = prepare_pullback_same_point( - f_vec, actual_backend, x_vec, (dc_vec,), contexts...; strict=implicit.strict + conditions, actual_backend, x, (zero(c),), contexts...; strict=implicit.strict ) else prep_Bᵀ_same = prepare_pullback_same_point( - f_vec, prep_Bᵀ, actual_backend, x_vec, (dc_vec,), contexts... + conditions, prep_Bᵀ, actual_backend, x, (zero(c),), contexts... ) end - function Bᵀ_fun(dc_vec_wrongtype) - if typeof(dc_vec) == typeof(dc_vec_wrongtype) - dc_vec = dc_vec_wrongtype - else - copyto!(dc_vec, dc_vec_wrongtype) - end - return pullback(f_vec, prep_Bᵀ_same, actual_backend, x_vec, (dc_vec,), contexts...)[1] - end - return Bᵀ_fun + Bᵀ = VJP(conditions, prep_Bᵀ_same, actual_backend, x, contexts) + return Bᵀ end diff --git a/src/implicit_function.jl b/src/implicit_function.jl index a12d72e4..9d3e81a9 100644 --- a/src/implicit_function.jl +++ b/src/implicit_function.jl @@ -23,6 +23,7 @@ This requires solving a linear system `A * J = -B` where `A = ∂₂c`, `B = ∂ representation=OperatorRepresentation(), linear_solver=IterativeLinearSolver(), backends=nothing, + strict=Val(true), ) ## Positional arguments @@ -35,6 +36,7 @@ This requires solving a linear system `A * J = -B` where `A = ∂₂c`, `B = ∂ - `representation`: defines how the partial Jacobian `A` of the conditions with respect to the output is represented, either [`MatrixRepresentation`](@ref) or [`OperatorRepresentation`](@ref). - `linear_solver`: specifies how the linear system `A * J = -B` will be solved in the implicit function theorem. It can be either [`DirectLinearSolver`](@ref) or [`IterativeLinearSolver`](@ref). - `backends::AbstractADType`: specifies how the `conditions` will be differentiated with respect to `x` and `y`. It can be either, `nothing`, which means that the external autodiff system will be used, or a named tuple `(; x=AutoSomething(), y=AutoSomethingElse())` of backend objects from [ADTypes.jl](https://github.com/SciML/ADTypes.jl). +- `strict::Val`: specifies whether preparation inside [DifferentiationInterface.jl](https://github.com/JuliaDiff/DifferentiationInterface.jl) should enforce a strict match between the primal variables and the provided tangents. """ struct ImplicitFunction{ F, diff --git a/src/preparation.jl b/src/preparation.jl index 4183f757..a84220d1 100644 --- a/src/preparation.jl +++ b/src/preparation.jl @@ -8,15 +8,16 @@ - `prep_B`: preparation for `B` (derivative of conditions with respect to `x`) in forward mode - `prep_Bᵀ`: preparation for `B` (derivative of conditions with respect to `x`) in reverse mode """ -struct ImplicitFunctionPreparation{PA,PAT,PB,PBT} +struct ImplicitFunctionPreparation{R<:Real,PA,PAT,PB,PBT} + _R::Type{R} prep_A::PA prep_Aᵀ::PAT prep_B::PB prep_Bᵀ::PBT end -function ImplicitFunctionPreparation() - return ImplicitFunctionPreparation(nothing, nothing, nothing, nothing) +function ImplicitFunctionPreparation(::Type{R}) where {R<:Real} + return ImplicitFunctionPreparation(R, nothing, nothing, nothing, nothing) end """ @@ -65,7 +66,7 @@ function prepare_implicit( prep_Bᵀ = nothing end end - return ImplicitFunctionPreparation(prep_A, prep_Aᵀ, prep_B, prep_Bᵀ) + return ImplicitFunctionPreparation(eltype(x), prep_A, prep_Aᵀ, prep_B, prep_Bᵀ) end function prepare_A( @@ -95,10 +96,9 @@ function prepare_A( strict::Val, ) contexts = (Constant(x), Constant(z), map(Constant, args)...) - f_vec = VecToVec(Switch12(conditions), y) - y_vec = vec(y) - dy_vec = vec(zero(y)) - return prepare_pushforward(f_vec, backend, y_vec, (dy_vec,), contexts...; strict) + return prepare_pushforward( + Switch12(conditions), backend, y, (zero(y),), contexts...; strict + ) end function prepare_Aᵀ( @@ -128,10 +128,9 @@ function prepare_Aᵀ( strict::Val, ) contexts = (Constant(x), Constant(z), map(Constant, args)...) - f_vec = VecToVec(Switch12(conditions), y) - y_vec = vec(y) - dc_vec = vec(zero(c)) - return prepare_pullback(f_vec, backend, y_vec, (dc_vec,), contexts...; strict) + return prepare_pullback( + Switch12(conditions), backend, y, (zero(c),), contexts...; strict + ) end function prepare_B( @@ -146,10 +145,7 @@ function prepare_B( strict::Val, ) contexts = (Constant(y), Constant(z), map(Constant, args)...) - f_vec = VecToVec(conditions, x) - x_vec = vec(x) - dx_vec = vec(zero(x)) - return prepare_pushforward(f_vec, backend, x_vec, (dx_vec,), contexts...; strict) + return prepare_pushforward(conditions, backend, x, (zero(x),), contexts...; strict) end function prepare_Bᵀ( @@ -164,8 +160,5 @@ function prepare_Bᵀ( strict::Val, ) contexts = (Constant(y), Constant(z), map(Constant, args)...) - f_vec = VecToVec(conditions, x) - x_vec = vec(x) - dc_vec = vec(zero(c)) - return prepare_pullback(f_vec, backend, x_vec, (dc_vec,), contexts...; strict) + return prepare_pullback(conditions, backend, x, (zero(c),), contexts...; strict) end diff --git a/src/settings.jl b/src/settings.jl index 3b739f25..279b3568 100644 --- a/src/settings.jl +++ b/src/settings.jl @@ -15,39 +15,20 @@ end IterativeLinearSolver Specify that linear systems `Ax = b` should be solved with an iterative method. - -# Constructor - - IterativeLinearSolver(::Val{method}=Val(:gmres); kwargs...) - -The `method` symbol is used to pick the appropriate algorithm from [Krylov.jl](https://github.com/JuliaSmoothOptimizers/Krylov.jl). -Keyword arguments are passed on to that algorithm. """ -struct IterativeLinearSolver{method,K} - _method::Val{method} +struct IterativeLinearSolver{A,K} + algorithm::A kwargs::K - function IterativeLinearSolver((::Val{method})=Val(:gmres); kwargs...) where {method} - return new{method,typeof(kwargs)}(Val(method), kwargs) - end -end - -function Base.show(io::IO, linear_solver::IterativeLinearSolver{method}) where {method} - print(io, "IterativeLinearSolver{$(repr(method))}") - if isempty(linear_solver.kwargs) - print(io, "()") - else - print(io, "(; ") - for (k, v) in pairs(linear_solver.kwargs) - print(io, "$k=$(repr(v)), ") - end - print(io, ")") + function IterativeLinearSolver(algorithm=GMRES(); kwargs...) + return new{typeof(algorithm),typeof(kwargs)}(algorithm, kwargs) end end -function (solver::IterativeLinearSolver{method})(A, b::AbstractVector) where {method} - workspace = krylov_workspace(Val(method), A, b) - krylov_solve!(workspace, A, b) - return solution(workspace) +function (solver::IterativeLinearSolver)(A, b) + x0 = zero(b) + sol, info = linsolve(A, b, x0, solver.algorithm; solver.kwargs...) + @assert info.converged == 1 + return sol end ## Representation @@ -69,43 +50,13 @@ struct MatrixRepresentation <: AbstractRepresentation end """ OperatorRepresentation -Specify that the matrix `A` involved in the implicit function theorem should be represented lazily. - -# Constructors - - OperatorRepresentation(; symmetric=false, hermitian=false, posdef=false) - OperatorRepresentation{package}(; symmetric=false, hermitian=false, posdef=false) - -The type parameter `package` can be either: - -- `:LinearOperators` to use a wrapper from [LinearOperators.jl](https://github.com/JuliaSmoothOptimizers/LinearOperators.jl) (the default) -- `:LinearMaps` to use a wrapper from [LinearMaps.jl](https://github.com/JuliaLinearAlgebra/LinearMaps.jl) - -The keyword arguments `symmetric`, `hermitian` and `posdef` give additional properties of the Jacobian of the `conditions` with respect to the solution `y`, which are useful to the solver in case you can prove them. +Specify that the matrix `A` involved in the implicit function theorem should be represented lazily, as a function. # See also - [`ImplicitFunction`](@ref) - [`MatrixRepresentation`](@ref) """ -struct OperatorRepresentation{package,symmetric,hermitian,posdef} <: AbstractRepresentation - function OperatorRepresentation{package}(; - symmetric::Bool=false, hermitian::Bool=false, posdef::Bool=false - ) where {package} - @assert package in [:LinearOperators, :LinearMaps] - return new{package,symmetric,hermitian,posdef}() - end -end - -function Base.show( - io::IO, ::OperatorRepresentation{package,symmetric,hermitian,posdef} -) where {package,symmetric,hermitian,posdef} - return print( - io, - "OperatorRepresentation{$(repr(package))}(; symmetric=$symmetric, hermitian=$hermitian, posdef=$posdef)", - ) -end - -OperatorRepresentation(; kwargs...) = OperatorRepresentation{:LinearOperators}(; kwargs...) +struct OperatorRepresentation <: AbstractRepresentation end function chainrules_suggested_backend end diff --git a/test/systematic.jl b/test/systematic.jl index c358f583..1a677676 100644 --- a/test/systematic.jl +++ b/test/systematic.jl @@ -2,10 +2,8 @@ using TestItems @testitem "Direct" setup = [TestUtils] begin using ADTypes, .TestUtils - for (backends, x) in Iterators.product( - [nothing, (; x=AutoForwardDiff(), y=AutoZygote())], - [float.(1:3), reshape(float.(1:6), 3, 2)], - ) + for (backends, x) in + Iterators.product([nothing, (; x=AutoForwardDiff(), y=AutoZygote())], [float.(1:3)]) yield() scen = Scenario(; solver=default_solver, diff --git a/test/utils.jl b/test/utils.jl index 1b6d5992..164b1aa8 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -118,7 +118,7 @@ function test_implicit_duals(scen::Scenario) @testset "Duals" begin @testset "Prepared" begin - y_and_dy, z = implicit(prep, x_and_dx, scen.args...) + y_and_dy, z = @inferred implicit(prep, x_and_dx, scen.args...) T = tag(y_and_dy) y = ForwardDiff.value.(y_and_dy) dy = ForwardDiff.extract_derivative.(T, y_and_dy) @@ -127,7 +127,7 @@ function test_implicit_duals(scen::Scenario) @test z == z_true end @testset "Unrepared" begin - y_and_dy, z = implicit(x_and_dx, scen.args...) + y_and_dy, z = @inferred implicit(x_and_dx, scen.args...) T = tag(y_and_dy) y = ForwardDiff.value.(y_and_dy) dy = ForwardDiff.extract_derivative.(T, y_and_dy) @@ -144,7 +144,6 @@ function test_implicit_rrule(scen::Scenario) implicit = ImplicitFunction( NonDifferentiable(scen.solver), scen.conditions; scen.implicit_kwargs... ) - prep = prepare_implicit(ReverseMode(), implicit, scen.x_prep, scen.args_prep...) y_true, z_true = scen.solver(scen.x, scen.args...) dy = similar(y_true) @@ -156,17 +155,11 @@ function test_implicit_rrule(scen::Scenario) )[1] @testset "ChainRule" begin - @testset "Prepared" begin - (y, z), pb = rrule(ZygoteRuleConfig(), implicit, prep, scen.x, scen.args...) - dimpl, dprep, dx = pb((dy, dz)) - @test y ≈ y_true - @test z == z_true - @test dimpl isa NoTangent - @test dx ≈ dx_true - end @testset "Unprepared" begin - (y, z), pb = rrule_via_ad(ZygoteRuleConfig(), implicit, scen.x, scen.args...) - dimpl, dx = pb((dy, dz)) + (y, z), pb = @inferred rrule_via_ad( + ZygoteRuleConfig(), implicit, scen.x, scen.args... + ) + dimpl, dx = @inferred pb((dy, dz)) @test y ≈ y_true @test z == z_true @test dimpl isa NoTangent @@ -187,11 +180,13 @@ function test_implicit_jacobian(scen::Scenario, outer_backend::AbstractADType) ) @testset "Jacobian - $outer_backend" begin - @testset "Prepared" begin - jac = DI.jacobian( - x -> first(implicit(prep, x, scen.args...)), outer_backend, scen.x - ) - @test jac ≈ jac_true + if outer_backend isa AutoForwardDiff + @testset "Prepared" begin + jac = DI.jacobian( + x -> first(implicit(prep, x, scen.args...)), outer_backend, scen.x + ) + @test jac ≈ jac_true + end end @testset "Unprepared" begin jac = DI.jacobian( @@ -203,7 +198,7 @@ function test_implicit_jacobian(scen::Scenario, outer_backend::AbstractADType) end function test_implicit(scen::Scenario, outer_backends=[AutoForwardDiff(), AutoZygote()]) - @testset "$scen" begin + return @testset "$scen" begin test_implicit_call(scen) test_implicit_duals(scen) test_implicit_rrule(scen) From 38d4c68414b63a0151f18a4f091e2a717dc313cb Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Wed, 2 Jul 2025 08:39:58 +0200 Subject: [PATCH 02/12] Fix --- test/systematic.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/systematic.jl b/test/systematic.jl index 1a677676..36306259 100644 --- a/test/systematic.jl +++ b/test/systematic.jl @@ -33,7 +33,7 @@ end; conditions=default_conditions, x=x, implicit_kwargs=(; - representation=OperatorRepresentation{:LinearOperators}(), + representation=OperatorRepresentation(), linear_solver=IterativeLinearSolver(), backends, ), From 7758221c0775176cce13d7bee4f0269e46d6fbb2 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Wed, 2 Jul 2025 09:05:29 +0200 Subject: [PATCH 03/12] Add type stability test --- ...mplicitDifferentiationChainRulesCoreExt.jl | 32 +++++++++----- src/implicit_function.jl | 2 +- src/settings.jl | 10 +++++ src/utils.jl | 21 --------- test/systematic.jl | 4 +- test/utils.jl | 44 +++++++++++-------- 6 files changed, 60 insertions(+), 53 deletions(-) diff --git a/ext/ImplicitDifferentiationChainRulesCoreExt.jl b/ext/ImplicitDifferentiationChainRulesCoreExt.jl index c0b2a744..c86bafa4 100644 --- a/ext/ImplicitDifferentiationChainRulesCoreExt.jl +++ b/ext/ImplicitDifferentiationChainRulesCoreExt.jl @@ -14,11 +14,29 @@ using ImplicitDifferentiation: # not covered by Codecov for now ImplicitDifferentiation.chainrules_suggested_backend(rc::RuleConfig) = AutoChainRules(rc) +struct ImplicitPullback{TA,TB,TL,TP,Nargs} + Aᵀ::TA + Bᵀ::TB + linear_solver::TL + project_x::TP + _Nargs::Val{Nargs} +end + +function (pb::ImplicitPullback{TA,TB,TL,TP,Nargs})((dy, dz)) where {TA,TB,TL,TP,Nargs} + (; Aᵀ, Bᵀ, linear_solver, project_x) = pb + dc = linear_solver(Aᵀ, -unthunk(dy)) + dx = Bᵀ(dc) + df = NoTangent() + dargs = ntuple(unimplemented_tangent, Val(Nargs)) + return (df, project_x(dx), dargs...) +end + function ChainRulesCore.rrule( rc::RuleConfig, implicit::ImplicitFunction, x::AbstractArray, args::Vararg{Any,N}; ) where {N} + (; conditions, linear_solver) = implicit y, z = implicit(x, args...) - c = implicit.conditions(x, y, z, args...) + c = conditions(x, y, z, args...) suggested_backend = chainrules_suggested_backend(rc) prep = ImplicitFunctionPreparation(eltype(x)) @@ -26,15 +44,8 @@ function ChainRulesCore.rrule( Bᵀ = build_Bᵀ(implicit, prep, x, y, z, c, args...; suggested_backend) project_x = ProjectTo(x) - function implicit_pullback_prepared((dy, dz)) - dc = implicit.linear_solver(Aᵀ, -unthunk(dy)) - dx = Bᵀ(dc) - df = NoTangent() - dargs = ntuple(unimplemented_tangent, N) - return (df, project_x(dx), dargs...) - end - - return (y, z), implicit_pullback_prepared + implicit_pullback = ImplicitPullback(Aᵀ, Bᵀ, linear_solver, project_x, Val(N)) + return (y, z), implicit_pullback end function unimplemented_tangent(_) @@ -42,5 +53,4 @@ function unimplemented_tangent(_) "Tangents for positional arguments of an `ImplicitFunction` beyond `x` (the first one) are not implemented" ) end - end diff --git a/src/implicit_function.jl b/src/implicit_function.jl index 9d3e81a9..8e855d31 100644 --- a/src/implicit_function.jl +++ b/src/implicit_function.jl @@ -33,7 +33,7 @@ This requires solving a linear system `A * J = -B` where `A = ∂₂c`, `B = ∂ ## Keyword arguments -- `representation`: defines how the partial Jacobian `A` of the conditions with respect to the output is represented, either [`MatrixRepresentation`](@ref) or [`OperatorRepresentation`](@ref). +- `representation`: defines how the partial Jacobian `A` of the conditions with respect to the output is represented. It can be either [`MatrixRepresentation`](@ref) or [`OperatorRepresentation`](@ref). - `linear_solver`: specifies how the linear system `A * J = -B` will be solved in the implicit function theorem. It can be either [`DirectLinearSolver`](@ref) or [`IterativeLinearSolver`](@ref). - `backends::AbstractADType`: specifies how the `conditions` will be differentiated with respect to `x` and `y`. It can be either, `nothing`, which means that the external autodiff system will be used, or a named tuple `(; x=AutoSomething(), y=AutoSomethingElse())` of backend objects from [ADTypes.jl](https://github.com/SciML/ADTypes.jl). - `strict::Val`: specifies whether preparation inside [DifferentiationInterface.jl](https://github.com/JuliaDiff/DifferentiationInterface.jl) should enforce a strict match between the primal variables and the provided tangents. diff --git a/src/settings.jl b/src/settings.jl index 279b3568..7b278811 100644 --- a/src/settings.jl +++ b/src/settings.jl @@ -4,6 +4,11 @@ DirectLinearSolver Specify that linear systems `Ax = b` should be solved with a direct method. + +# See also + +- [`ImplicitFunction`](@ref) +- [`IterativeLinearSolver`](@ref) """ struct DirectLinearSolver end @@ -15,6 +20,11 @@ end IterativeLinearSolver Specify that linear systems `Ax = b` should be solved with an iterative method. + +# See also + +- [`ImplicitFunction`](@ref) +- [`DirectLinearSolver`](@ref) """ struct IterativeLinearSolver{A,K} algorithm::A diff --git a/src/utils.jl b/src/utils.jl index 5dc612eb..807e6606 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -13,24 +13,3 @@ end function (s12::Switch12)(arg1, arg2, other_args::Vararg{Any,N}) where {N} return s12.f(arg2, arg1, other_args...) end - -""" - VecToVec - -Represent a function which behaves like `f`, except that the first argument is expected as a vector, and the return is converted to a vector: - f(a1, a2, a3) = b -becomes - g(a1_vec, a2, a3) = vec(f(reshape(a1_vec, size(a1)), a2, a3)) -""" -struct VecToVec{F,N} - f::F - arg1_size::NTuple{N,Int} -end - -VecToVec(f::F, arg1_example::AbstractArray) where {F} = VecToVec(f, size(arg1_example)) - -function (v2v::VecToVec)(arg1_vec::AbstractVector, other_args::Vararg{Any,N}) where {N} - arg1 = reshape(arg1_vec, v2v.arg1_size) - res = v2v.f(arg1, other_args...) - return vec(res) -end diff --git a/test/systematic.jl b/test/systematic.jl index 36306259..eaf8f10a 100644 --- a/test/systematic.jl +++ b/test/systematic.jl @@ -39,8 +39,8 @@ end; ), ) scen2 = add_arg_mult(scen) - test_implicit(scen) - test_implicit(scen2) + test_implicit(scen; type_stability=true) + test_implicit(scen2; type_stability=true) end end; diff --git a/test/utils.jl b/test/utils.jl index 164b1aa8..b7fcd460 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -97,7 +97,7 @@ end tag(::AbstractArray{<:ForwardDiff.Dual{T}}) where {T} = T -function test_implicit_duals(scen::Scenario) +function test_implicit_duals(scen::Scenario; type_stability::Bool) implicit = ImplicitFunction( NonDifferentiable(scen.solver), scen.conditions; scen.implicit_kwargs... ) @@ -118,29 +118,33 @@ function test_implicit_duals(scen::Scenario) @testset "Duals" begin @testset "Prepared" begin - y_and_dy, z = @inferred implicit(prep, x_and_dx, scen.args...) + y_and_dy, z = implicit(prep, x_and_dx, scen.args...) T = tag(y_and_dy) y = ForwardDiff.value.(y_and_dy) dy = ForwardDiff.extract_derivative.(T, y_and_dy) @test y ≈ y_true @test dy ≈ dy_true @test z == z_true + if type_stability + @inferred implicit(prep, x_and_dx, scen.args...) + end end @testset "Unrepared" begin - y_and_dy, z = @inferred implicit(x_and_dx, scen.args...) + y_and_dy, z = implicit(x_and_dx, scen.args...) T = tag(y_and_dy) y = ForwardDiff.value.(y_and_dy) dy = ForwardDiff.extract_derivative.(T, y_and_dy) @test y ≈ y_true @test dy ≈ dy_true @test z == z_true + if type_stability + @inferred implicit(x_and_dx, scen.args...) + end end end end -function compare_pullbacks(dimpl, dx, dx_true) end - -function test_implicit_rrule(scen::Scenario) +function test_implicit_rrule(scen::Scenario; type_stability::Bool) implicit = ImplicitFunction( NonDifferentiable(scen.solver), scen.conditions; scen.implicit_kwargs... ) @@ -155,15 +159,15 @@ function test_implicit_rrule(scen::Scenario) )[1] @testset "ChainRule" begin - @testset "Unprepared" begin - (y, z), pb = @inferred rrule_via_ad( - ZygoteRuleConfig(), implicit, scen.x, scen.args... - ) - dimpl, dx = @inferred pb((dy, dz)) - @test y ≈ y_true - @test z == z_true - @test dimpl isa NoTangent - @test dx ≈ dx_true + (y, z), pb = rrule_via_ad(ZygoteRuleConfig(), implicit, scen.x, scen.args...) + dimpl, dx = pb((dy, dz)) + @test y ≈ y_true + @test z == z_true + @test dimpl isa NoTangent + @test dx ≈ dx_true + if type_stability + @inferred rrule_via_ad(ZygoteRuleConfig(), implicit, scen.x, scen.args...) + @inferred pb((dy, dz)) end end end @@ -197,11 +201,15 @@ function test_implicit_jacobian(scen::Scenario, outer_backend::AbstractADType) end end -function test_implicit(scen::Scenario, outer_backends=[AutoForwardDiff(), AutoZygote()]) +function test_implicit( + scen::Scenario, + outer_backends=[AutoForwardDiff(), AutoZygote()]; + type_stability::Bool=false, +) return @testset "$scen" begin test_implicit_call(scen) - test_implicit_duals(scen) - test_implicit_rrule(scen) + test_implicit_duals(scen; type_stability) + test_implicit_rrule(scen; type_stability) for outer_backend in outer_backends test_implicit_jacobian(scen, outer_backend) end From 46c738bffd28d6ce298c17b011fabc1f843e1821 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Wed, 2 Jul 2025 09:06:10 +0200 Subject: [PATCH 04/12] Typo --- examples/3_tricks.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/3_tricks.jl b/examples/3_tricks.jl index 07744ef5..820628a8 100644 --- a/examples/3_tricks.jl +++ b/examples/3_tricks.jl @@ -53,7 +53,6 @@ implicit_components = ImplicitFunction( a, b, m = [1.0, 2.0], [3.0, 4.0, 5.0], 6.0 x = ComponentVector(; a=a, b=b, m=m) y, z = implicit_components(x) -conditions_components(x, y, z) # And it works with both ForwardDiff.jl and Zygote.jl From 6d936bba4a6dcd0d399a672dcccacca9ee21bafb Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Wed, 2 Jul 2025 09:08:30 +0200 Subject: [PATCH 05/12] Fix --- test/printing.jl | 3 --- 1 file changed, 3 deletions(-) diff --git a/test/printing.jl b/test/printing.jl index 2d94f78c..dca56350 100644 --- a/test/printing.jl +++ b/test/printing.jl @@ -2,7 +2,4 @@ using TestItems @testitem "Settings" begin @test startswith(string(ImplicitFunction(nothing, nothing)), "ImplicitFunction") - @test startswith(string(OperatorRepresentation()), "Operator") - @test startswith(string(IterativeLinearSolver(; atol=1e-5)), "Iterative") - @test startswith(string(IterativeLinearSolver()), "Iterative") end From 8a62b19a73ba76c774feef4710ce16a3c8bde116 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Wed, 2 Jul 2025 09:42:13 +0200 Subject: [PATCH 06/12] Fixes --- src/ImplicitDifferentiation.jl | 1 - src/execution.jl | 2 -- src/preparation.jl | 7 ++++++- test/systematic.jl | 4 ++-- 4 files changed, 8 insertions(+), 6 deletions(-) diff --git a/src/ImplicitDifferentiation.jl b/src/ImplicitDifferentiation.jl index f334f6a9..9f74cb4c 100644 --- a/src/ImplicitDifferentiation.jl +++ b/src/ImplicitDifferentiation.jl @@ -32,6 +32,5 @@ include("callable.jl") export MatrixRepresentation, OperatorRepresentation export IterativeLinearSolver, DirectLinearSolver export ImplicitFunction -export prepare_implicit end diff --git a/src/execution.jl b/src/execution.jl index 07123ce3..200b426a 100644 --- a/src/execution.jl +++ b/src/execution.jl @@ -62,7 +62,6 @@ end function build_A_aux( ::OperatorRepresentation, implicit, prep, x, y, z, c, args...; suggested_backend ) - T = Base.promote_eltype(x, y, c) (; conditions, backends) = implicit (; prep_A) = prep actual_backend = isnothing(backends) ? suggested_backend : backends.y @@ -117,7 +116,6 @@ end function build_Aᵀ_aux( ::OperatorRepresentation, implicit, prep, x, y, z, c, args...; suggested_backend ) - T = Base.promote_eltype(x, y, c) (; conditions, backends) = implicit (; prep_Aᵀ) = prep actual_backend = isnothing(backends) ? suggested_backend : backends.y diff --git a/src/preparation.jl b/src/preparation.jl index a84220d1..889d1e10 100644 --- a/src/preparation.jl +++ b/src/preparation.jl @@ -29,7 +29,12 @@ end strict=Val(true) ) -Uses the preparation mechanism from [DifferentiationInterface.jl](https://github.com/JuliaDiff/DifferentiationInterface.jl) to speed up subsequent calls to `implicit(x, args...)` where `(x, args...)` are similar to `(x_prep, args_prep...)`. +Uses the preparation mechanism from [DifferentiationInterface.jl](https://github.com/JuliaDiff/DifferentiationInterface.jl) to speed up subsequent differentiated calls to `implicit(x, args...)` where `(x, args...)` are similar to `(x_prep, args_prep...)`. + +The `mode` argument is an object from [ADTypes.jl](https://github.com/SciML/ADTypes.jl) that specifies whether the preparation should target [`ForwardMode`](@extref ADTypes.ForwardMode), [`ReverseMode`](@extref ADTypes.ReverseMode) or both ([`ForwardOrReverseMode`](ext@ref ADTypes.ForwardOrReverseMode)). + +!!! warning + This mechanism is not yet part of the public API, use it at your own risk. """ function prepare_implicit( mode::AbstractMode, implicit::ImplicitFunction, x, args::Vararg{Any,N}; strict=Val(true) diff --git a/test/systematic.jl b/test/systematic.jl index eaf8f10a..9e4a9a4e 100644 --- a/test/systematic.jl +++ b/test/systematic.jl @@ -39,8 +39,8 @@ end; ), ) scen2 = add_arg_mult(scen) - test_implicit(scen; type_stability=true) - test_implicit(scen2; type_stability=true) + test_implicit(scen; type_stability=VERSION >= v"1.11") + test_implicit(scen2; type_stability=VERSION >= v"1.11") end end; From e885024d2481141ab552e01a15ca6b0503e3344d Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Wed, 2 Jul 2025 09:49:41 +0200 Subject: [PATCH 07/12] Fix --- docs/src/api.md | 1 - src/preparation.jl | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/docs/src/api.md b/docs/src/api.md index d72019b6..5c6c3da8 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -23,7 +23,6 @@ MatrixRepresentation OperatorRepresentation IterativeLinearSolver DirectLinearSolver -prepare_implicit ``` ## Internals diff --git a/src/preparation.jl b/src/preparation.jl index 889d1e10..dd01b90c 100644 --- a/src/preparation.jl +++ b/src/preparation.jl @@ -31,7 +31,7 @@ end Uses the preparation mechanism from [DifferentiationInterface.jl](https://github.com/JuliaDiff/DifferentiationInterface.jl) to speed up subsequent differentiated calls to `implicit(x, args...)` where `(x, args...)` are similar to `(x_prep, args_prep...)`. -The `mode` argument is an object from [ADTypes.jl](https://github.com/SciML/ADTypes.jl) that specifies whether the preparation should target [`ForwardMode`](@extref ADTypes.ForwardMode), [`ReverseMode`](@extref ADTypes.ReverseMode) or both ([`ForwardOrReverseMode`](ext@ref ADTypes.ForwardOrReverseMode)). +The `mode` argument is an object from [ADTypes.jl](https://github.com/SciML/ADTypes.jl) that specifies whether the preparation should target `ForwardMode`, `ReverseMode` or both (`ForwardOrReverseMode`). !!! warning This mechanism is not yet part of the public API, use it at your own risk. From 206b30c56fc9d7c95b5f2a7f34eb898204f4878e Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Wed, 2 Jul 2025 09:53:11 +0200 Subject: [PATCH 08/12] Strict --- src/execution.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/execution.jl b/src/execution.jl index 200b426a..39393830 100644 --- a/src/execution.jl +++ b/src/execution.jl @@ -152,7 +152,7 @@ function build_B( contexts = (Constant(y), Constant(z), map(Constant, args)...) if isnothing(prep_B) prep_B_same = prepare_pushforward_same_point( - conditions, actual_backend, x, (zero(x),), contexts... + conditions, actual_backend, x, (zero(x),), contexts...; strict=implicit.strict ) else prep_B_same = prepare_pushforward_same_point( From 0ece953f4c82a97e37cf7f84d30cf6f94d240db9 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Wed, 2 Jul 2025 10:12:35 +0200 Subject: [PATCH 09/12] Fix coverage --- Project.toml | 25 ++++++++++++++++++++++++- test/preparation.jl | 27 ++++++++++++++++++++++++++- 2 files changed, 50 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index cac22e5e..c525894c 100644 --- a/Project.toml +++ b/Project.toml @@ -72,4 +72,27 @@ TestItems = "1c621080-faea-4a02-84b6-bbd5e436b8fe" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["ADTypes", "Aqua", "ChainRulesCore", "ChainRulesTestUtils", "ComponentArrays", "DifferentiationInterface", "Documenter", "ExplicitImports", "FiniteDiff", "ForwardDiff", "JET", "JuliaFormatter", "LinearAlgebra", "NLsolve", "Optim", "Random", "SparseArrays", "StaticArrays", "Test", "TestItems", "TestItemRunner", "Zygote"] +test = [ + "ADTypes", + "Aqua", + "ChainRulesCore", + "ChainRulesTestUtils", + "ComponentArrays", + "DifferentiationInterface", + "Documenter", + "ExplicitImports", + "FiniteDiff", + "ForwardDiff", + "JET", + "JuliaFormatter", + "LinearAlgebra", + "NLsolve", + "Optim", + "Random", + "SparseArrays", + "StaticArrays", + "Test", + "TestItems", + "TestItemRunner", + "Zygote", +] diff --git a/test/preparation.jl b/test/preparation.jl index 32d59b74..d898b7f9 100644 --- a/test/preparation.jl +++ b/test/preparation.jl @@ -1,5 +1,7 @@ @testitem "Preparation" begin using ImplicitDifferentiation + using ImplicitDifferentiation: + prepare_implicit, build_A, build_Aᵀ, build_B, build_Bᵀ, JVP, VJP using ADTypes using ADTypes: ForwardOrReverseMode, ForwardMode, ReverseMode using ForwardDiff: ForwardDiff @@ -8,11 +10,22 @@ solver(x) = sqrt.(x), nothing conditions(x, y, z) = y .^ 2 .- x + implicit = ImplicitFunction( + solver, + conditions; + backends=(; x=AutoForwardDiff(), y=AutoForwardDiff()), + representation=MatrixRepresentation(), + ) + implicit_iterative = ImplicitFunction( solver, conditions; backends=(; x=AutoForwardDiff(), y=AutoForwardDiff()) ) implicit_nobackends = ImplicitFunction(solver, conditions) + x = rand(5) + y, z = implicit(x) + c = conditions(x, y, z) + suggested_backend = AutoEnzyme() @testset "None" begin prep = prepare_implicit(ForwardOrReverseMode(), implicit_nobackends, x) @@ -28,6 +41,10 @@ @test prep.prep_Aᵀ === nothing @test prep.prep_B !== nothing @test prep.prep_Bᵀ === nothing + @test build_A(implicit, prep, x, y, z, c; suggested_backend) isa AbstractMatrix + @test build_Aᵀ(implicit, prep, x, y, z, c; suggested_backend) isa AbstractMatrix + @test build_B(implicit, prep, x, y, z, c; suggested_backend) isa JVP + @test build_Bᵀ(implicit, prep, x, y, z, c; suggested_backend) isa VJP end @testset "ReverseMode" begin @@ -36,13 +53,21 @@ @test prep.prep_Aᵀ !== nothing @test prep.prep_B === nothing @test prep.prep_Bᵀ !== nothing + @test build_A(implicit, prep, x, y, z, c; suggested_backend) isa AbstractMatrix + @test build_Aᵀ(implicit, prep, x, y, z, c; suggested_backend) isa AbstractMatrix + @test build_B(implicit, prep, x, y, z, c; suggested_backend) isa JVP + @test build_Bᵀ(implicit, prep, x, y, z, c; suggested_backend) isa VJP end @testset "Both" begin - prep = prepare_implicit(ForwardOrReverseMode(), implicit, x) + prep = prepare_implicit(ForwardOrReverseMode(), implicit_iterative, x) @test prep.prep_A !== nothing @test prep.prep_Aᵀ !== nothing @test prep.prep_B !== nothing @test prep.prep_Bᵀ !== nothing + @test build_A(implicit_iterative, prep, x, y, z, c; suggested_backend) isa JVP + @test build_Aᵀ(implicit_iterative, prep, x, y, z, c; suggested_backend) isa VJP + @test build_B(implicit_iterative, prep, x, y, z, c; suggested_backend) isa JVP + @test build_Bᵀ(implicit_iterative, prep, x, y, z, c; suggested_backend) isa VJP end end From 7c6ff285668a5c4bffd63d22766d03b12285c898 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Wed, 2 Jul 2025 10:56:07 +0200 Subject: [PATCH 10/12] Linear solver settings --- src/settings.jl | 19 +++++++++++++------ test/printing.jl | 2 ++ test/systematic.jl | 11 +++++++---- 3 files changed, 22 insertions(+), 10 deletions(-) diff --git a/src/settings.jl b/src/settings.jl index 7b278811..c54e2a33 100644 --- a/src/settings.jl +++ b/src/settings.jl @@ -26,21 +26,28 @@ Specify that linear systems `Ax = b` should be solved with an iterative method. - [`ImplicitFunction`](@ref) - [`DirectLinearSolver`](@ref) """ -struct IterativeLinearSolver{A,K} - algorithm::A +struct IterativeLinearSolver{K} kwargs::K - function IterativeLinearSolver(algorithm=GMRES(); kwargs...) - return new{typeof(algorithm),typeof(kwargs)}(algorithm, kwargs) + function IterativeLinearSolver(; kwargs...) + return new{typeof(kwargs)}(kwargs) end end function (solver::IterativeLinearSolver)(A, b) - x0 = zero(b) - sol, info = linsolve(A, b, x0, solver.algorithm; solver.kwargs...) + sol, info = linsolve(A, b; solver.kwargs...) @assert info.converged == 1 return sol end +function Base.show(io::IO, linear_solver::IterativeLinearSolver) + (; kwargs) = linear_solver + print(io, repr(IterativeLinearSolver; context=io), "(;") + for p in pairs(kwargs) + print(io, " ", p[1], "=", repr(p[2]; context=io), ",") + end + return print(io, ")") +end + ## Representation abstract type AbstractRepresentation end diff --git a/test/printing.jl b/test/printing.jl index dca56350..d1c25991 100644 --- a/test/printing.jl +++ b/test/printing.jl @@ -2,4 +2,6 @@ using TestItems @testitem "Settings" begin @test startswith(string(ImplicitFunction(nothing, nothing)), "ImplicitFunction") + @test startswith(string(IterativeLinearSolver()), "IterativeLinearSolver") + @test startswith(string(IterativeLinearSolver(; rtol=1e-3)), "IterativeLinearSolver") end diff --git a/test/systematic.jl b/test/systematic.jl index 9e4a9a4e..799b67cf 100644 --- a/test/systematic.jl +++ b/test/systematic.jl @@ -23,8 +23,13 @@ end; @testitem "Iterative" setup = [TestUtils] begin using ADTypes, .TestUtils - for (backends, x) in Iterators.product( + for (backends, linear_solver, x) in Iterators.product( [nothing, (; x=AutoForwardDiff(), y=AutoZygote())], + [ + IterativeLinearSolver(), + IterativeLinearSolver(; rtol=1e-8), + IterativeLinearSolver(; issymmetric=true, isposdef=true), + ], [float.(1:3), reshape(float.(1:6), 3, 2)], ) yield() @@ -33,9 +38,7 @@ end; conditions=default_conditions, x=x, implicit_kwargs=(; - representation=OperatorRepresentation(), - linear_solver=IterativeLinearSolver(), - backends, + representation=OperatorRepresentation(), linear_solver, backends ), ) scen2 = add_arg_mult(scen) From 3720ce87d2ec67943928617e20d69a21a04bdc77 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Wed, 2 Jul 2025 11:00:11 +0200 Subject: [PATCH 11/12] More principled printin --- src/implicit_function.jl | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/src/implicit_function.jl b/src/implicit_function.jl index 8e855d31..fcb63768 100644 --- a/src/implicit_function.jl +++ b/src/implicit_function.jl @@ -74,14 +74,17 @@ function Base.show(io::IO, implicit::ImplicitFunction) (; solver, conditions, backends, linear_solver, representation) = implicit return print( io, - """ - ImplicitFunction( - $solver, - $conditions; - representation=$representation, - linear_solver=$linear_solver, - backends=$backends, - ) - """, + repr(ImplicitFunction; context=io), + "(", + repr(solver; context=io), + ", ", + repr(conditions; context=io), + "; representation=", + repr(representation; context=io), + ", linear_solver=", + repr(linear_solver; context=io), + ", backends=", + repr(backends; context=io), + ")", ) end From fd4e50f46f102303b6f17f153613f2ad5700cb35 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Wed, 2 Jul 2025 11:58:11 +0200 Subject: [PATCH 12/12] Fix tests --- src/ImplicitDifferentiation.jl | 2 +- test/printing.jl | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/ImplicitDifferentiation.jl b/src/ImplicitDifferentiation.jl index 9f74cb4c..d5880daa 100644 --- a/src/ImplicitDifferentiation.jl +++ b/src/ImplicitDifferentiation.jl @@ -19,7 +19,7 @@ using DifferentiationInterface: prepare_pushforward_same_point, pullback, pushforward -using KrylovKit: GMRES, linsolve +using KrylovKit: linsolve using LinearAlgebra: factorize include("utils.jl") diff --git a/test/printing.jl b/test/printing.jl index d1c25991..0190649a 100644 --- a/test/printing.jl +++ b/test/printing.jl @@ -1,7 +1,7 @@ using TestItems @testitem "Settings" begin - @test startswith(string(ImplicitFunction(nothing, nothing)), "ImplicitFunction") - @test startswith(string(IterativeLinearSolver()), "IterativeLinearSolver") - @test startswith(string(IterativeLinearSolver(; rtol=1e-3)), "IterativeLinearSolver") + @test contains(string(ImplicitFunction(nothing, nothing)), "ImplicitFunction") + @test contains(string(IterativeLinearSolver()), "IterativeLinearSolver") + @test contains(string(IterativeLinearSolver(; rtol=1e-3)), "IterativeLinearSolver") end