diff --git a/Project.toml b/Project.toml index 9b2331ae..63743d10 100644 --- a/Project.toml +++ b/Project.toml @@ -1,13 +1,15 @@ name = "ImplicitDifferentiation" uuid = "57b37032-215b-411a-8a7c-41a003a55207" authors = ["Guillaume Dalle", "Mohamed Tarek"] -version = "0.7.3" +version = "0.8.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" +IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153" Krylov = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +LinearMaps = "7a12625a-238d-50fd-b39a-03d52299707e" LinearOperators = "5c8ed15e-5a4c-59e4-a42b-c7e8811fb125" [weakdeps] @@ -22,12 +24,30 @@ ImplicitDifferentiationZygoteExt = "Zygote" [compat] ADTypes = "1.9.0" +Aqua = "0.8.13" ChainRulesCore = "1.25.0" -DifferentiationInterface = "0.6.1" +ChainRulesTestUtils = "1.13.0" +ComponentArrays = "0.15.27" +DifferentiationInterface = "0.6.1,0.7" +Documenter = "1.12.0" +ExplicitImports = "1" +FiniteDiff = "2.27.0" ForwardDiff = "0.10.36, 1" +IterativeSolvers = "0.9.4" +JET = "0.9, 0.10" +JuliaFormatter = "2.1.2" Krylov = "0.9.6, 0.10" LinearAlgebra = "1.10" +LinearMaps = "3.11.4" LinearOperators = "2.8.0" +NLsolve = "4.5.1" +Optim = "1.12.0" +Random = "1" +SparseArrays = "1" +StaticArrays = "1.9.13" +Test = "1" +TestItemRunner = "1.1.0" +TestItems = "1.0.0" Zygote = "0.7.4" julia = "1.10" @@ -51,7 +71,9 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +TestItemRunner = "f8b46487-2199-4994-9208-9a1283c18c0a" +TestItems = "1c621080-faea-4a02-84b6-bbd5e436b8fe" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["ADTypes", "Aqua", "ChainRulesCore", "ChainRulesTestUtils", "ComponentArrays", "DifferentiationInterface", "Documenter", "ExplicitImports", "FiniteDiff", "ForwardDiff", "JET", "JuliaFormatter", "NLsolve", "Optim", "Random", "SparseArrays", "StaticArrays", "Test", "Zygote"] +test = ["ADTypes", "Aqua", "ChainRulesCore", "ChainRulesTestUtils", "ComponentArrays", "DifferentiationInterface", "Documenter", "ExplicitImports", "FiniteDiff", "ForwardDiff", "JET", "JuliaFormatter", "NLsolve", "Optim", "Random", "SparseArrays", "StaticArrays", "Test", "TestItems", "TestItemRunner", "Zygote"] diff --git a/docs/src/api.md b/docs/src/api.md index 594c8226..8eb6ba27 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -1,5 +1,4 @@ ```@meta -CurrentModule = ImplicitDifferentiation CollapsedDocStrings = true ``` @@ -20,7 +19,7 @@ ImplicitFunction ### Settings ```@docs -KrylovLinearSolver +IterativeLinearSolver MatrixRepresentation OperatorRepresentation NoPreparation diff --git a/docs/src/faq.md b/docs/src/faq.md index dbd0fd65..3e7e4edf 100644 --- a/docs/src/faq.md +++ b/docs/src/faq.md @@ -46,7 +46,10 @@ Say your forward mapping takes multiple inputs and returns multiple outputs, suc The trick is to leverage [ComponentArrays.jl](https://github.com/jonniedie/ComponentArrays.jl) to wrap all the inputs inside a single a `ComponentVector`, and do the same for all the outputs. See the examples for a demonstration. -!!! warning "Warning" +!!! warning + The default linear operator representation does not support ComponentArrays.jl: you need to select `representation=OperatorRepresentation{:LinearMaps}()` in the [`ImplicitFunction`](@ref) constructor for it to work. + +!!! warning You may run into issues trying to differentiate through the `ComponentVector` constructor. For instance, Zygote.jl will throw `ERROR: Mutating arrays is not supported`. Check out [this issue](https://github.com/gdalle/ImplicitDifferentiation.jl/issues/67) for a dirty workaround involving custom chain rules for the constructor. diff --git a/examples/3_tricks.jl b/examples/3_tricks.jl index 7b667b90..dd43ddd2 100644 --- a/examples/3_tricks.jl +++ b/examples/3_tricks.jl @@ -7,7 +7,6 @@ We demonstrate several features that may come in handy for some users. using ComponentArrays using ForwardDiff using ImplicitDifferentiation -using Krylov using LinearAlgebra using Test #src using Zygote @@ -43,9 +42,13 @@ function conditions_components(x::ComponentVector, y::ComponentVector, _z) return c end; -# And build your implicit function like so. +# And build your implicit function like so, switching the operator representation to avoid errors with ComponentArrays. -implicit_components = ImplicitFunction(forward_components, conditions_components); +implicit_components = ImplicitFunction( + forward_components, + conditions_components; + representation=OperatorRepresentation{:LinearMaps}(), +); # Now we're good to go. diff --git a/ext/ImplicitDifferentiationChainRulesCoreExt.jl b/ext/ImplicitDifferentiationChainRulesCoreExt.jl index 030fb876..997ab1f7 100644 --- a/ext/ImplicitDifferentiationChainRulesCoreExt.jl +++ b/ext/ImplicitDifferentiationChainRulesCoreExt.jl @@ -17,17 +17,18 @@ function ChainRulesCore.rrule( rc::RuleConfig, implicit::ImplicitFunction, x::AbstractArray, args::Vararg{Any,N}; ) where {N} y, z = implicit(x, args...) + c = implicit.conditions(x, y, z, args...) suggested_backend = chainrules_suggested_backend(rc) - Aᵀ = build_Aᵀ(implicit, x, y, z, args...; suggested_backend) - Bᵀ = build_Bᵀ(implicit, x, y, z, args...; suggested_backend) + Aᵀ = build_Aᵀ(implicit, x, y, z, c, args...; suggested_backend) + Bᵀ = build_Bᵀ(implicit, x, y, z, c, args...; suggested_backend) project_x = ProjectTo(x) function implicit_pullback((dy, dz)) dy = unthunk(dy) dy_vec = vec(dy) dc_vec = implicit.linear_solver(Aᵀ, -dy_vec) - dx_vec = Bᵀ * dc_vec + dx_vec = Bᵀ(dc_vec) dx = reshape(dx_vec, size(x)) df = NoTangent() dargs = ntuple(unimplemented_tangent, N) diff --git a/ext/ImplicitDifferentiationForwardDiffExt.jl b/ext/ImplicitDifferentiationForwardDiffExt.jl index f7b9a8e6..b5358c5d 100644 --- a/ext/ImplicitDifferentiationForwardDiffExt.jl +++ b/ext/ImplicitDifferentiationForwardDiffExt.jl @@ -9,22 +9,24 @@ function (implicit::ImplicitFunction)( ) where {T,R,N} x = value.(x_and_dx) y, z = implicit(x, args...) + c = implicit.conditions(x, y, z, args...) suggested_backend = AutoForwardDiff() - A = build_A(implicit, x, y, z, args...; suggested_backend) - B = build_B(implicit, x, y, z, args...; suggested_backend) + A = build_A(implicit, x, y, z, c, args...; suggested_backend) + B = build_B(implicit, x, y, z, c, args...; suggested_backend) dX = ntuple(Val(N)) do k partials.(x_and_dx, k) end dC_mat = mapreduce(hcat, dX) do dₖx dₖx_vec = vec(dₖx) - dₖc_vec = B * dₖx_vec + dₖc_vec = B(dₖx_vec) + return dₖc_vec end dY_mat = implicit.linear_solver(A, -dC_mat) - y_and_dy = map(LinearIndices(y)) do i - Dual{T}(y[i], Partials(ntuple(k -> dY_mat[i, k], Val(N)))) + y_and_dy = map(y, LinearIndices(y)) do yi, i + Dual{T}(yi, Partials(ntuple(k -> dY_mat[i, k], Val(N)))) end return y_and_dy, z diff --git a/src/ImplicitDifferentiation.jl b/src/ImplicitDifferentiation.jl index 91c351b2..0e044d48 100644 --- a/src/ImplicitDifferentiation.jl +++ b/src/ImplicitDifferentiation.jl @@ -18,9 +18,13 @@ using DifferentiationInterface: prepare_pushforward, prepare_pushforward_same_point, pullback!, - pushforward! -using Krylov: gmres + pullback, + pushforward!, + pushforward +using Krylov: Krylov +using IterativeSolvers: IterativeSolvers using LinearOperators: LinearOperator +using LinearMaps: FunctionMap using LinearAlgebra: factorize include("utils.jl") @@ -29,7 +33,7 @@ include("preparation.jl") include("implicit_function.jl") include("execution.jl") -export KrylovLinearSolver +export IterativeLinearSolver export MatrixRepresentation, OperatorRepresentation export NoPreparation, ForwardPreparation, ReversePreparation, BothPreparation export ImplicitFunction diff --git a/src/execution.jl b/src/execution.jl index 74a3bc5a..1190bd69 100644 --- a/src/execution.jl +++ b/src/execution.jl @@ -1,31 +1,32 @@ -const SYMMETRIC = false -const HERMITIAN = false - -struct JVP!{F,P,B,I,C} +struct JVP!{F,P,B,I,V,C} f::F prep::P backend::B input::I + v_buffer::V contexts::C end -struct VJP!{F,P,B,I,C} +struct VJP!{F,P,B,I,V,C} f::F prep::P backend::B input::I + v_buffer::V contexts::C end -function (po::JVP!)(res::AbstractVector, v::AbstractVector) - (; f, backend, input, contexts, prep) = po - pushforward!(f, (res,), prep, backend, input, (v,), contexts...) +function (po::JVP!)(res::AbstractVector, v_wrongtype::AbstractVector) + (; f, backend, input, v_buffer, contexts, prep) = po + copyto!(v_buffer, v_wrongtype) + pushforward!(f, (res,), prep, backend, input, (v_buffer,), contexts...) return res end -function (po::VJP!)(res::AbstractVector, v::AbstractVector) - (; f, backend, input, contexts, prep) = po - pullback!(f, (res,), prep, backend, input, (v,), contexts...) +function (po::VJP!)(res::AbstractVector, v_wrongtype::AbstractVector) + (; f, backend, input, v_buffer, contexts, prep) = po + copyto!(v_buffer, v_wrongtype) + pullback!(f, (res,), prep, backend, input, (v_buffer,), contexts...) return res end @@ -36,38 +37,70 @@ function build_A( x::AbstractArray, y::AbstractArray, z, + c, args...; suggested_backend::AbstractADType, ) return build_A_aux( - implicit.representation, implicit, x, y, z, args...; suggested_backend + implicit.representation, implicit, x, y, z, c, args...; suggested_backend ) end -function build_A_aux(::MatrixRepresentation, implicit, x, y, z, args...; suggested_backend) - (; conditions, backend, prep_A) = implicit - actual_backend = isnothing(backend) ? suggested_backend : backend +function build_A_aux( + ::MatrixRepresentation, implicit, x, y, z, c, args...; suggested_backend +) + (; conditions, backends, prep_A) = implicit + actual_backend = isnothing(backends) ? suggested_backend : backends.y contexts = (Constant(x), Constant(z), map(Constant, args)...) - A = jacobian(Switch12(conditions), prep_A..., actual_backend, y, contexts...) + if isnothing(prep_A) + A = jacobian(Switch12(conditions), actual_backend, y, contexts...) + else + A = jacobian(Switch12(conditions), prep_A, actual_backend, y, contexts...) + end return factorize(A) end function build_A_aux( - ::OperatorRepresentation, implicit, x, y, z, args...; suggested_backend -) - (; conditions, backend, prep_A) = implicit - actual_backend = isnothing(backend) ? suggested_backend : backend + ::OperatorRepresentation{package,symmetric,hermitian}, + implicit, + x, + y, + z, + c, + args...; + suggested_backend, +) where {package,symmetric,hermitian} + T = Base.promote_eltype(x, y, c) + (; conditions, backends, prep_A) = implicit + actual_backend = isnothing(backends) ? suggested_backend : backends.y contexts = (Constant(x), Constant(z), map(Constant, args)...) f_vec = VecToVec(Switch12(conditions), y) y_vec = vec(y) dy_vec = vec(zero(y)) - prep_A_same = prepare_pushforward_same_point( - f_vec, prep_A..., actual_backend, y_vec, (dy_vec,), contexts... - ) - prod! = JVP!(f_vec, prep_A_same, actual_backend, y_vec, contexts) - return LinearOperator( - eltype(y), length(y), length(y), SYMMETRIC, HERMITIAN, prod!, typeof(y_vec) - ) + if isnothing(prep_A) + prep_A_same = prepare_pushforward_same_point( + f_vec, actual_backend, y_vec, (dy_vec,), contexts...; strict=implicit.strict + ) + else + prep_A_same = prepare_pushforward_same_point( + f_vec, prep_A, actual_backend, y_vec, (dy_vec,), contexts... + ) + end + prod! = JVP!(f_vec, prep_A_same, actual_backend, y_vec, dy_vec, contexts) + if package == :LinearOperators + return LinearOperator( + T, length(c), length(y), symmetric, hermitian, prod!; S=typeof(dy_vec) + ) + elseif package == :LinearMaps + return FunctionMap{T}( + prod!, + length(c), + length(y); + ismutating=true, + issymmetric=symmetric, + ishermitian=hermitian, + ) + end end ## Aᵀ @@ -77,40 +110,72 @@ function build_Aᵀ( x::AbstractArray, y::AbstractArray, z, + c, args...; suggested_backend::AbstractADType, ) return build_Aᵀ_aux( - implicit.representation, implicit, x, y, z, args...; suggested_backend + implicit.representation, implicit, x, y, z, c, args...; suggested_backend ) end -function build_Aᵀ_aux(::MatrixRepresentation, implicit, x, y, z, args...; suggested_backend) - (; conditions, backend, prep_Aᵀ) = implicit - actual_backend = isnothing(backend) ? suggested_backend : backend +function build_Aᵀ_aux( + ::MatrixRepresentation, implicit, x, y, z, c, args...; suggested_backend +) + (; conditions, backends, prep_Aᵀ) = implicit + actual_backend = isnothing(backends) ? suggested_backend : backends.y contexts = (Constant(x), Constant(z), map(Constant, args)...) - Aᵀ = transpose( - jacobian(Switch12(conditions), prep_Aᵀ..., actual_backend, y, contexts...) - ) + if isnothing(prep_Aᵀ) + Aᵀ = transpose(jacobian(Switch12(conditions), actual_backend, y, contexts...)) + else + Aᵀ = transpose( + jacobian(Switch12(conditions), prep_Aᵀ, actual_backend, y, contexts...) + ) + end return factorize(Aᵀ) end function build_Aᵀ_aux( - ::OperatorRepresentation, implicit, x, y, z, args...; suggested_backend -) - (; conditions, backend, prep_Aᵀ) = implicit - actual_backend = isnothing(backend) ? suggested_backend : backend + ::OperatorRepresentation{package,symmetric,hermitian}, + implicit, + x, + y, + z, + c, + args...; + suggested_backend, +) where {package,symmetric,hermitian} + T = Base.promote_eltype(x, y, c) + (; conditions, backends, prep_Aᵀ) = implicit + actual_backend = isnothing(backends) ? suggested_backend : backends.y contexts = (Constant(x), Constant(z), map(Constant, args)...) f_vec = VecToVec(Switch12(conditions), y) y_vec = vec(y) - dc_vec = vec(zero(y)) - prep_Aᵀ_same = prepare_pullback_same_point( - f_vec, prep_Aᵀ..., actual_backend, y_vec, (dc_vec,), contexts... - ) - prod! = VJP!(f_vec, prep_Aᵀ_same, actual_backend, y_vec, contexts) - return LinearOperator( - eltype(y), length(y), length(y), SYMMETRIC, HERMITIAN, prod!, typeof(y_vec) - ) + dc_vec = vec(zero(c)) + if isnothing(prep_Aᵀ) + prep_Aᵀ_same = prepare_pullback_same_point( + f_vec, actual_backend, y_vec, (dc_vec,), contexts...; strict=implicit.strict + ) + else + prep_Aᵀ_same = prepare_pullback_same_point( + f_vec, prep_Aᵀ, actual_backend, y_vec, (dc_vec,), contexts... + ) + end + prod! = VJP!(f_vec, prep_Aᵀ_same, actual_backend, y_vec, dc_vec, contexts) + if package == :LinearOperators + return LinearOperator( + T, length(y), length(c), symmetric, hermitian, prod!; S=typeof(dc_vec) + ) + elseif package == :LinearMaps + return FunctionMap{T}( + prod!, + length(y), + length(c); + ismutating=true, + issymmetric=symmetric, + ishermitian=hermitian, + ) + end end ## B @@ -120,37 +185,32 @@ function build_B( x::AbstractArray, y::AbstractArray, z, + c, args...; suggested_backend::AbstractADType, ) - return build_B_aux( - implicit.representation, implicit, x, y, z, args...; suggested_backend - ) -end - -function build_B_aux(::MatrixRepresentation, implicit, x, y, z, args...; suggested_backend) - (; conditions, backend, prep_B) = implicit - actual_backend = isnothing(backend) ? suggested_backend : backend - contexts = (Constant(y), Constant(z), map(Constant, args)...) - return jacobian(conditions, prep_B..., actual_backend, x, contexts...) -end - -function build_B_aux( - ::OperatorRepresentation, implicit, x, y, z, args...; suggested_backend -) - (; conditions, backend, prep_B) = implicit - actual_backend = isnothing(backend) ? suggested_backend : backend + (; conditions, backends, prep_B) = implicit + actual_backend = isnothing(backends) ? suggested_backend : backends.x contexts = (Constant(y), Constant(z), map(Constant, args)...) f_vec = VecToVec(conditions, x) x_vec = vec(x) dx_vec = vec(zero(x)) - prep_B_same = prepare_pushforward_same_point( - f_vec, prep_B..., actual_backend, x_vec, (dx_vec,), contexts... - ) - prod! = JVP!(f_vec, prep_B_same, actual_backend, x_vec, contexts) - return LinearOperator( - eltype(y), length(y), length(x), SYMMETRIC, HERMITIAN, prod!, typeof(x_vec) - ) + if isnothing(prep_B) + prep_B_same = prepare_pushforward_same_point( + f_vec, actual_backend, x_vec, (dx_vec,), contexts... + ) + else + prep_B_same = prepare_pushforward_same_point( + f_vec, prep_B, actual_backend, x_vec, (dx_vec,), contexts... + ) + end + function B_fun(dx_vec_wrongtype) + copyto!(dx_vec, dx_vec_wrongtype) + return pushforward( + f_vec, prep_B_same, actual_backend, x_vec, (dx_vec,), contexts... + )[1] + end + return B_fun end ## Bᵀ @@ -160,35 +220,28 @@ function build_Bᵀ( x::AbstractArray, y::AbstractArray, z, + c, args...; suggested_backend::AbstractADType, ) - return build_Bᵀ_aux( - implicit.representation, implicit, x, y, z, args...; suggested_backend - ) -end - -function build_Bᵀ_aux(::MatrixRepresentation, implicit, x, y, z, args...; suggested_backend) - (; conditions, backend, prep_Bᵀ) = implicit - actual_backend = isnothing(backend) ? suggested_backend : backend - contexts = (Constant(y), Constant(z), map(Constant, args)...) - return transpose(jacobian(conditions, prep_Bᵀ..., actual_backend, x, contexts...)) -end - -function build_Bᵀ_aux( - ::OperatorRepresentation, implicit, x, y, z, args...; suggested_backend -) - (; conditions, backend, prep_Bᵀ) = implicit - actual_backend = isnothing(backend) ? suggested_backend : backend + (; conditions, backends, prep_Bᵀ) = implicit + actual_backend = isnothing(backends) ? suggested_backend : backends.x contexts = (Constant(y), Constant(z), map(Constant, args)...) f_vec = VecToVec(conditions, x) x_vec = vec(x) - dc_vec = vec(zero(y)) - prep_Bᵀ_same = prepare_pullback_same_point( - f_vec, prep_Bᵀ..., actual_backend, x_vec, (dc_vec,), contexts... - ) - prod! = VJP!(f_vec, prep_Bᵀ_same, actual_backend, x_vec, contexts) - return LinearOperator( - eltype(y), length(x), length(y), SYMMETRIC, HERMITIAN, prod!, typeof(x_vec) - ) + dc_vec = vec(zero(c)) + if isnothing(prep_Bᵀ) + prep_Bᵀ_same = prepare_pullback_same_point( + f_vec, actual_backend, x_vec, (dc_vec,), contexts...; strict=implicit.strict + ) + else + prep_Bᵀ_same = prepare_pullback_same_point( + f_vec, prep_Bᵀ, actual_backend, x_vec, (dc_vec,), contexts... + ) + end + function Bᵀ_fun(dc_vec_wrongtype) + copyto!(dc_vec, dc_vec_wrongtype) + return pullback(f_vec, prep_Bᵀ_same, actual_backend, x_vec, (dc_vec,), contexts...)[1] + end + return Bᵀ_fun end diff --git a/src/implicit_function.jl b/src/implicit_function.jl index acba5687..faeaa626 100644 --- a/src/implicit_function.jl +++ b/src/implicit_function.jl @@ -19,9 +19,9 @@ This requires solving a linear system `A * J = -B` where `A = ∂₂c`, `B = ∂ ImplicitFunction( solver, - conditions, - linear_solver=KrylovLinearSolver(), + conditions; representation=OperatorRepresentation(), + linear_solver=IterativeLinearSolver(), backend=nothing, preparation=nothing, input_example=nothing, @@ -34,66 +34,119 @@ This requires solving a linear system `A * J = -B` where `A = ∂₂c`, `B = ∂ ## Keyword arguments -- `linear_solver`: a callable to solve linear systems with two required methods, one for `(A, b)` (single solve) and one for `(A, B)` (batched solve) (defaults to [`KrylovLinearSolver`](@ref)) - `representation`: either [`MatrixRepresentation`](@ref) or [`OperatorRepresentation`](@ref) -- `backend::AbstractADType`: either `nothing` or an object from [ADTypes.jl](https://github.com/SciML/ADTypes.jl) dictating how how the conditions will be differentiated -- `preparation`: either `nothing` or a mode object from [ADTypes.jl](https://github.com/SciML/ADTypes.jl): `ADTypes.ForwardMode()`, `ADTypes.ReverseMode()` or `ADTypes.ForwardOrReverseMode()` -- `input_example`: either `nothing` or a tuple `(x, args...)` used to prepare differentiation +- `linear_solver`: a callable to solve linear systems with two required methods, one for `(A, b)` (single solve) and one for `(A, B)` (batched solve). It defaults to [`IterativeLinearSolver`](@ref) but can also be the built-in `\\`, or a user-provided function. +- `backend::AbstractADType`: specifies how the `conditions` will be differentiated with respect to `x` and `y`. It can be either + - `nothing`, which means that the external autodiff system will be used + - a single object from [ADTypes.jl](https://github.com/SciML/ADTypes.jl) + - a named tuple `(; x, y)` of objects from [ADTypes.jl](https://github.com/SciML/ADTypes.jl) +- `preparation`: either `nothing` or a mode object from [ADTypes.jl](https://github.com/SciML/ADTypes.jl): `ADTypes.ForwardMode()`, `ADTypes.ReverseMode()` or `ADTypes.ForwardOrReverseMode()`. +- `input_example`: either `nothing` or a tuple `(x, args...)` used to prepare differentiation. +- `strict::Val=Val(true)`: whether or not to enforce a strict match in [DifferentiationInterface.jl](https://github.com/JuliaDiff/DifferentiationInterface.jl) between the preparation and the execution types. Relaxing this to `strict=Val(false)` can prove necessary when working with custom array types like ComponentArrays.jl, which are not always compatible with iterative linear solvers. """ struct ImplicitFunction{ F, C, L, R<:AbstractRepresentation, - B<:Union{Nothing,AbstractADType}, + B<:Union{ + Nothing, # + AbstractADType, + NamedTuple{(:x, :y),<:Tuple{AbstractADType,AbstractADType}}, + }, P<:Union{Nothing,AbstractMode}, PA, PAT, PB, PBT, + _strict, } solver::F conditions::C linear_solver::L representation::R - backend::B + backends::B preparation::P prep_A::PA prep_Aᵀ::PAT prep_B::PB prep_Bᵀ::PBT + strict::Val{_strict} end function ImplicitFunction( solver, conditions; - linear_solver=KrylovLinearSolver(), representation=OperatorRepresentation(), - backend=nothing, + linear_solver=IterativeLinearSolver(), + backends=nothing, preparation=nothing, input_example=nothing, + strict::Val=Val(true), ) - if isnothing(preparation) || isnothing(backend) || isnothing(input_example) - prep_A = () - prep_Aᵀ = () - prep_B = () - prep_Bᵀ = () + if isnothing(preparation) || isnothing(backends) || isnothing(input_example) + prep_A = nothing + prep_Aᵀ = nothing + prep_B = nothing + prep_Bᵀ = nothing else + real_backends = backends isa AbstractADType ? (; x=backends, y=backends) : backends x, args = first(input_example), Base.tail(input_example) y, z = solver(x, args...) + c = conditions(x, y, z, args...) if preparation isa Union{ForwardMode,ForwardOrReverseMode} - prep_A = (prepare_A(representation, x, y, z, args...; conditions, backend),) - prep_B = (prepare_B(representation, x, y, z, args...; conditions, backend),) + prep_A = prepare_A( + representation, + x, + y, + z, + c, + args...; + conditions, + backend=real_backends.y, + strict, + ) + prep_B = prepare_B( + representation, + x, + y, + z, + c, + args...; + conditions, + backend=real_backends.x, + strict, + ) else - prep_A = () - prep_B = () + prep_A = nothing + prep_B = nothing end if preparation isa Union{ReverseMode,ForwardOrReverseMode} - prep_Aᵀ = (prepare_Aᵀ(representation, x, y, z, args...; conditions, backend),) - prep_Bᵀ = (prepare_Bᵀ(representation, x, y, z, args...; conditions, backend),) + prep_Aᵀ = prepare_Aᵀ( + representation, + x, + y, + z, + c, + args...; + conditions, + backend=real_backends.y, + strict, + ) + prep_Bᵀ = prepare_Bᵀ( + representation, + x, + y, + z, + c, + args...; + conditions, + backend=real_backends.x, + strict, + ) else - prep_Aᵀ = () - prep_Bᵀ = () + prep_Aᵀ = nothing + prep_Bᵀ = nothing end end return ImplicitFunction( @@ -101,26 +154,27 @@ function ImplicitFunction( conditions, linear_solver, representation, - backend, + backends, preparation, prep_A, prep_Aᵀ, prep_B, prep_Bᵀ, + strict, ) end function Base.show(io::IO, implicit::ImplicitFunction) - (; solver, conditions, backend, linear_solver, representation, preparation) = implicit + (; solver, conditions, backends, linear_solver, representation, preparation) = implicit return print( io, """ ImplicitFunction( $solver, $conditions; - linear_solver=$linear_solver, representation=$representation, - backend=$backend, + linear_solver=$linear_solver, + backends=$backends, preparation=$preparation, ) """, diff --git a/src/preparation.jl b/src/preparation.jl index 41f8ccb4..c6b09b72 100644 --- a/src/preparation.jl +++ b/src/preparation.jl @@ -3,12 +3,14 @@ function prepare_A( x::AbstractArray, y::AbstractArray, z, + c, args...; conditions, backend::AbstractADType, + strict::Val, ) contexts = (Constant(x), Constant(z), map(Constant, args)...) - return prepare_jacobian(Switch12(conditions), backend, y, contexts...) + return prepare_jacobian(Switch12(conditions), backend, y, contexts...; strict) end function prepare_A( @@ -16,15 +18,17 @@ function prepare_A( x::AbstractArray, y::AbstractArray, z, + c, args...; conditions, backend::AbstractADType, + strict::Val, ) contexts = (Constant(x), Constant(z), map(Constant, args)...) f_vec = VecToVec(Switch12(conditions), y) y_vec = vec(y) dy_vec = vec(zero(y)) - return prepare_pushforward(f_vec, backend, y_vec, (dy_vec,), contexts...) + return prepare_pushforward(f_vec, backend, y_vec, (dy_vec,), contexts...; strict) end function prepare_Aᵀ( @@ -32,12 +36,14 @@ function prepare_Aᵀ( x::AbstractArray, y::AbstractArray, z, + c, args...; conditions, backend::AbstractADType, + strict::Val, ) contexts = (Constant(x), Constant(z), map(Constant, args)...) - return prepare_jacobian(Switch12(conditions), backend, y, contexts...) + return prepare_jacobian(Switch12(conditions), backend, y, contexts...; strict) end function prepare_Aᵀ( @@ -45,71 +51,51 @@ function prepare_Aᵀ( x::AbstractArray, y::AbstractArray, z, + c, args...; conditions, backend::AbstractADType, + strict::Val, ) contexts = (Constant(x), Constant(z), map(Constant, args)...) f_vec = VecToVec(Switch12(conditions), y) y_vec = vec(y) - dc_vec = vec(zero(y)) # same size - return prepare_pullback(f_vec, backend, y_vec, (dc_vec,), contexts...) + dc_vec = vec(zero(c)) + return prepare_pullback(f_vec, backend, y_vec, (dc_vec,), contexts...; strict) end function prepare_B( - ::MatrixRepresentation, - x::AbstractArray, - y::AbstractArray, - z, - args...; - conditions, - backend::AbstractADType, -) - contexts = (Constant(y), Constant(z), map(Constant, args)...) - return prepare_jacobian(conditions, backend, x, contexts...) -end - -function prepare_B( - ::OperatorRepresentation, + ::AbstractRepresentation, x::AbstractArray, y::AbstractArray, z, + c, args...; conditions, backend::AbstractADType, + strict::Val, ) contexts = (Constant(y), Constant(z), map(Constant, args)...) f_vec = VecToVec(conditions, x) x_vec = vec(x) dx_vec = vec(zero(x)) - return prepare_pushforward(f_vec, backend, x_vec, (dx_vec,), contexts...) + return prepare_pushforward(f_vec, backend, x_vec, (dx_vec,), contexts...; strict) end function prepare_Bᵀ( - ::MatrixRepresentation, - x::AbstractArray, - y::AbstractArray, - z, - args...; - conditions, - backend::AbstractADType, -) - contexts = (Constant(y), Constant(z), map(Constant, args)...) - return prepare_jacobian(conditions, backend, x, contexts...) -end - -function prepare_Bᵀ( - ::OperatorRepresentation, + ::AbstractRepresentation, x::AbstractArray, y::AbstractArray, z, + c, args...; conditions, backend::AbstractADType, + strict::Val, ) contexts = (Constant(y), Constant(z), map(Constant, args)...) f_vec = VecToVec(conditions, x) x_vec = vec(x) - dc_vec = vec(zero(y)) # same size - return prepare_pullback(f_vec, backend, x_vec, (dc_vec,), contexts...) + dc_vec = vec(zero(c)) + return prepare_pullback(f_vec, backend, x_vec, (dc_vec,), contexts...; strict) end diff --git a/src/settings.jl b/src/settings.jl index e2f4c5b9..dbfb876c 100644 --- a/src/settings.jl +++ b/src/settings.jl @@ -1,45 +1,63 @@ ## Linear solver """ - KrylovLinearSolver + IterativeLinearSolver Callable object that can solve linear systems `Ax = b` and `AX = B` in the same way as the built-in `\\`. -Uses an iterative solver from [Krylov.jl](https://github.com/JuliaSmoothOptimizers/Krylov.jl) under the hood. # Constructor - KrylovLinearSolver(; verbose=true) + IterativeLinearSolver{package}(; kwargs...) -If `verbose` is `true`, the solver logs a warning in case of failure. -Otherwise it will fail silently, and may return solutions that do not exactly satisfy the linear system. +The type parameter `package` can be either: + +- `:Krylov` to use the solver `gmres` from [Krylov.jl](https://github.com/JuliaSmoothOptimizers/Krylov.jl) +- `:IterativeSolvers` to use the solver `gmres` from [IterativeSolvers.jl](https://github.com/JuliaLinearAlgebra/IterativeSolvers.jl) + +Keyword arguments are passed on to the respective solver. # Callable behavior - (::KylovLinearSolver)(A, b::AbstractVector) + (::IterativeLinearSolver)(A, b::AbstractVector) Solve a linear system with a single right-hand side. - (::KrylovLinearSolver)(A, B::AbstractMatrix) + (::IterativeLinearSolver)(A, B::AbstractMatrix) Solve a linear system with multiple right-hand sides. """ -Base.@kwdef struct KrylovLinearSolver - verbose::Bool = true +struct IterativeLinearSolver{package,K} + kwargs::K + function IterativeLinearSolver{package}(; kwargs...) where {package} + @assert package in [:Krylov, :IterativeSolvers] + return new{package,typeof(kwargs)}(kwargs) + end +end + +IterativeLinearSolver() = IterativeLinearSolver{:Krylov}() + +function (solver::IterativeLinearSolver{:Krylov})(A, b::AbstractVector) + x, stats = Krylov.gmres(A, b; solver.kwargs...) + return x end -function (solver::KrylovLinearSolver)(A, b::AbstractVector) - x, stats = gmres(A, b) - if !stats.solved || stats.inconsistent - solver.verbose && - @warn "Failed to solve the linear system in the implicit function theorem with `Krylov.gmres`" stats +function (solver::IterativeLinearSolver{:Krylov})(A, B::AbstractMatrix) + # TODO: use block_gmres + X = mapreduce(hcat, eachcol(B)) do b + x, _ = Krylov.gmres(A, b; solver.kwargs...) + x end + return X +end + +function (solver::IterativeLinearSolver{:IterativeSolvers})(A, b::AbstractVector) + x = IterativeSolvers.gmres(A, b; solver.kwargs...) return x end -function (solver::KrylovLinearSolver)(A, B::AbstractMatrix) - # X, stats = block_gmres(A, B) # https://github.com/JuliaSmoothOptimizers/Krylov.jl/issues/854 +function (solver::IterativeLinearSolver{:IterativeSolvers})(A, B::AbstractMatrix) X = mapreduce(hcat, eachcol(B)) do b - solver(A, b) + IterativeSolvers.gmres(A, b; solver.kwargs...) end return X end @@ -51,7 +69,7 @@ abstract type AbstractRepresentation end """ MatrixRepresentation -Specify that the matrices involved in the implicit function theorem should be represented explicitly, with all their coefficients. +Specify that the matrix `A` involved in the implicit function theorem should be represented explicitly, with all its coefficients. # See also @@ -63,14 +81,34 @@ struct MatrixRepresentation <: AbstractRepresentation end """ OperatorRepresentation -Specify that the matrices involved in the implicit function theorem should be represented lazily, as linear operators from [LinearOperators.jl](https://github.com/JuliaSmoothOptimizers/LinearOperators.jl). +Specify that the matrix `A` involved in the implicit function theorem should be represented lazily. + +# Constructors + + OperatorRepresentation{package}(; symmetric=false, hermitian=false) + +The type parameter `package` can be either: + +- `:LinearOperators` to use a wrapper from [LinearOperators.jl](https://github.com/JuliaSmoothOptimizers/LinearOperators.jl) (the default) +- `:LinearMaps` to use a wrapper from [LinearMaps.jl](https://github.com/JuliaLinearAlgebra/LinearMaps.jl) + +The keyword arguments `symmetric` and `hermitian` give additional properties of the Jacobian of the `conditions` with respect to the solution `y`, in case you can prove them. # See also - [`ImplicitFunction`](@ref) - [`MatrixRepresentation`](@ref) """ -struct OperatorRepresentation <: AbstractRepresentation end +struct OperatorRepresentation{package,symmetric,hermitian} <: AbstractRepresentation + function OperatorRepresentation{package}(; + symmetric::Bool=false, hermitian::Bool=false + ) where {package} + @assert package in [:LinearOperators, :LinearMaps] + return new{package,symmetric,hermitian}() + end +end + +OperatorRepresentation() = OperatorRepresentation{:LinearOperators}() ## Preparation diff --git a/test/examples.jl b/test/examples.jl new file mode 100644 index 00000000..117b5e78 --- /dev/null +++ b/test/examples.jl @@ -0,0 +1,15 @@ +@testitem "intro" begin + include(joinpath(dirname(@__DIR__), "examples", "0_intro.jl")) +end + +@testitem "basic" begin + include(joinpath(dirname(@__DIR__), "examples", "1_basic.jl")) +end + +@testitem "advanced" begin + include(joinpath(dirname(@__DIR__), "examples", "2_advanced.jl")) +end + +@testitem "tricks" begin + include(joinpath(dirname(@__DIR__), "examples", "3_tricks.jl")) +end diff --git a/test/formalities.jl b/test/formalities.jl new file mode 100644 index 00000000..09e8b2f2 --- /dev/null +++ b/test/formalities.jl @@ -0,0 +1,39 @@ +using TestItems + +@testitem "Code quality" begin + using Aqua + using ForwardDiff: ForwardDiff + using Zygote: Zygote + Aqua.test_all(ImplicitDifferentiation; ambiguities=false, undocumented_names=true) +end +@testitem "Formatting" begin + using JuliaFormatter + @test format(ImplicitDifferentiation; verbose=false, overwrite=false) +end +@testitem "Static checking" begin + using JET + using ForwardDiff: ForwardDiff + using Zygote: Zygote + JET.test_package(ImplicitDifferentiation; target_defined_modules=true) +end +@testitem "Imports" begin + using ExplicitImports + using ForwardDiff: ForwardDiff + using Zygote: Zygote + @test check_no_implicit_imports(ImplicitDifferentiation) === nothing + @test check_no_stale_explicit_imports(ImplicitDifferentiation) === nothing + @test check_all_explicit_imports_via_owners(ImplicitDifferentiation) === nothing + @test_broken check_all_explicit_imports_are_public(ImplicitDifferentiation) === nothing + @test check_all_qualified_accesses_via_owners(ImplicitDifferentiation) === nothing + @test check_no_self_qualified_accesses(ImplicitDifferentiation) === nothing +end +@testitem "Doctests" begin + using Documenter + Documenter.DocMeta.setdocmeta!( + ImplicitDifferentiation, + :DocTestSetup, + :(using ImplicitDifferentiation); + recursive=true, + ) + doctest(ImplicitDifferentiation) +end diff --git a/test/runtests.jl b/test/runtests.jl index 93bafb72..b9e874db 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,59 +1,3 @@ -## Imports +using TestItemRunner -using Aqua -using Documenter -using ExplicitImports -using ForwardDiff: ForwardDiff -using ImplicitDifferentiation -using JET -using JuliaFormatter -using Random -using Test -using Zygote: Zygote - -Documenter.DocMeta.setdocmeta!( - ImplicitDifferentiation, :DocTestSetup, :(using ImplicitDifferentiation); recursive=true -) - -EXAMPLES_DIR_JL = joinpath(dirname(@__DIR__), "examples") - -## Test sets - -@testset verbose = true "ImplicitDifferentiation.jl" begin - @testset "Code quality (Aqua.jl)" begin - Aqua.test_all( - ImplicitDifferentiation; ambiguities=false, deps_compat=(check_extras = false) - ) - end - @testset "Formatting (JuliaFormatter.jl)" begin - @test format(ImplicitDifferentiation; verbose=false, overwrite=false) - end - @testset "Static checking (JET.jl)" begin - JET.test_package(ImplicitDifferentiation; target_defined_modules=true) - end - @testset "Imports (ExplicitImports.jl)" begin - @test check_no_implicit_imports(ImplicitDifferentiation) === nothing - @test check_no_stale_explicit_imports(ImplicitDifferentiation) === nothing - @test check_all_explicit_imports_via_owners(ImplicitDifferentiation) === nothing - @test_broken check_all_explicit_imports_are_public(ImplicitDifferentiation) === - nothing - @test check_all_qualified_accesses_via_owners(ImplicitDifferentiation) === nothing - @test check_no_self_qualified_accesses(ImplicitDifferentiation) === nothing - end - @testset "Doctests (Documenter.jl)" begin - doctest(ImplicitDifferentiation) - end - @testset verbose = true "Examples" begin - @info "Example tests" - for file in readdir(EXAMPLES_DIR_JL) - @info "$file" - @testset "$file" begin - include(joinpath(EXAMPLES_DIR_JL, file)) - end - end - end - @testset verbose = true "Systematic" begin - @info "Systematic tests" - include("systematic.jl") - end -end; +@run_package_tests diff --git a/test/systematic.jl b/test/systematic.jl index 32bfd4e1..71b5a4a1 100644 --- a/test/systematic.jl +++ b/test/systematic.jl @@ -1,31 +1,32 @@ -using ADTypes -using ADTypes: ForwardMode, ReverseMode -using ForwardDiff: ForwardDiff -using ImplicitDifferentiation -using Test -using Zygote: Zygote, ZygoteRuleConfig -using FiniteDiff: FiniteDiff +using TestItems -include("utils.jl") +@testitem "Systematic tests" begin + using ADTypes + using ADTypes: ForwardMode, ReverseMode + using ForwardDiff: ForwardDiff + using ImplicitDifferentiation + using Test + using Zygote: Zygote, ZygoteRuleConfig + using FiniteDiff: FiniteDiff -## Parameter combinations + include("utils.jl") -linear_solver_candidates = [ID.KrylovLinearSolver(), \] -representation_candidates = [MatrixRepresentation(), OperatorRepresentation()] -backend_candidates = [nothing, AutoForwardDiff(), AutoZygote()] -preparation_candidates = [nothing, ForwardMode(), ReverseMode()] -x_candidates = [float.(1:6), reshape(float.(1:12), 6, 2)] + ## Parameter combinations -## Test loop + representation_linear_solver_candidates = [ + (MatrixRepresentation(), \), # + (OperatorRepresentation{:LinearOperators}(), IterativeLinearSolver{:Krylov}()), + (OperatorRepresentation{:LinearMaps}(), IterativeLinearSolver{:IterativeSolvers}()), + ] + backend_candidates = [nothing, AutoForwardDiff(), AutoZygote()] + preparation_candidates = [nothing, ForwardMode(), ReverseMode()] + x_candidates = [float.(1:6), reshape(float.(1:12), 6, 2)] -@testset verbose = true "Systematic tests" begin - @testset for representation in representation_candidates - for (linear_solver, backend, preparation, x) in Iterators.product( - linear_solver_candidates, - backend_candidates, - preparation_candidates, - x_candidates, - ) + ## Test loop + + @testset for (representation, linear_solver) in representation_linear_solver_candidates + for (backend, preparation, x) in + Iterators.product(backend_candidates, preparation_candidates, x_candidates) x_type = typeof(x) @info "Testing" linear_solver backend representation preparation x_type if (representation isa OperatorRepresentation && linear_solver == \) @@ -35,7 +36,17 @@ x_candidates = [float.(1:6), reshape(float.(1:12), 6, 2)] x = Float64.(1:6) @testset "$((; linear_solver, backend, preparation, x_type))" begin test_implicit( - outer_backends, x; representation, backend, preparation, linear_solver + outer_backends, + x; + representation, + backends=isnothing(backend) ? nothing : (; x=backend, y=backend), + preparation, + linear_solver, + strict=if linear_solver isa IterativeLinearSolver{:IterativeSolvers} + Val(false) + else + Val(true) + end, ) end end