diff --git a/Project.toml b/Project.toml index f4bf0b47..c525894c 100644 --- a/Project.toml +++ b/Project.toml @@ -6,10 +6,8 @@ version = "0.9.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" -Krylov = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7" +KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -LinearMaps = "7a12625a-238d-50fd-b39a-03d52299707e" -LinearOperators = "5c8ed15e-5a4c-59e4-a42b-c7e8811fb125" [weakdeps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" @@ -35,9 +33,8 @@ ForwardDiff = "0.10.36, 1" JET = "0.9, 0.10" JuliaFormatter = "2.1.2" Krylov = "0.9.6, 0.10" +KrylovKit = "0.9.5" LinearAlgebra = "1" -LinearMaps = "3.11.4" -LinearOperators = "2.8.0" NLsolve = "4.5.1" Optim = "1.12.0" Random = "1" diff --git a/docs/Project.toml b/docs/Project.toml index d3a46cbc..b328c3ba 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -6,7 +6,6 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" ImplicitDifferentiation = "57b37032-215b-411a-8a7c-41a003a55207" -Krylov = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306" NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56" @@ -16,4 +15,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] -Documenter = "1.3" \ No newline at end of file +Documenter = "1.3" diff --git a/docs/src/api.md b/docs/src/api.md index d72019b6..5c6c3da8 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -23,7 +23,6 @@ MatrixRepresentation OperatorRepresentation IterativeLinearSolver DirectLinearSolver -prepare_implicit ``` ## Internals diff --git a/examples/3_tricks.jl b/examples/3_tricks.jl index 3d47251c..820628a8 100644 --- a/examples/3_tricks.jl +++ b/examples/3_tricks.jl @@ -44,7 +44,9 @@ end; # And build your implicit function like so: -implicit_components = ImplicitFunction(forward_components, conditions_components); +implicit_components = ImplicitFunction( + forward_components, conditions_components; strict=Val(false) +); # Now we're good to go. diff --git a/ext/ImplicitDifferentiationChainRulesCoreExt.jl b/ext/ImplicitDifferentiationChainRulesCoreExt.jl index 341fc37d..c86bafa4 100644 --- a/ext/ImplicitDifferentiationChainRulesCoreExt.jl +++ b/ext/ImplicitDifferentiationChainRulesCoreExt.jl @@ -14,34 +14,38 @@ using ImplicitDifferentiation: # not covered by Codecov for now ImplicitDifferentiation.chainrules_suggested_backend(rc::RuleConfig) = AutoChainRules(rc) +struct ImplicitPullback{TA,TB,TL,TP,Nargs} + Aᵀ::TA + Bᵀ::TB + linear_solver::TL + project_x::TP + _Nargs::Val{Nargs} +end + +function (pb::ImplicitPullback{TA,TB,TL,TP,Nargs})((dy, dz)) where {TA,TB,TL,TP,Nargs} + (; Aᵀ, Bᵀ, linear_solver, project_x) = pb + dc = linear_solver(Aᵀ, -unthunk(dy)) + dx = Bᵀ(dc) + df = NoTangent() + dargs = ntuple(unimplemented_tangent, Val(Nargs)) + return (df, project_x(dx), dargs...) +end + function ChainRulesCore.rrule( - rc::RuleConfig, - implicit::ImplicitFunction, - prep::ImplicitFunctionPreparation, - x::AbstractArray, - args::Vararg{Any,N}; + rc::RuleConfig, implicit::ImplicitFunction, x::AbstractArray, args::Vararg{Any,N}; ) where {N} + (; conditions, linear_solver) = implicit y, z = implicit(x, args...) - c = implicit.conditions(x, y, z, args...) + c = conditions(x, y, z, args...) suggested_backend = chainrules_suggested_backend(rc) + prep = ImplicitFunctionPreparation(eltype(x)) Aᵀ = build_Aᵀ(implicit, prep, x, y, z, c, args...; suggested_backend) Bᵀ = build_Bᵀ(implicit, prep, x, y, z, c, args...; suggested_backend) project_x = ProjectTo(x) - function implicit_pullback_prepared((dy, dz)) - dy = unthunk(dy) - dy_vec = vec(dy) - dc_vec = implicit.linear_solver(Aᵀ, -dy_vec) - dx_vec = Bᵀ(dc_vec) - dx = reshape(dx_vec, size(x)) - df = NoTangent() - dprep = @not_implemented("Tangents for mutable arguments are not defined") - dargs = ntuple(unimplemented_tangent, N) - return (df, dprep, project_x(dx), dargs...) - end - - return (y, z), implicit_pullback_prepared + implicit_pullback = ImplicitPullback(Aᵀ, Bᵀ, linear_solver, project_x, Val(N)) + return (y, z), implicit_pullback end function unimplemented_tangent(_) @@ -49,5 +53,4 @@ function unimplemented_tangent(_) "Tangents for positional arguments of an `ImplicitFunction` beyond `x` (the first one) are not implemented" ) end - end diff --git a/ext/ImplicitDifferentiationForwardDiffExt.jl b/ext/ImplicitDifferentiationForwardDiffExt.jl index 9f8bae71..5228c270 100644 --- a/ext/ImplicitDifferentiationForwardDiffExt.jl +++ b/ext/ImplicitDifferentiationForwardDiffExt.jl @@ -6,7 +6,7 @@ using ImplicitDifferentiation: ImplicitFunction, ImplicitFunctionPreparation, build_A, build_B function (implicit::ImplicitFunction)( - prep::ImplicitFunctionPreparation, x_and_dx::AbstractArray{Dual{T,R,N}}, args... + prep::ImplicitFunctionPreparation{R}, x_and_dx::AbstractArray{Dual{T,R,N}}, args... ) where {T,R,N} x = value.(x_and_dx) y, z = implicit(x, args...) @@ -19,14 +19,9 @@ function (implicit::ImplicitFunction)( dX = ntuple(Val(N)) do k partials.(x_and_dx, k) end - dC_vec = map(dX) do dₖx - dₖx_vec = vec(dₖx) - dₖc_vec = B(dₖx_vec) - return dₖc_vec - end - dY = map(dC_vec) do dₖc_vec - dₖy_vec = implicit.linear_solver(A, -dₖc_vec) - dₖy = reshape(dₖy_vec, size(y)) + dC = map(B, dX) + dY = map(dC) do dₖc + dₖy = implicit.linear_solver(A, -dₖc) return dₖy end @@ -37,4 +32,11 @@ function (implicit::ImplicitFunction)( return y_and_dy, z end +function (implicit::ImplicitFunction)( + x_and_dx::AbstractArray{Dual{T,R,N}}, args... +) where {T,R,N} + prep = ImplicitFunctionPreparation(R) + return implicit(prep, x_and_dx, args...) +end + end diff --git a/src/ImplicitDifferentiation.jl b/src/ImplicitDifferentiation.jl index c601f669..d5880daa 100644 --- a/src/ImplicitDifferentiation.jl +++ b/src/ImplicitDifferentiation.jl @@ -17,13 +17,9 @@ using DifferentiationInterface: prepare_pullback_same_point, prepare_pushforward, prepare_pushforward_same_point, - pullback!, pullback, - pushforward!, pushforward -using Krylov: Krylov, krylov_workspace, krylov_solve!, solution -using LinearOperators: LinearOperator -using LinearMaps: FunctionMap +using KrylovKit: linsolve using LinearAlgebra: factorize include("utils.jl") @@ -36,6 +32,5 @@ include("callable.jl") export MatrixRepresentation, OperatorRepresentation export IterativeLinearSolver, DirectLinearSolver export ImplicitFunction -export prepare_implicit end diff --git a/src/callable.jl b/src/callable.jl index 5ce5f172..b9187af3 100644 --- a/src/callable.jl +++ b/src/callable.jl @@ -1,9 +1,9 @@ function (implicit::ImplicitFunction)(x::AbstractArray, args::Vararg{Any,N}) where {N} - return implicit(ImplicitFunctionPreparation(), x, args...) + return implicit(ImplicitFunctionPreparation(eltype(x)), x, args...) end function (implicit::ImplicitFunction)( - ::ImplicitFunctionPreparation, x::AbstractArray, args::Vararg{Any,N} -) where {N} + ::ImplicitFunctionPreparation{R}, x::AbstractArray{R}, args::Vararg{Any,N} +) where {R<:Real,N} return implicit.solver(x, args...) end diff --git a/src/execution.jl b/src/execution.jl index 5d543de9..39393830 100644 --- a/src/execution.jl +++ b/src/execution.jl @@ -1,43 +1,29 @@ -struct JVP!{F,P,B,I,V,C} +struct JVP{F,P,B,I,C} f::F prep::P backend::B input::I - v_buffer::V contexts::C end -struct VJP!{F,P,B,I,V,C} +struct VJP{F,P,B,I,C} f::F prep::P backend::B input::I - v_buffer::V contexts::C end -function (po::JVP!)(res::AbstractVector, v_wrongtype::AbstractVector) - (; f, backend, input, v_buffer, contexts, prep) = po - if typeof(v_buffer) == typeof(v_wrongtype) - v = v_wrongtype - else - copyto!(v_buffer, v_wrongtype) - v = v_buffer - end - pushforward!(f, (res,), prep, backend, input, (v,), contexts...) - return res +function (po::JVP)(v) + (; f, backend, input, contexts, prep) = po + res = pushforward(f, prep, backend, input, (v,), contexts...) + return only(res) end -function (po::VJP!)(res::AbstractVector, v_wrongtype::AbstractVector) - (; f, backend, input, v_buffer, contexts, prep) = po - if typeof(v_buffer) == typeof(v_wrongtype) - v = v_wrongtype - else - copyto!(v_buffer, v_wrongtype) - v = v_buffer - end - pullback!(f, (res,), prep, backend, input, (v,), contexts...) - return res +function (po::VJP)(v) + (; f, backend, input, contexts, prep) = po + res = pullback(f, prep, backend, input, (v,), contexts...) + return only(res) end ## A @@ -64,56 +50,34 @@ function build_A_aux( (; prep_A) = prep actual_backend = isnothing(backends) ? suggested_backend : backends.y contexts = (Constant(x), Constant(z), map(Constant, args)...) + f = Switch12(conditions) if isnothing(prep_A) - A = jacobian(Switch12(conditions), actual_backend, y, contexts...) + A = jacobian(f, actual_backend, y, contexts...) else - A = jacobian(Switch12(conditions), prep_A, actual_backend, y, contexts...) + A = jacobian(f, prep_A, actual_backend, y, contexts...) end return factorize(A) end function build_A_aux( - ::OperatorRepresentation{package,symmetric,hermitian,posdef}, - implicit, - prep, - x, - y, - z, - c, - args...; - suggested_backend, -) where {package,symmetric,hermitian,posdef} - T = Base.promote_eltype(x, y, c) + ::OperatorRepresentation, implicit, prep, x, y, z, c, args...; suggested_backend +) (; conditions, backends) = implicit (; prep_A) = prep actual_backend = isnothing(backends) ? suggested_backend : backends.y contexts = (Constant(x), Constant(z), map(Constant, args)...) - f_vec = VecToVec(Switch12(conditions), y) - y_vec = vec(y) - dy_vec = vec(zero(y)) + f = Switch12(conditions) if isnothing(prep_A) prep_A_same = prepare_pushforward_same_point( - f_vec, actual_backend, y_vec, (dy_vec,), contexts...; strict=implicit.strict + f, actual_backend, y, (zero(y),), contexts...; strict=implicit.strict ) else prep_A_same = prepare_pushforward_same_point( - f_vec, prep_A, actual_backend, y_vec, (dy_vec,), contexts... - ) - end - prod! = JVP!(f_vec, prep_A_same, actual_backend, y_vec, dy_vec, contexts) - if package == :LinearOperators - return LinearOperator(T, length(c), length(y), symmetric, hermitian, prod!) - elseif package == :LinearMaps - return FunctionMap{T}( - prod!, - length(c), - length(y); - ismutating=true, - issymmetric=symmetric, - ishermitian=hermitian, - isposdef=posdef, + f, prep_A, actual_backend, y, (zero(y),), contexts... ) end + A = JVP(f, prep_A_same, actual_backend, y, contexts) + return A end ## Aᵀ @@ -140,58 +104,34 @@ function build_Aᵀ_aux( (; prep_Aᵀ) = prep actual_backend = isnothing(backends) ? suggested_backend : backends.y contexts = (Constant(x), Constant(z), map(Constant, args)...) + f = Switch12(conditions) if isnothing(prep_Aᵀ) - Aᵀ = transpose(jacobian(Switch12(conditions), actual_backend, y, contexts...)) + Aᵀ = transpose(jacobian(f, actual_backend, y, contexts...)) else - Aᵀ = transpose( - jacobian(Switch12(conditions), prep_Aᵀ, actual_backend, y, contexts...) - ) + Aᵀ = transpose(jacobian(f, prep_Aᵀ, actual_backend, y, contexts...)) end return factorize(Aᵀ) end function build_Aᵀ_aux( - ::OperatorRepresentation{package,symmetric,hermitian,posdef}, - implicit, - prep, - x, - y, - z, - c, - args...; - suggested_backend, -) where {package,symmetric,hermitian,posdef} - T = Base.promote_eltype(x, y, c) + ::OperatorRepresentation, implicit, prep, x, y, z, c, args...; suggested_backend +) (; conditions, backends) = implicit (; prep_Aᵀ) = prep actual_backend = isnothing(backends) ? suggested_backend : backends.y contexts = (Constant(x), Constant(z), map(Constant, args)...) - f_vec = VecToVec(Switch12(conditions), y) - y_vec = vec(y) - dc_vec = vec(zero(c)) + f = Switch12(conditions) if isnothing(prep_Aᵀ) prep_Aᵀ_same = prepare_pullback_same_point( - f_vec, actual_backend, y_vec, (dc_vec,), contexts...; strict=implicit.strict + f, actual_backend, y, (zero(c),), contexts...; strict=implicit.strict ) else prep_Aᵀ_same = prepare_pullback_same_point( - f_vec, prep_Aᵀ, actual_backend, y_vec, (dc_vec,), contexts... - ) - end - prod! = VJP!(f_vec, prep_Aᵀ_same, actual_backend, y_vec, dc_vec, contexts) - if package == :LinearOperators - return LinearOperator(T, length(y), length(c), symmetric, hermitian, prod!) - elseif package == :LinearMaps - return FunctionMap{T}( - prod!, - length(y), - length(c); - ismutating=true, - issymmetric=symmetric, - ishermitian=hermitian, - isposdef=posdef, + f, prep_Aᵀ, actual_backend, y, (zero(c),), contexts... ) end + A = VJP(f, prep_Aᵀ_same, actual_backend, y, contexts) + return A end ## B @@ -210,26 +150,17 @@ function build_B( (; prep_B) = prep actual_backend = isnothing(backends) ? suggested_backend : backends.x contexts = (Constant(y), Constant(z), map(Constant, args)...) - f_vec = VecToVec(conditions, x) - x_vec = vec(x) - dx_vec = vec(zero(x)) if isnothing(prep_B) prep_B_same = prepare_pushforward_same_point( - f_vec, actual_backend, x_vec, (dx_vec,), contexts... + conditions, actual_backend, x, (zero(x),), contexts...; strict=implicit.strict ) else prep_B_same = prepare_pushforward_same_point( - f_vec, prep_B, actual_backend, x_vec, (dx_vec,), contexts... + conditions, prep_B, actual_backend, x, (zero(x),), contexts... ) end - function B_fun(dx_vec_wrongtype) - @assert typeof(dx_vec) == typeof(dx_vec_wrongtype) - dx_vec = dx_vec_wrongtype - return pushforward( - f_vec, prep_B_same, actual_backend, x_vec, (dx_vec,), contexts... - )[1] - end - return B_fun + B = JVP(conditions, prep_B_same, actual_backend, x, contexts) + return B end ## Bᵀ @@ -248,25 +179,15 @@ function build_Bᵀ( (; prep_Bᵀ) = prep actual_backend = isnothing(backends) ? suggested_backend : backends.x contexts = (Constant(y), Constant(z), map(Constant, args)...) - f_vec = VecToVec(conditions, x) - x_vec = vec(x) - dc_vec = vec(zero(c)) if isnothing(prep_Bᵀ) prep_Bᵀ_same = prepare_pullback_same_point( - f_vec, actual_backend, x_vec, (dc_vec,), contexts...; strict=implicit.strict + conditions, actual_backend, x, (zero(c),), contexts...; strict=implicit.strict ) else prep_Bᵀ_same = prepare_pullback_same_point( - f_vec, prep_Bᵀ, actual_backend, x_vec, (dc_vec,), contexts... + conditions, prep_Bᵀ, actual_backend, x, (zero(c),), contexts... ) end - function Bᵀ_fun(dc_vec_wrongtype) - if typeof(dc_vec) == typeof(dc_vec_wrongtype) - dc_vec = dc_vec_wrongtype - else - copyto!(dc_vec, dc_vec_wrongtype) - end - return pullback(f_vec, prep_Bᵀ_same, actual_backend, x_vec, (dc_vec,), contexts...)[1] - end - return Bᵀ_fun + Bᵀ = VJP(conditions, prep_Bᵀ_same, actual_backend, x, contexts) + return Bᵀ end diff --git a/src/implicit_function.jl b/src/implicit_function.jl index a12d72e4..fcb63768 100644 --- a/src/implicit_function.jl +++ b/src/implicit_function.jl @@ -23,6 +23,7 @@ This requires solving a linear system `A * J = -B` where `A = ∂₂c`, `B = ∂ representation=OperatorRepresentation(), linear_solver=IterativeLinearSolver(), backends=nothing, + strict=Val(true), ) ## Positional arguments @@ -32,9 +33,10 @@ This requires solving a linear system `A * J = -B` where `A = ∂₂c`, `B = ∂ ## Keyword arguments -- `representation`: defines how the partial Jacobian `A` of the conditions with respect to the output is represented, either [`MatrixRepresentation`](@ref) or [`OperatorRepresentation`](@ref). +- `representation`: defines how the partial Jacobian `A` of the conditions with respect to the output is represented. It can be either [`MatrixRepresentation`](@ref) or [`OperatorRepresentation`](@ref). - `linear_solver`: specifies how the linear system `A * J = -B` will be solved in the implicit function theorem. It can be either [`DirectLinearSolver`](@ref) or [`IterativeLinearSolver`](@ref). - `backends::AbstractADType`: specifies how the `conditions` will be differentiated with respect to `x` and `y`. It can be either, `nothing`, which means that the external autodiff system will be used, or a named tuple `(; x=AutoSomething(), y=AutoSomethingElse())` of backend objects from [ADTypes.jl](https://github.com/SciML/ADTypes.jl). +- `strict::Val`: specifies whether preparation inside [DifferentiationInterface.jl](https://github.com/JuliaDiff/DifferentiationInterface.jl) should enforce a strict match between the primal variables and the provided tangents. """ struct ImplicitFunction{ F, @@ -72,14 +74,17 @@ function Base.show(io::IO, implicit::ImplicitFunction) (; solver, conditions, backends, linear_solver, representation) = implicit return print( io, - """ - ImplicitFunction( - $solver, - $conditions; - representation=$representation, - linear_solver=$linear_solver, - backends=$backends, - ) - """, + repr(ImplicitFunction; context=io), + "(", + repr(solver; context=io), + ", ", + repr(conditions; context=io), + "; representation=", + repr(representation; context=io), + ", linear_solver=", + repr(linear_solver; context=io), + ", backends=", + repr(backends; context=io), + ")", ) end diff --git a/src/preparation.jl b/src/preparation.jl index 4183f757..dd01b90c 100644 --- a/src/preparation.jl +++ b/src/preparation.jl @@ -8,15 +8,16 @@ - `prep_B`: preparation for `B` (derivative of conditions with respect to `x`) in forward mode - `prep_Bᵀ`: preparation for `B` (derivative of conditions with respect to `x`) in reverse mode """ -struct ImplicitFunctionPreparation{PA,PAT,PB,PBT} +struct ImplicitFunctionPreparation{R<:Real,PA,PAT,PB,PBT} + _R::Type{R} prep_A::PA prep_Aᵀ::PAT prep_B::PB prep_Bᵀ::PBT end -function ImplicitFunctionPreparation() - return ImplicitFunctionPreparation(nothing, nothing, nothing, nothing) +function ImplicitFunctionPreparation(::Type{R}) where {R<:Real} + return ImplicitFunctionPreparation(R, nothing, nothing, nothing, nothing) end """ @@ -28,7 +29,12 @@ end strict=Val(true) ) -Uses the preparation mechanism from [DifferentiationInterface.jl](https://github.com/JuliaDiff/DifferentiationInterface.jl) to speed up subsequent calls to `implicit(x, args...)` where `(x, args...)` are similar to `(x_prep, args_prep...)`. +Uses the preparation mechanism from [DifferentiationInterface.jl](https://github.com/JuliaDiff/DifferentiationInterface.jl) to speed up subsequent differentiated calls to `implicit(x, args...)` where `(x, args...)` are similar to `(x_prep, args_prep...)`. + +The `mode` argument is an object from [ADTypes.jl](https://github.com/SciML/ADTypes.jl) that specifies whether the preparation should target `ForwardMode`, `ReverseMode` or both (`ForwardOrReverseMode`). + +!!! warning + This mechanism is not yet part of the public API, use it at your own risk. """ function prepare_implicit( mode::AbstractMode, implicit::ImplicitFunction, x, args::Vararg{Any,N}; strict=Val(true) @@ -65,7 +71,7 @@ function prepare_implicit( prep_Bᵀ = nothing end end - return ImplicitFunctionPreparation(prep_A, prep_Aᵀ, prep_B, prep_Bᵀ) + return ImplicitFunctionPreparation(eltype(x), prep_A, prep_Aᵀ, prep_B, prep_Bᵀ) end function prepare_A( @@ -95,10 +101,9 @@ function prepare_A( strict::Val, ) contexts = (Constant(x), Constant(z), map(Constant, args)...) - f_vec = VecToVec(Switch12(conditions), y) - y_vec = vec(y) - dy_vec = vec(zero(y)) - return prepare_pushforward(f_vec, backend, y_vec, (dy_vec,), contexts...; strict) + return prepare_pushforward( + Switch12(conditions), backend, y, (zero(y),), contexts...; strict + ) end function prepare_Aᵀ( @@ -128,10 +133,9 @@ function prepare_Aᵀ( strict::Val, ) contexts = (Constant(x), Constant(z), map(Constant, args)...) - f_vec = VecToVec(Switch12(conditions), y) - y_vec = vec(y) - dc_vec = vec(zero(c)) - return prepare_pullback(f_vec, backend, y_vec, (dc_vec,), contexts...; strict) + return prepare_pullback( + Switch12(conditions), backend, y, (zero(c),), contexts...; strict + ) end function prepare_B( @@ -146,10 +150,7 @@ function prepare_B( strict::Val, ) contexts = (Constant(y), Constant(z), map(Constant, args)...) - f_vec = VecToVec(conditions, x) - x_vec = vec(x) - dx_vec = vec(zero(x)) - return prepare_pushforward(f_vec, backend, x_vec, (dx_vec,), contexts...; strict) + return prepare_pushforward(conditions, backend, x, (zero(x),), contexts...; strict) end function prepare_Bᵀ( @@ -164,8 +165,5 @@ function prepare_Bᵀ( strict::Val, ) contexts = (Constant(y), Constant(z), map(Constant, args)...) - f_vec = VecToVec(conditions, x) - x_vec = vec(x) - dc_vec = vec(zero(c)) - return prepare_pullback(f_vec, backend, x_vec, (dc_vec,), contexts...; strict) + return prepare_pullback(conditions, backend, x, (zero(c),), contexts...; strict) end diff --git a/src/settings.jl b/src/settings.jl index 3b739f25..c54e2a33 100644 --- a/src/settings.jl +++ b/src/settings.jl @@ -4,6 +4,11 @@ DirectLinearSolver Specify that linear systems `Ax = b` should be solved with a direct method. + +# See also + +- [`ImplicitFunction`](@ref) +- [`IterativeLinearSolver`](@ref) """ struct DirectLinearSolver end @@ -16,38 +21,31 @@ end Specify that linear systems `Ax = b` should be solved with an iterative method. -# Constructor - - IterativeLinearSolver(::Val{method}=Val(:gmres); kwargs...) +# See also -The `method` symbol is used to pick the appropriate algorithm from [Krylov.jl](https://github.com/JuliaSmoothOptimizers/Krylov.jl). -Keyword arguments are passed on to that algorithm. +- [`ImplicitFunction`](@ref) +- [`DirectLinearSolver`](@ref) """ -struct IterativeLinearSolver{method,K} - _method::Val{method} +struct IterativeLinearSolver{K} kwargs::K - function IterativeLinearSolver((::Val{method})=Val(:gmres); kwargs...) where {method} - return new{method,typeof(kwargs)}(Val(method), kwargs) + function IterativeLinearSolver(; kwargs...) + return new{typeof(kwargs)}(kwargs) end end -function Base.show(io::IO, linear_solver::IterativeLinearSolver{method}) where {method} - print(io, "IterativeLinearSolver{$(repr(method))}") - if isempty(linear_solver.kwargs) - print(io, "()") - else - print(io, "(; ") - for (k, v) in pairs(linear_solver.kwargs) - print(io, "$k=$(repr(v)), ") - end - print(io, ")") - end +function (solver::IterativeLinearSolver)(A, b) + sol, info = linsolve(A, b; solver.kwargs...) + @assert info.converged == 1 + return sol end -function (solver::IterativeLinearSolver{method})(A, b::AbstractVector) where {method} - workspace = krylov_workspace(Val(method), A, b) - krylov_solve!(workspace, A, b) - return solution(workspace) +function Base.show(io::IO, linear_solver::IterativeLinearSolver) + (; kwargs) = linear_solver + print(io, repr(IterativeLinearSolver; context=io), "(;") + for p in pairs(kwargs) + print(io, " ", p[1], "=", repr(p[2]; context=io), ",") + end + return print(io, ")") end ## Representation @@ -69,43 +67,13 @@ struct MatrixRepresentation <: AbstractRepresentation end """ OperatorRepresentation -Specify that the matrix `A` involved in the implicit function theorem should be represented lazily. - -# Constructors - - OperatorRepresentation(; symmetric=false, hermitian=false, posdef=false) - OperatorRepresentation{package}(; symmetric=false, hermitian=false, posdef=false) - -The type parameter `package` can be either: - -- `:LinearOperators` to use a wrapper from [LinearOperators.jl](https://github.com/JuliaSmoothOptimizers/LinearOperators.jl) (the default) -- `:LinearMaps` to use a wrapper from [LinearMaps.jl](https://github.com/JuliaLinearAlgebra/LinearMaps.jl) - -The keyword arguments `symmetric`, `hermitian` and `posdef` give additional properties of the Jacobian of the `conditions` with respect to the solution `y`, which are useful to the solver in case you can prove them. +Specify that the matrix `A` involved in the implicit function theorem should be represented lazily, as a function. # See also - [`ImplicitFunction`](@ref) - [`MatrixRepresentation`](@ref) """ -struct OperatorRepresentation{package,symmetric,hermitian,posdef} <: AbstractRepresentation - function OperatorRepresentation{package}(; - symmetric::Bool=false, hermitian::Bool=false, posdef::Bool=false - ) where {package} - @assert package in [:LinearOperators, :LinearMaps] - return new{package,symmetric,hermitian,posdef}() - end -end - -function Base.show( - io::IO, ::OperatorRepresentation{package,symmetric,hermitian,posdef} -) where {package,symmetric,hermitian,posdef} - return print( - io, - "OperatorRepresentation{$(repr(package))}(; symmetric=$symmetric, hermitian=$hermitian, posdef=$posdef)", - ) -end - -OperatorRepresentation(; kwargs...) = OperatorRepresentation{:LinearOperators}(; kwargs...) +struct OperatorRepresentation <: AbstractRepresentation end function chainrules_suggested_backend end diff --git a/src/utils.jl b/src/utils.jl index 5dc612eb..807e6606 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -13,24 +13,3 @@ end function (s12::Switch12)(arg1, arg2, other_args::Vararg{Any,N}) where {N} return s12.f(arg2, arg1, other_args...) end - -""" - VecToVec - -Represent a function which behaves like `f`, except that the first argument is expected as a vector, and the return is converted to a vector: - f(a1, a2, a3) = b -becomes - g(a1_vec, a2, a3) = vec(f(reshape(a1_vec, size(a1)), a2, a3)) -""" -struct VecToVec{F,N} - f::F - arg1_size::NTuple{N,Int} -end - -VecToVec(f::F, arg1_example::AbstractArray) where {F} = VecToVec(f, size(arg1_example)) - -function (v2v::VecToVec)(arg1_vec::AbstractVector, other_args::Vararg{Any,N}) where {N} - arg1 = reshape(arg1_vec, v2v.arg1_size) - res = v2v.f(arg1, other_args...) - return vec(res) -end diff --git a/test/preparation.jl b/test/preparation.jl index 32d59b74..d898b7f9 100644 --- a/test/preparation.jl +++ b/test/preparation.jl @@ -1,5 +1,7 @@ @testitem "Preparation" begin using ImplicitDifferentiation + using ImplicitDifferentiation: + prepare_implicit, build_A, build_Aᵀ, build_B, build_Bᵀ, JVP, VJP using ADTypes using ADTypes: ForwardOrReverseMode, ForwardMode, ReverseMode using ForwardDiff: ForwardDiff @@ -8,11 +10,22 @@ solver(x) = sqrt.(x), nothing conditions(x, y, z) = y .^ 2 .- x + implicit = ImplicitFunction( + solver, + conditions; + backends=(; x=AutoForwardDiff(), y=AutoForwardDiff()), + representation=MatrixRepresentation(), + ) + implicit_iterative = ImplicitFunction( solver, conditions; backends=(; x=AutoForwardDiff(), y=AutoForwardDiff()) ) implicit_nobackends = ImplicitFunction(solver, conditions) + x = rand(5) + y, z = implicit(x) + c = conditions(x, y, z) + suggested_backend = AutoEnzyme() @testset "None" begin prep = prepare_implicit(ForwardOrReverseMode(), implicit_nobackends, x) @@ -28,6 +41,10 @@ @test prep.prep_Aᵀ === nothing @test prep.prep_B !== nothing @test prep.prep_Bᵀ === nothing + @test build_A(implicit, prep, x, y, z, c; suggested_backend) isa AbstractMatrix + @test build_Aᵀ(implicit, prep, x, y, z, c; suggested_backend) isa AbstractMatrix + @test build_B(implicit, prep, x, y, z, c; suggested_backend) isa JVP + @test build_Bᵀ(implicit, prep, x, y, z, c; suggested_backend) isa VJP end @testset "ReverseMode" begin @@ -36,13 +53,21 @@ @test prep.prep_Aᵀ !== nothing @test prep.prep_B === nothing @test prep.prep_Bᵀ !== nothing + @test build_A(implicit, prep, x, y, z, c; suggested_backend) isa AbstractMatrix + @test build_Aᵀ(implicit, prep, x, y, z, c; suggested_backend) isa AbstractMatrix + @test build_B(implicit, prep, x, y, z, c; suggested_backend) isa JVP + @test build_Bᵀ(implicit, prep, x, y, z, c; suggested_backend) isa VJP end @testset "Both" begin - prep = prepare_implicit(ForwardOrReverseMode(), implicit, x) + prep = prepare_implicit(ForwardOrReverseMode(), implicit_iterative, x) @test prep.prep_A !== nothing @test prep.prep_Aᵀ !== nothing @test prep.prep_B !== nothing @test prep.prep_Bᵀ !== nothing + @test build_A(implicit_iterative, prep, x, y, z, c; suggested_backend) isa JVP + @test build_Aᵀ(implicit_iterative, prep, x, y, z, c; suggested_backend) isa VJP + @test build_B(implicit_iterative, prep, x, y, z, c; suggested_backend) isa JVP + @test build_Bᵀ(implicit_iterative, prep, x, y, z, c; suggested_backend) isa VJP end end diff --git a/test/printing.jl b/test/printing.jl index 2d94f78c..0190649a 100644 --- a/test/printing.jl +++ b/test/printing.jl @@ -1,8 +1,7 @@ using TestItems @testitem "Settings" begin - @test startswith(string(ImplicitFunction(nothing, nothing)), "ImplicitFunction") - @test startswith(string(OperatorRepresentation()), "Operator") - @test startswith(string(IterativeLinearSolver(; atol=1e-5)), "Iterative") - @test startswith(string(IterativeLinearSolver()), "Iterative") + @test contains(string(ImplicitFunction(nothing, nothing)), "ImplicitFunction") + @test contains(string(IterativeLinearSolver()), "IterativeLinearSolver") + @test contains(string(IterativeLinearSolver(; rtol=1e-3)), "IterativeLinearSolver") end diff --git a/test/systematic.jl b/test/systematic.jl index c358f583..799b67cf 100644 --- a/test/systematic.jl +++ b/test/systematic.jl @@ -2,10 +2,8 @@ using TestItems @testitem "Direct" setup = [TestUtils] begin using ADTypes, .TestUtils - for (backends, x) in Iterators.product( - [nothing, (; x=AutoForwardDiff(), y=AutoZygote())], - [float.(1:3), reshape(float.(1:6), 3, 2)], - ) + for (backends, x) in + Iterators.product([nothing, (; x=AutoForwardDiff(), y=AutoZygote())], [float.(1:3)]) yield() scen = Scenario(; solver=default_solver, @@ -25,8 +23,13 @@ end; @testitem "Iterative" setup = [TestUtils] begin using ADTypes, .TestUtils - for (backends, x) in Iterators.product( + for (backends, linear_solver, x) in Iterators.product( [nothing, (; x=AutoForwardDiff(), y=AutoZygote())], + [ + IterativeLinearSolver(), + IterativeLinearSolver(; rtol=1e-8), + IterativeLinearSolver(; issymmetric=true, isposdef=true), + ], [float.(1:3), reshape(float.(1:6), 3, 2)], ) yield() @@ -35,14 +38,12 @@ end; conditions=default_conditions, x=x, implicit_kwargs=(; - representation=OperatorRepresentation{:LinearOperators}(), - linear_solver=IterativeLinearSolver(), - backends, + representation=OperatorRepresentation(), linear_solver, backends ), ) scen2 = add_arg_mult(scen) - test_implicit(scen) - test_implicit(scen2) + test_implicit(scen; type_stability=VERSION >= v"1.11") + test_implicit(scen2; type_stability=VERSION >= v"1.11") end end; diff --git a/test/utils.jl b/test/utils.jl index 1b6d5992..b7fcd460 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -97,7 +97,7 @@ end tag(::AbstractArray{<:ForwardDiff.Dual{T}}) where {T} = T -function test_implicit_duals(scen::Scenario) +function test_implicit_duals(scen::Scenario; type_stability::Bool) implicit = ImplicitFunction( NonDifferentiable(scen.solver), scen.conditions; scen.implicit_kwargs... ) @@ -125,6 +125,9 @@ function test_implicit_duals(scen::Scenario) @test y ≈ y_true @test dy ≈ dy_true @test z == z_true + if type_stability + @inferred implicit(prep, x_and_dx, scen.args...) + end end @testset "Unrepared" begin y_and_dy, z = implicit(x_and_dx, scen.args...) @@ -134,17 +137,17 @@ function test_implicit_duals(scen::Scenario) @test y ≈ y_true @test dy ≈ dy_true @test z == z_true + if type_stability + @inferred implicit(x_and_dx, scen.args...) + end end end end -function compare_pullbacks(dimpl, dx, dx_true) end - -function test_implicit_rrule(scen::Scenario) +function test_implicit_rrule(scen::Scenario; type_stability::Bool) implicit = ImplicitFunction( NonDifferentiable(scen.solver), scen.conditions; scen.implicit_kwargs... ) - prep = prepare_implicit(ReverseMode(), implicit, scen.x_prep, scen.args_prep...) y_true, z_true = scen.solver(scen.x, scen.args...) dy = similar(y_true) @@ -156,21 +159,15 @@ function test_implicit_rrule(scen::Scenario) )[1] @testset "ChainRule" begin - @testset "Prepared" begin - (y, z), pb = rrule(ZygoteRuleConfig(), implicit, prep, scen.x, scen.args...) - dimpl, dprep, dx = pb((dy, dz)) - @test y ≈ y_true - @test z == z_true - @test dimpl isa NoTangent - @test dx ≈ dx_true - end - @testset "Unprepared" begin - (y, z), pb = rrule_via_ad(ZygoteRuleConfig(), implicit, scen.x, scen.args...) - dimpl, dx = pb((dy, dz)) - @test y ≈ y_true - @test z == z_true - @test dimpl isa NoTangent - @test dx ≈ dx_true + (y, z), pb = rrule_via_ad(ZygoteRuleConfig(), implicit, scen.x, scen.args...) + dimpl, dx = pb((dy, dz)) + @test y ≈ y_true + @test z == z_true + @test dimpl isa NoTangent + @test dx ≈ dx_true + if type_stability + @inferred rrule_via_ad(ZygoteRuleConfig(), implicit, scen.x, scen.args...) + @inferred pb((dy, dz)) end end end @@ -187,11 +184,13 @@ function test_implicit_jacobian(scen::Scenario, outer_backend::AbstractADType) ) @testset "Jacobian - $outer_backend" begin - @testset "Prepared" begin - jac = DI.jacobian( - x -> first(implicit(prep, x, scen.args...)), outer_backend, scen.x - ) - @test jac ≈ jac_true + if outer_backend isa AutoForwardDiff + @testset "Prepared" begin + jac = DI.jacobian( + x -> first(implicit(prep, x, scen.args...)), outer_backend, scen.x + ) + @test jac ≈ jac_true + end end @testset "Unprepared" begin jac = DI.jacobian( @@ -202,11 +201,15 @@ function test_implicit_jacobian(scen::Scenario, outer_backend::AbstractADType) end end -function test_implicit(scen::Scenario, outer_backends=[AutoForwardDiff(), AutoZygote()]) - @testset "$scen" begin +function test_implicit( + scen::Scenario, + outer_backends=[AutoForwardDiff(), AutoZygote()]; + type_stability::Bool=false, +) + return @testset "$scen" begin test_implicit_call(scen) - test_implicit_duals(scen) - test_implicit_rrule(scen) + test_implicit_duals(scen; type_stability) + test_implicit_rrule(scen; type_stability) for outer_backend in outer_backends test_implicit_jacobian(scen, outer_backend) end