diff --git a/Project.toml b/Project.toml index 0b0d4c35..f4bf0b47 100644 --- a/Project.toml +++ b/Project.toml @@ -1,12 +1,11 @@ name = "ImplicitDifferentiation" uuid = "57b37032-215b-411a-8a7c-41a003a55207" authors = ["Guillaume Dalle", "Mohamed Tarek"] -version = "0.8.1" +version = "0.9.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" -IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153" Krylov = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LinearMaps = "7a12625a-238d-50fd-b39a-03d52299707e" @@ -33,11 +32,10 @@ Documenter = "1.12.0" ExplicitImports = "1" FiniteDiff = "2.27.0" ForwardDiff = "0.10.36, 1" -IterativeSolvers = "0.9.4" JET = "0.9, 0.10" JuliaFormatter = "2.1.2" Krylov = "0.9.6, 0.10" -LinearAlgebra = "1.10" +LinearAlgebra = "1" LinearMaps = "3.11.4" LinearOperators = "2.8.0" NLsolve = "4.5.1" @@ -65,6 +63,7 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899" Krylov = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56" Optim = "429524aa-4258-5aef-a3af-852621145aeb" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" @@ -76,4 +75,27 @@ TestItems = "1c621080-faea-4a02-84b6-bbd5e436b8fe" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["ADTypes", "Aqua", "ChainRulesCore", "ChainRulesTestUtils", "ComponentArrays", "DifferentiationInterface", "Documenter", "ExplicitImports", "FiniteDiff", "ForwardDiff", "JET", "JuliaFormatter", "NLsolve", "Optim", "Random", "SparseArrays", "StaticArrays", "Test", "TestItems", "TestItemRunner", "Zygote"] +test = [ + "ADTypes", + "Aqua", + "ChainRulesCore", + "ChainRulesTestUtils", + "ComponentArrays", + "DifferentiationInterface", + "Documenter", + "ExplicitImports", + "FiniteDiff", + "ForwardDiff", + "JET", + "JuliaFormatter", + "LinearAlgebra", + "NLsolve", + "Optim", + "Random", + "SparseArrays", + "StaticArrays", + "Test", + "TestItems", + "TestItemRunner", + "Zygote", +] diff --git a/docs/src/api.md b/docs/src/api.md index ce95c90e..d72019b6 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -22,6 +22,8 @@ ImplicitFunction MatrixRepresentation OperatorRepresentation IterativeLinearSolver +DirectLinearSolver +prepare_implicit ``` ## Internals diff --git a/docs/src/faq.md b/docs/src/faq.md index bec4b2bf..0322055f 100644 --- a/docs/src/faq.md +++ b/docs/src/faq.md @@ -18,8 +18,7 @@ However, this can be switched to any other "inner" backend compatible with [Diff ### Arrays Functions that eat or spit out arbitrary arrays are supported, as long as the forward mapping _and_ conditions return arrays of the same size. - -If you deal with small arrays (say, less than 100 elements), consider using [StaticArrays.jl](https://github.com/JuliaArrays/StaticArrays.jl) for increased performance. +The array types involved should be mutable. ### Scalars diff --git a/examples/3_tricks.jl b/examples/3_tricks.jl index a1cf4e89..3d47251c 100644 --- a/examples/3_tricks.jl +++ b/examples/3_tricks.jl @@ -38,7 +38,7 @@ end function conditions_components(x::ComponentVector, y::ComponentVector, _z) c_d, c_e = conditions_components_aux(x.a, x.b, x.m, y.d, y.e) - c = ComponentVector(; c_d=c_d, c_e=c_e) + c = ComponentVector(; d=c_d, e=c_e) return c end; @@ -50,7 +50,7 @@ implicit_components = ImplicitFunction(forward_components, conditions_components a, b, m = [1.0, 2.0], [3.0, 4.0, 5.0], 6.0 x = ComponentVector(; a=a, b=b, m=m) -implicit_components(x) +y, z = implicit_components(x) # And it works with both ForwardDiff.jl and Zygote.jl diff --git a/ext/ImplicitDifferentiationChainRulesCoreExt.jl b/ext/ImplicitDifferentiationChainRulesCoreExt.jl index 997ab1f7..341fc37d 100644 --- a/ext/ImplicitDifferentiationChainRulesCoreExt.jl +++ b/ext/ImplicitDifferentiationChainRulesCoreExt.jl @@ -6,6 +6,7 @@ using ChainRulesCore: unthunk, @not_implemented using ImplicitDifferentiation: ImplicitDifferentiation, ImplicitFunction, + ImplicitFunctionPreparation, build_Aᵀ, build_Bᵀ, chainrules_suggested_backend @@ -14,28 +15,33 @@ using ImplicitDifferentiation: ImplicitDifferentiation.chainrules_suggested_backend(rc::RuleConfig) = AutoChainRules(rc) function ChainRulesCore.rrule( - rc::RuleConfig, implicit::ImplicitFunction, x::AbstractArray, args::Vararg{Any,N}; + rc::RuleConfig, + implicit::ImplicitFunction, + prep::ImplicitFunctionPreparation, + x::AbstractArray, + args::Vararg{Any,N}; ) where {N} y, z = implicit(x, args...) c = implicit.conditions(x, y, z, args...) suggested_backend = chainrules_suggested_backend(rc) - Aᵀ = build_Aᵀ(implicit, x, y, z, c, args...; suggested_backend) - Bᵀ = build_Bᵀ(implicit, x, y, z, c, args...; suggested_backend) + Aᵀ = build_Aᵀ(implicit, prep, x, y, z, c, args...; suggested_backend) + Bᵀ = build_Bᵀ(implicit, prep, x, y, z, c, args...; suggested_backend) project_x = ProjectTo(x) - function implicit_pullback((dy, dz)) + function implicit_pullback_prepared((dy, dz)) dy = unthunk(dy) dy_vec = vec(dy) dc_vec = implicit.linear_solver(Aᵀ, -dy_vec) dx_vec = Bᵀ(dc_vec) dx = reshape(dx_vec, size(x)) df = NoTangent() + dprep = @not_implemented("Tangents for mutable arguments are not defined") dargs = ntuple(unimplemented_tangent, N) - return (df, project_x(dx), dargs...) + return (df, dprep, project_x(dx), dargs...) end - return (y, z), implicit_pullback + return (y, z), implicit_pullback_prepared end function unimplemented_tangent(_) diff --git a/ext/ImplicitDifferentiationForwardDiffExt.jl b/ext/ImplicitDifferentiationForwardDiffExt.jl index b5358c5d..9f8bae71 100644 --- a/ext/ImplicitDifferentiationForwardDiffExt.jl +++ b/ext/ImplicitDifferentiationForwardDiffExt.jl @@ -2,31 +2,36 @@ module ImplicitDifferentiationForwardDiffExt using ADTypes: AutoForwardDiff using ForwardDiff: Dual, Partials, partials, value -using ImplicitDifferentiation: ImplicitFunction, build_A, build_B +using ImplicitDifferentiation: + ImplicitFunction, ImplicitFunctionPreparation, build_A, build_B function (implicit::ImplicitFunction)( - x_and_dx::AbstractArray{Dual{T,R,N}}, args... + prep::ImplicitFunctionPreparation, x_and_dx::AbstractArray{Dual{T,R,N}}, args... ) where {T,R,N} x = value.(x_and_dx) y, z = implicit(x, args...) c = implicit.conditions(x, y, z, args...) suggested_backend = AutoForwardDiff() - A = build_A(implicit, x, y, z, c, args...; suggested_backend) - B = build_B(implicit, x, y, z, c, args...; suggested_backend) + A = build_A(implicit, prep, x, y, z, c, args...; suggested_backend) + B = build_B(implicit, prep, x, y, z, c, args...; suggested_backend) dX = ntuple(Val(N)) do k partials.(x_and_dx, k) end - dC_mat = mapreduce(hcat, dX) do dₖx + dC_vec = map(dX) do dₖx dₖx_vec = vec(dₖx) dₖc_vec = B(dₖx_vec) return dₖc_vec end - dY_mat = implicit.linear_solver(A, -dC_mat) + dY = map(dC_vec) do dₖc_vec + dₖy_vec = implicit.linear_solver(A, -dₖc_vec) + dₖy = reshape(dₖy_vec, size(y)) + return dₖy + end y_and_dy = map(y, LinearIndices(y)) do yi, i - Dual{T}(yi, Partials(ntuple(k -> dY_mat[i, k], Val(N)))) + Dual{T}(yi, Partials(ntuple(k -> dY[k][i], Val(N)))) end return y_and_dy, z diff --git a/src/ImplicitDifferentiation.jl b/src/ImplicitDifferentiation.jl index c7ea12a3..c601f669 100644 --- a/src/ImplicitDifferentiation.jl +++ b/src/ImplicitDifferentiation.jl @@ -21,20 +21,21 @@ using DifferentiationInterface: pullback, pushforward!, pushforward -using Krylov: Krylov -using IterativeSolvers: IterativeSolvers +using Krylov: Krylov, krylov_workspace, krylov_solve!, solution using LinearOperators: LinearOperator using LinearMaps: FunctionMap using LinearAlgebra: factorize include("utils.jl") include("settings.jl") -include("preparation.jl") include("implicit_function.jl") +include("preparation.jl") include("execution.jl") +include("callable.jl") export MatrixRepresentation, OperatorRepresentation -export IterativeLinearSolver +export IterativeLinearSolver, DirectLinearSolver export ImplicitFunction +export prepare_implicit end diff --git a/src/callable.jl b/src/callable.jl new file mode 100644 index 00000000..5ce5f172 --- /dev/null +++ b/src/callable.jl @@ -0,0 +1,9 @@ +function (implicit::ImplicitFunction)(x::AbstractArray, args::Vararg{Any,N}) where {N} + return implicit(ImplicitFunctionPreparation(), x, args...) +end + +function (implicit::ImplicitFunction)( + ::ImplicitFunctionPreparation, x::AbstractArray, args::Vararg{Any,N} +) where {N} + return implicit.solver(x, args...) +end diff --git a/src/execution.jl b/src/execution.jl index 3d5d00e3..5d543de9 100644 --- a/src/execution.jl +++ b/src/execution.jl @@ -18,15 +18,25 @@ end function (po::JVP!)(res::AbstractVector, v_wrongtype::AbstractVector) (; f, backend, input, v_buffer, contexts, prep) = po - copyto!(v_buffer, v_wrongtype) - pushforward!(f, (res,), prep, backend, input, (v_buffer,), contexts...) + if typeof(v_buffer) == typeof(v_wrongtype) + v = v_wrongtype + else + copyto!(v_buffer, v_wrongtype) + v = v_buffer + end + pushforward!(f, (res,), prep, backend, input, (v,), contexts...) return res end function (po::VJP!)(res::AbstractVector, v_wrongtype::AbstractVector) (; f, backend, input, v_buffer, contexts, prep) = po - copyto!(v_buffer, v_wrongtype) - pullback!(f, (res,), prep, backend, input, (v_buffer,), contexts...) + if typeof(v_buffer) == typeof(v_wrongtype) + v = v_wrongtype + else + copyto!(v_buffer, v_wrongtype) + v = v_buffer + end + pullback!(f, (res,), prep, backend, input, (v,), contexts...) return res end @@ -34,6 +44,7 @@ end function build_A( implicit::ImplicitFunction, + prep::ImplicitFunctionPreparation, x::AbstractArray, y::AbstractArray, z, @@ -42,14 +53,15 @@ function build_A( suggested_backend::AbstractADType, ) return build_A_aux( - implicit.representation, implicit, x, y, z, c, args...; suggested_backend + implicit.representation, implicit, prep, x, y, z, c, args...; suggested_backend ) end function build_A_aux( - ::MatrixRepresentation, implicit, x, y, z, c, args...; suggested_backend + ::MatrixRepresentation, implicit, prep, x, y, z, c, args...; suggested_backend ) - (; conditions, backends, prep_A) = implicit + (; conditions, backends) = implicit + (; prep_A) = prep actual_backend = isnothing(backends) ? suggested_backend : backends.y contexts = (Constant(x), Constant(z), map(Constant, args)...) if isnothing(prep_A) @@ -61,17 +73,19 @@ function build_A_aux( end function build_A_aux( - ::OperatorRepresentation{package,symmetric,hermitian,posdef,keep_input_type}, + ::OperatorRepresentation{package,symmetric,hermitian,posdef}, implicit, + prep, x, y, z, c, args...; suggested_backend, -) where {package,symmetric,hermitian,posdef,keep_input_type} +) where {package,symmetric,hermitian,posdef} T = Base.promote_eltype(x, y, c) - (; conditions, backends, prep_A) = implicit + (; conditions, backends) = implicit + (; prep_A) = prep actual_backend = isnothing(backends) ? suggested_backend : backends.y contexts = (Constant(x), Constant(z), map(Constant, args)...) f_vec = VecToVec(Switch12(conditions), y) @@ -88,15 +102,7 @@ function build_A_aux( end prod! = JVP!(f_vec, prep_A_same, actual_backend, y_vec, dy_vec, contexts) if package == :LinearOperators - return LinearOperator( - T, - length(c), - length(y), - symmetric, - hermitian, - prod!; - S=keep_input_type ? typeof(dy_vec) : Vector{T}, - ) + return LinearOperator(T, length(c), length(y), symmetric, hermitian, prod!) elseif package == :LinearMaps return FunctionMap{T}( prod!, @@ -114,6 +120,7 @@ end function build_Aᵀ( implicit::ImplicitFunction, + prep::ImplicitFunctionPreparation, x::AbstractArray, y::AbstractArray, z, @@ -122,14 +129,15 @@ function build_Aᵀ( suggested_backend::AbstractADType, ) return build_Aᵀ_aux( - implicit.representation, implicit, x, y, z, c, args...; suggested_backend + implicit.representation, implicit, prep, x, y, z, c, args...; suggested_backend ) end function build_Aᵀ_aux( - ::MatrixRepresentation, implicit, x, y, z, c, args...; suggested_backend + ::MatrixRepresentation, implicit, prep, x, y, z, c, args...; suggested_backend ) - (; conditions, backends, prep_Aᵀ) = implicit + (; conditions, backends) = implicit + (; prep_Aᵀ) = prep actual_backend = isnothing(backends) ? suggested_backend : backends.y contexts = (Constant(x), Constant(z), map(Constant, args)...) if isnothing(prep_Aᵀ) @@ -143,17 +151,19 @@ function build_Aᵀ_aux( end function build_Aᵀ_aux( - ::OperatorRepresentation{package,symmetric,hermitian,posdef,keep_input_type}, + ::OperatorRepresentation{package,symmetric,hermitian,posdef}, implicit, + prep, x, y, z, c, args...; suggested_backend, -) where {package,symmetric,hermitian,posdef,keep_input_type} +) where {package,symmetric,hermitian,posdef} T = Base.promote_eltype(x, y, c) - (; conditions, backends, prep_Aᵀ) = implicit + (; conditions, backends) = implicit + (; prep_Aᵀ) = prep actual_backend = isnothing(backends) ? suggested_backend : backends.y contexts = (Constant(x), Constant(z), map(Constant, args)...) f_vec = VecToVec(Switch12(conditions), y) @@ -170,15 +180,7 @@ function build_Aᵀ_aux( end prod! = VJP!(f_vec, prep_Aᵀ_same, actual_backend, y_vec, dc_vec, contexts) if package == :LinearOperators - return LinearOperator( - T, - length(y), - length(c), - symmetric, - hermitian, - prod!; - S=keep_input_type ? typeof(dc_vec) : Vector{T}, - ) + return LinearOperator(T, length(y), length(c), symmetric, hermitian, prod!) elseif package == :LinearMaps return FunctionMap{T}( prod!, @@ -196,6 +198,7 @@ end function build_B( implicit::ImplicitFunction, + prep::ImplicitFunctionPreparation, x::AbstractArray, y::AbstractArray, z, @@ -203,7 +206,8 @@ function build_B( args...; suggested_backend::AbstractADType, ) - (; conditions, backends, prep_B) = implicit + (; conditions, backends) = implicit + (; prep_B) = prep actual_backend = isnothing(backends) ? suggested_backend : backends.x contexts = (Constant(y), Constant(z), map(Constant, args)...) f_vec = VecToVec(conditions, x) @@ -219,7 +223,8 @@ function build_B( ) end function B_fun(dx_vec_wrongtype) - copyto!(dx_vec, dx_vec_wrongtype) + @assert typeof(dx_vec) == typeof(dx_vec_wrongtype) + dx_vec = dx_vec_wrongtype return pushforward( f_vec, prep_B_same, actual_backend, x_vec, (dx_vec,), contexts... )[1] @@ -231,6 +236,7 @@ end function build_Bᵀ( implicit::ImplicitFunction, + prep::ImplicitFunctionPreparation, x::AbstractArray, y::AbstractArray, z, @@ -238,7 +244,8 @@ function build_Bᵀ( args...; suggested_backend::AbstractADType, ) - (; conditions, backends, prep_Bᵀ) = implicit + (; conditions, backends) = implicit + (; prep_Bᵀ) = prep actual_backend = isnothing(backends) ? suggested_backend : backends.x contexts = (Constant(y), Constant(z), map(Constant, args)...) f_vec = VecToVec(conditions, x) @@ -254,7 +261,11 @@ function build_Bᵀ( ) end function Bᵀ_fun(dc_vec_wrongtype) - copyto!(dc_vec, dc_vec_wrongtype) + if typeof(dc_vec) == typeof(dc_vec_wrongtype) + dc_vec = dc_vec_wrongtype + else + copyto!(dc_vec, dc_vec_wrongtype) + end return pullback(f_vec, prep_Bᵀ_same, actual_backend, x_vec, (dc_vec,), contexts...)[1] end return Bᵀ_fun diff --git a/src/implicit_function.jl b/src/implicit_function.jl index 531078c6..a12d72e4 100644 --- a/src/implicit_function.jl +++ b/src/implicit_function.jl @@ -23,8 +23,6 @@ This requires solving a linear system `A * J = -B` where `A = ∂₂c`, `B = ∂ representation=OperatorRepresentation(), linear_solver=IterativeLinearSolver(), backends=nothing, - preparation=nothing, - input_example=nothing, ) ## Positional arguments @@ -35,11 +33,8 @@ This requires solving a linear system `A * J = -B` where `A = ∂₂c`, `B = ∂ ## Keyword arguments - `representation`: defines how the partial Jacobian `A` of the conditions with respect to the output is represented, either [`MatrixRepresentation`](@ref) or [`OperatorRepresentation`](@ref). -- `linear_solver`: a callable to solve linear systems with two required methods, one for `(A, b::AbstractVector)` (single solve) and one for `(A, B::AbstractMatrix)` (batched solve). It defaults to [`IterativeLinearSolver`](@ref) but can also be the built-in `\\`, or a user-provided function. +- `linear_solver`: specifies how the linear system `A * J = -B` will be solved in the implicit function theorem. It can be either [`DirectLinearSolver`](@ref) or [`IterativeLinearSolver`](@ref). - `backends::AbstractADType`: specifies how the `conditions` will be differentiated with respect to `x` and `y`. It can be either, `nothing`, which means that the external autodiff system will be used, or a named tuple `(; x=AutoSomething(), y=AutoSomethingElse())` of backend objects from [ADTypes.jl](https://github.com/SciML/ADTypes.jl). -- `preparation`: either `nothing` or a mode object from [ADTypes.jl](https://github.com/SciML/ADTypes.jl): `ADTypes.ForwardMode()`, `ADTypes.ReverseMode()` or `ADTypes.ForwardOrReverseMode()`. -- `input_example`: either `nothing` or a tuple `(x, args...)` used to prepare differentiation. -- `strict::Val=Val(true)`: whether or not to enforce a strict match in [DifferentiationInterface.jl](https://github.com/JuliaDiff/DifferentiationInterface.jl) between the preparation and the execution types. """ struct ImplicitFunction{ F, @@ -50,24 +45,14 @@ struct ImplicitFunction{ Nothing, # NamedTuple{(:x, :y),<:Tuple{AbstractADType,AbstractADType}}, }, - P<:Union{Nothing,AbstractMode}, - PA, - PAT, - PB, - PBT, - _strict, + S, } solver::F conditions::C linear_solver::L representation::R backends::B - preparation::P - prep_A::PA - prep_Aᵀ::PAT - prep_B::PB - prep_Bᵀ::PBT - strict::Val{_strict} + strict::Val{S} end function ImplicitFunction( @@ -76,59 +61,15 @@ function ImplicitFunction( representation=OperatorRepresentation(), linear_solver=IterativeLinearSolver(), backends=nothing, - preparation=nothing, - input_example=nothing, strict::Val=Val(true), ) - if isnothing(preparation) || isnothing(backends) || isnothing(input_example) - prep_A = nothing - prep_Aᵀ = nothing - prep_B = nothing - prep_Bᵀ = nothing - else - x, args = first(input_example), Base.tail(input_example) - y, z = solver(x, args...) - c = conditions(x, y, z, args...) - if preparation isa Union{ForwardMode,ForwardOrReverseMode} - prep_A = prepare_A( - representation, x, y, z, c, args...; conditions, backend=backends.y, strict - ) - prep_B = prepare_B( - representation, x, y, z, c, args...; conditions, backend=backends.x, strict - ) - else - prep_A = nothing - prep_B = nothing - end - if preparation isa Union{ReverseMode,ForwardOrReverseMode} - prep_Aᵀ = prepare_Aᵀ( - representation, x, y, z, c, args...; conditions, backend=backends.y, strict - ) - prep_Bᵀ = prepare_Bᵀ( - representation, x, y, z, c, args...; conditions, backend=backends.x, strict - ) - else - prep_Aᵀ = nothing - prep_Bᵀ = nothing - end - end return ImplicitFunction( - solver, - conditions, - linear_solver, - representation, - backends, - preparation, - prep_A, - prep_Aᵀ, - prep_B, - prep_Bᵀ, - strict, + solver, conditions, linear_solver, representation, backends, strict ) end function Base.show(io::IO, implicit::ImplicitFunction) - (; solver, conditions, backends, linear_solver, representation, preparation) = implicit + (; solver, conditions, backends, linear_solver, representation) = implicit return print( io, """ @@ -138,12 +79,7 @@ function Base.show(io::IO, implicit::ImplicitFunction) representation=$representation, linear_solver=$linear_solver, backends=$backends, - preparation=$preparation, ) """, ) end - -function (implicit::ImplicitFunction)(x::AbstractArray, args::Vararg{Any,N}) where {N} - return implicit.solver(x, args...) -end diff --git a/src/preparation.jl b/src/preparation.jl index c6b09b72..4183f757 100644 --- a/src/preparation.jl +++ b/src/preparation.jl @@ -1,3 +1,73 @@ +""" + ImplicitFunctionPreparation + +# Fields + +- `prep_A`: preparation for `A` (derivative of conditions with respect to `y`) in forward mode +- `prep_Aᵀ`: preparation for `A` (derivative of conditions with respect to `y`) in reverse mode +- `prep_B`: preparation for `B` (derivative of conditions with respect to `x`) in forward mode +- `prep_Bᵀ`: preparation for `B` (derivative of conditions with respect to `x`) in reverse mode +""" +struct ImplicitFunctionPreparation{PA,PAT,PB,PBT} + prep_A::PA + prep_Aᵀ::PAT + prep_B::PB + prep_Bᵀ::PBT +end + +function ImplicitFunctionPreparation() + return ImplicitFunctionPreparation(nothing, nothing, nothing, nothing) +end + +""" + prepare_implicit( + mode::ADTypes.AbstractMode, + implicit::ImplicitFunction, + x_prep, + args_prep...; + strict=Val(true) + ) + +Uses the preparation mechanism from [DifferentiationInterface.jl](https://github.com/JuliaDiff/DifferentiationInterface.jl) to speed up subsequent calls to `implicit(x, args...)` where `(x, args...)` are similar to `(x_prep, args_prep...)`. +""" +function prepare_implicit( + mode::AbstractMode, implicit::ImplicitFunction, x, args::Vararg{Any,N}; strict=Val(true) +) where {N} + (; solver, conditions, backends, representation) = implicit + y, z = solver(x, args...) + c = conditions(x, y, z, args...) + if isnothing(backends) + prep_A = nothing + prep_Aᵀ = nothing + prep_B = nothing + prep_Bᵀ = nothing + else + if mode isa Union{ForwardMode,ForwardOrReverseMode} + prep_A = prepare_A( + representation, x, y, z, c, args...; conditions, backend=backends.y, strict + ) + prep_B = prepare_B( + representation, x, y, z, c, args...; conditions, backend=backends.x, strict + ) + else + prep_A = nothing + prep_B = nothing + end + if mode isa Union{ReverseMode,ForwardOrReverseMode} + prep_Aᵀ = prepare_Aᵀ( + representation, x, y, z, c, args...; conditions, backend=backends.y, strict + ) + prep_Bᵀ = prepare_Bᵀ( + representation, x, y, z, c, args...; conditions, backend=backends.x, strict + ) + else + prep_Aᵀ = nothing + prep_Bᵀ = nothing + end + end + return ImplicitFunctionPreparation(prep_A, prep_Aᵀ, prep_B, prep_Bᵀ) +end + function prepare_A( ::MatrixRepresentation, x::AbstractArray, diff --git a/src/settings.jl b/src/settings.jl index 5c4ba95a..3b739f25 100644 --- a/src/settings.jl +++ b/src/settings.jl @@ -1,74 +1,53 @@ ## Linear solver """ - IterativeLinearSolver - -Callable object that can solve linear systems `Ax = b` and `AX = B` in the same way as the built-in `\\`. - -# Constructor - - IterativeLinearSolver(; kwargs...) - IterativeLinearSolver{package}(; kwargs...) - -The type parameter `package` can be either: + DirectLinearSolver -- `:Krylov` to use the solver `gmres` or `block_gmres` from [Krylov.jl](https://github.com/JuliaSmoothOptimizers/Krylov.jl) (the default) -- `:IterativeSolvers` to use the solver `gmres` from [IterativeSolvers.jl](https://github.com/JuliaLinearAlgebra/IterativeSolvers.jl) +Specify that linear systems `Ax = b` should be solved with a direct method. +""" +struct DirectLinearSolver end -Keyword arguments are passed on to the respective solver. +function (solver::DirectLinearSolver)(A, b::AbstractVector) + return A \ b +end -# Callable behavior +""" + IterativeLinearSolver - (::IterativeLinearSolver)(A, b::AbstractVector) +Specify that linear systems `Ax = b` should be solved with an iterative method. -Solve a linear system with a single right-hand side. +# Constructor - (::IterativeLinearSolver)(A, B::AbstractMatrix) + IterativeLinearSolver(::Val{method}=Val(:gmres); kwargs...) -Solve a linear system with multiple right-hand sides. +The `method` symbol is used to pick the appropriate algorithm from [Krylov.jl](https://github.com/JuliaSmoothOptimizers/Krylov.jl). +Keyword arguments are passed on to that algorithm. """ -struct IterativeLinearSolver{package,K} +struct IterativeLinearSolver{method,K} + _method::Val{method} kwargs::K - function IterativeLinearSolver{package}(; kwargs...) where {package} - @assert package in [:Krylov, :IterativeSolvers] - return new{package,typeof(kwargs)}(kwargs) + function IterativeLinearSolver((::Val{method})=Val(:gmres); kwargs...) where {method} + return new{method,typeof(kwargs)}(Val(method), kwargs) end end -function Base.show(io::IO, linear_solver::IterativeLinearSolver{package}) where {package} - print(io, "IterativeLinearSolver{$(repr(package))}(; ") - for (k, v) in pairs(linear_solver.kwargs) - print(io, "$k=$v, ") +function Base.show(io::IO, linear_solver::IterativeLinearSolver{method}) where {method} + print(io, "IterativeLinearSolver{$(repr(method))}") + if isempty(linear_solver.kwargs) + print(io, "()") + else + print(io, "(; ") + for (k, v) in pairs(linear_solver.kwargs) + print(io, "$k=$(repr(v)), ") + end + print(io, ")") end - return print(io, ")") end -IterativeLinearSolver(; kwargs...) = IterativeLinearSolver{:Krylov}(; kwargs...) - -function (solver::IterativeLinearSolver{:Krylov})(A, b::AbstractVector) - x, stats = Krylov.gmres(A, b; solver.kwargs...) - return x -end - -function (solver::IterativeLinearSolver{:Krylov})(A, B::AbstractMatrix) - # TODO: use block_gmres - X = mapreduce(hcat, eachcol(B)) do b - x, _ = Krylov.gmres(A, b; solver.kwargs...) - x - end - return X -end - -function (solver::IterativeLinearSolver{:IterativeSolvers})(A, b::AbstractVector) - x = IterativeSolvers.gmres(A, b; solver.kwargs...) - return x -end - -function (solver::IterativeLinearSolver{:IterativeSolvers})(A, B::AbstractMatrix) - X = mapreduce(hcat, eachcol(B)) do b - IterativeSolvers.gmres(A, b; solver.kwargs...) - end - return X +function (solver::IterativeLinearSolver{method})(A, b::AbstractVector) where {method} + workspace = krylov_workspace(Val(method), A, b) + krylov_solve!(workspace, A, b) + return solution(workspace) end ## Representation @@ -94,12 +73,8 @@ Specify that the matrix `A` involved in the implicit function theorem should be # Constructors - OperatorRepresentation(; - symmetric=false, hermitian=false, posdef=false, keep_input_type=false - ) - OperatorRepresentation{package}(; - symmetric=false, hermitian=false, posdef=false, keep_input_type=false - ) + OperatorRepresentation(; symmetric=false, hermitian=false, posdef=false) + OperatorRepresentation{package}(; symmetric=false, hermitian=false, posdef=false) The type parameter `package` can be either: @@ -108,32 +83,26 @@ The type parameter `package` can be either: The keyword arguments `symmetric`, `hermitian` and `posdef` give additional properties of the Jacobian of the `conditions` with respect to the solution `y`, which are useful to the solver in case you can prove them. -The keyword argument `keep_input_type` dictates whether to force the linear operator to work with the provided input type, or fall back on a default. - # See also - [`ImplicitFunction`](@ref) - [`MatrixRepresentation`](@ref) """ -struct OperatorRepresentation{package,symmetric,hermitian,posdef,keep_input_type} <: - AbstractRepresentation +struct OperatorRepresentation{package,symmetric,hermitian,posdef} <: AbstractRepresentation function OperatorRepresentation{package}(; - symmetric::Bool=false, - hermitian::Bool=false, - posdef::Bool=false, - keep_input_type::Bool=false, + symmetric::Bool=false, hermitian::Bool=false, posdef::Bool=false ) where {package} @assert package in [:LinearOperators, :LinearMaps] - return new{package,symmetric,hermitian,posdef,keep_input_type}() + return new{package,symmetric,hermitian,posdef}() end end function Base.show( - io::IO, ::OperatorRepresentation{package,symmetric,hermitian,posdef,keep_input_type} -) where {package,symmetric,hermitian,posdef,keep_input_type} + io::IO, ::OperatorRepresentation{package,symmetric,hermitian,posdef} +) where {package,symmetric,hermitian,posdef} return print( io, - "OperatorRepresentation{$(repr(package))}(; symmetric=$symmetric, hermitian=$hermitian, posdef=$posdef, keep_input_type=$keep_input_type)", + "OperatorRepresentation{$(repr(package))}(; symmetric=$symmetric, hermitian=$hermitian, posdef=$posdef)", ) end diff --git a/test/preparation.jl b/test/preparation.jl index b4b2e2c9..32d59b74 100644 --- a/test/preparation.jl +++ b/test/preparation.jl @@ -8,56 +8,41 @@ solver(x) = sqrt.(x), nothing conditions(x, y, z) = y .^ 2 .- x + implicit = ImplicitFunction( + solver, conditions; backends=(; x=AutoForwardDiff(), y=AutoForwardDiff()) + ) + implicit_nobackends = ImplicitFunction(solver, conditions) x = rand(5) - input_example = (x,) @testset "None" begin - implicit_none = ImplicitFunction(solver, conditions) - @test implicit_none.prep_A === nothing - @test implicit_none.prep_Aᵀ === nothing - @test implicit_none.prep_B === nothing - @test implicit_none.prep_Bᵀ === nothing + prep = prepare_implicit(ForwardOrReverseMode(), implicit_nobackends, x) + @test prep.prep_A === nothing + @test prep.prep_Aᵀ === nothing + @test prep.prep_B === nothing + @test prep.prep_Bᵀ === nothing end @testset "ForwardMode" begin - implicit_forward = ImplicitFunction( - solver, - conditions; - preparation=ForwardMode(), - backends=(; x=AutoForwardDiff(), y=AutoForwardDiff()), - input_example, - ) - @test implicit_forward.prep_A !== nothing - @test implicit_forward.prep_Aᵀ === nothing - @test implicit_forward.prep_B !== nothing - @test implicit_forward.prep_Bᵀ === nothing + prep = prepare_implicit(ForwardMode(), implicit, x) + @test prep.prep_A !== nothing + @test prep.prep_Aᵀ === nothing + @test prep.prep_B !== nothing + @test prep.prep_Bᵀ === nothing end @testset "ReverseMode" begin - implicit_reverse = ImplicitFunction( - solver, - conditions; - preparation=ReverseMode(), - backends=(; x=AutoZygote(), y=AutoZygote()), - input_example, - ) - @test implicit_reverse.prep_A === nothing - @test implicit_reverse.prep_Aᵀ !== nothing - @test implicit_reverse.prep_B === nothing - @test implicit_reverse.prep_Bᵀ !== nothing + prep = prepare_implicit(ReverseMode(), implicit, x) + @test prep.prep_A === nothing + @test prep.prep_Aᵀ !== nothing + @test prep.prep_B === nothing + @test prep.prep_Bᵀ !== nothing end @testset "Both" begin - implicit_both = ImplicitFunction( - solver, - conditions; - preparation=ForwardOrReverseMode(), - backends=(; x=AutoForwardDiff(), y=AutoZygote()), - input_example, - ) - @test implicit_both.prep_A !== nothing - @test implicit_both.prep_Aᵀ !== nothing - @test implicit_both.prep_B !== nothing - @test implicit_both.prep_Bᵀ !== nothing + prep = prepare_implicit(ForwardOrReverseMode(), implicit, x) + @test prep.prep_A !== nothing + @test prep.prep_Aᵀ !== nothing + @test prep.prep_B !== nothing + @test prep.prep_Bᵀ !== nothing end end diff --git a/test/printing.jl b/test/printing.jl index 4e03d04e..2d94f78c 100644 --- a/test/printing.jl +++ b/test/printing.jl @@ -1,6 +1,8 @@ using TestItems @testitem "Settings" begin + @test startswith(string(ImplicitFunction(nothing, nothing)), "ImplicitFunction") @test startswith(string(OperatorRepresentation()), "Operator") @test startswith(string(IterativeLinearSolver(; atol=1e-5)), "Iterative") + @test startswith(string(IterativeLinearSolver()), "Iterative") end diff --git a/test/systematic.jl b/test/systematic.jl index 568d7d50..c358f583 100644 --- a/test/systematic.jl +++ b/test/systematic.jl @@ -1,10 +1,9 @@ using TestItems -@testitem "Matrix" setup = [TestUtils] begin +@testitem "Direct" setup = [TestUtils] begin using ADTypes, .TestUtils - for (backends, preparation, x) in Iterators.product( + for (backends, x) in Iterators.product( [nothing, (; x=AutoForwardDiff(), y=AutoZygote())], - [nothing, ADTypes.ForwardOrReverseMode()], [float.(1:3), reshape(float.(1:6), 3, 2)], ) yield() @@ -14,10 +13,8 @@ using TestItems x=x, implicit_kwargs=(; representation=MatrixRepresentation(), - linear_solver=\, + linear_solver=DirectLinearSolver(), backends, - preparation, - input_example=(x,), ), ) scen2 = add_arg_mult(scen) @@ -26,11 +23,10 @@ using TestItems end end; -@testitem "Krylov" setup = [TestUtils] begin +@testitem "Iterative" setup = [TestUtils] begin using ADTypes, .TestUtils - for (backends, preparation, x) in Iterators.product( + for (backends, x) in Iterators.product( [nothing, (; x=AutoForwardDiff(), y=AutoZygote())], - [nothing, ADTypes.ForwardOrReverseMode()], [float.(1:3), reshape(float.(1:6), 3, 2)], ) yield() @@ -40,10 +36,8 @@ end; x=x, implicit_kwargs=(; representation=OperatorRepresentation{:LinearOperators}(), - linear_solver=IterativeLinearSolver{:Krylov}(), + linear_solver=IterativeLinearSolver(), backends, - preparation, - input_example=(x,), ), ) scen2 = add_arg_mult(scen) @@ -52,26 +46,16 @@ end; end end; -@testitem "IterativeSolvers" setup = [TestUtils] begin - using ADTypes, .TestUtils - for (backends, preparation, x) in Iterators.product( - [nothing, (; x=AutoForwardDiff(), y=AutoZygote())], - [nothing, ADTypes.ForwardOrReverseMode()], - [float.(1:3), reshape(float.(1:6), 3, 2)], +@testitem "ComponentVector" setup = [TestUtils] begin + using ComponentArrays, .TestUtils + x = ComponentVector(; a=float.(1:3), b=float.(4:6)) + scen = Scenario(; + solver=default_solver, + conditions=default_conditions, + x=x, + implicit_kwargs=(; strict=Val(false)), ) - yield() - scen = Scenario(; - solver=default_solver, - conditions=default_conditions, - x=x, - implicit_kwargs=(; - representation=OperatorRepresentation{:LinearMaps}(), - linear_solver=IterativeLinearSolver{:IterativeSolvers}(), - backends, - preparation, - input_example=(x,), - ), - ) - test_implicit(scen) - end + scen2 = add_arg_mult(scen) + test_implicit(scen) + test_implicit(scen2) end; diff --git a/test/utils.jl b/test/utils.jl index e95a1c87..1b6d5992 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -1,22 +1,26 @@ using ADTypes +using ADTypes: ForwardMode, ReverseMode, ForwardOrReverseMode using ChainRulesCore using ChainRulesTestUtils +using ComponentArrays import DifferentiationInterface as DI using ForwardDiff: ForwardDiff import ImplicitDifferentiation as ID -using ImplicitDifferentiation: ImplicitFunction +using ImplicitDifferentiation: ImplicitFunction, prepare_implicit using JET using LinearAlgebra using Random: rand! using Test using Zygote: Zygote, ZygoteRuleConfig -@kwdef struct Scenario{S,C,X,A,K} +@kwdef struct Scenario{S,C,X,A,K,Xp,Ap} solver::S conditions::C x::X args::A = () implicit_kwargs::K = (;) + x_prep::Xp = zero(x) + args_prep::Ap = map(zero, args) end function Base.show(io::IO, scen::Scenario) @@ -73,6 +77,8 @@ function add_arg_mult(scen::Scenario, a=3) x=scen.x, args=(a,), implicit_kwargs=implicit_kwargs_with_arg_mult, + x_prep=scen.x_prep, + args_prep=(zero(a),), ) end @@ -95,16 +101,12 @@ function test_implicit_duals(scen::Scenario) implicit = ImplicitFunction( NonDifferentiable(scen.solver), scen.conditions; scen.implicit_kwargs... ) + prep = prepare_implicit(ForwardMode(), implicit, scen.x_prep, scen.args_prep...) dx = similar(scen.x) rand!(dx) x_and_dx = ForwardDiff.Dual.(scen.x, dx) - y_and_dy, z = implicit(x_and_dx, scen.args...) - T = tag(y_and_dy) - y = ForwardDiff.value.(y_and_dy) - dy = ForwardDiff.extract_derivative.(T, y_and_dy) - y_true, z_true = scen.solver(scen.x, scen.args...) dy_true = DI.pushforward( first ∘ scen.solver, @@ -115,33 +117,61 @@ function test_implicit_duals(scen::Scenario) )[1] @testset "Duals" begin - @test y ≈ y_true - @test dy ≈ dy_true - @test z == z_true + @testset "Prepared" begin + y_and_dy, z = implicit(prep, x_and_dx, scen.args...) + T = tag(y_and_dy) + y = ForwardDiff.value.(y_and_dy) + dy = ForwardDiff.extract_derivative.(T, y_and_dy) + @test y ≈ y_true + @test dy ≈ dy_true + @test z == z_true + end + @testset "Unrepared" begin + y_and_dy, z = implicit(x_and_dx, scen.args...) + T = tag(y_and_dy) + y = ForwardDiff.value.(y_and_dy) + dy = ForwardDiff.extract_derivative.(T, y_and_dy) + @test y ≈ y_true + @test dy ≈ dy_true + @test z == z_true + end end end +function compare_pullbacks(dimpl, dx, dx_true) end + function test_implicit_rrule(scen::Scenario) implicit = ImplicitFunction( NonDifferentiable(scen.solver), scen.conditions; scen.implicit_kwargs... ) + prep = prepare_implicit(ReverseMode(), implicit, scen.x_prep, scen.args_prep...) y_true, z_true = scen.solver(scen.x, scen.args...) dy = similar(y_true) rand!(dy) dz = NoTangent() - (y, z), pb = rrule(ZygoteRuleConfig(), implicit, scen.x, scen.args...) - dimpl, dx = pb((dy, dz)) dx_true = DI.pullback( first ∘ scen.solver, AutoZygote(), scen.x, (dy,), map(DI.Constant, scen.args)... )[1] @testset "ChainRule" begin - @test y ≈ y_true - @test z == z_true - @test dimpl isa NoTangent - @test dx ≈ dx_true + @testset "Prepared" begin + (y, z), pb = rrule(ZygoteRuleConfig(), implicit, prep, scen.x, scen.args...) + dimpl, dprep, dx = pb((dy, dz)) + @test y ≈ y_true + @test z == z_true + @test dimpl isa NoTangent + @test dx ≈ dx_true + end + @testset "Unprepared" begin + (y, z), pb = rrule_via_ad(ZygoteRuleConfig(), implicit, scen.x, scen.args...) + dimpl, dx = pb((dy, dz)) + @test y ≈ y_true + @test z == z_true + @test dimpl isa NoTangent + @test dx ≈ dx_true + end end end @@ -149,15 +179,26 @@ function test_implicit_jacobian(scen::Scenario, outer_backend::AbstractADType) implicit = ImplicitFunction( NonDifferentiable(scen.solver), scen.conditions; scen.implicit_kwargs... ) - jac = DI.jacobian( - first ∘ implicit, outer_backend, scen.x, map(DI.Constant, scen.args)... + prep = prepare_implicit( + ForwardOrReverseMode(), implicit, scen.x_prep, scen.args_prep... ) jac_true = DI.jacobian( first ∘ scen.solver, outer_backend, scen.x, map(DI.Constant, scen.args)... ) @testset "Jacobian - $outer_backend" begin - @test jac ≈ jac_true + @testset "Prepared" begin + jac = DI.jacobian( + x -> first(implicit(prep, x, scen.args...)), outer_backend, scen.x + ) + @test jac ≈ jac_true + end + @testset "Unprepared" begin + jac = DI.jacobian( + first ∘ implicit, outer_backend, scen.x, map(DI.Constant, scen.args)... + ) + @test jac ≈ jac_true + end end end