diff --git a/Project.toml b/Project.toml index 22c78b59..5b11776e 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ImplicitDifferentiation" uuid = "57b37032-215b-411a-8a7c-41a003a55207" authors = ["Guillaume Dalle", "Mohamed Tarek"] -version = "0.7.1" +version = "0.7.2" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/docs/src/faq.md b/docs/src/faq.md index b767873c..dbd0fd65 100644 --- a/docs/src/faq.md +++ b/docs/src/faq.md @@ -15,16 +15,11 @@ However, this can be switched to any other "inner" backend compatible with [Diff ## Input and output types -### Vectors - -Functions that eat or spit out arbitrary vectors are supported, as long as the forward mapping _and_ conditions return vectors of the same size. - -If you deal with small vectors (say, less than 100 elements), consider using [StaticArrays.jl](https://github.com/JuliaArrays/StaticArrays.jl) for increased performance. - ### Arrays -Functions that eat or spit out matrices and higher-order tensors are not supported. -You can use `vec` and `reshape` for the conversion to and from vectors. +Functions that eat or spit out arbitrary arrays are supported, as long as the forward mapping _and_ conditions return arrays of the same size. + +If you deal with small arrays (say, less than 100 elements), consider using [StaticArrays.jl](https://github.com/JuliaArrays/StaticArrays.jl) for increased performance. ### Scalars diff --git a/examples/0_intro.jl b/examples/0_intro.jl index 9d4bdeae..c85cad87 100644 --- a/examples/0_intro.jl +++ b/examples/0_intro.jl @@ -29,7 +29,7 @@ This is essentially the componentwise square root function but with an additiona We can check that it does what it's supposed to do. =# -x = [4.0, 9.0] +x = [1.0 2.0; 3.0 4.0] badsqrt(x) @test badsqrt(x) ≈ sqrt.(x) #src @@ -37,7 +37,7 @@ badsqrt(x) Of course the Jacobian has an explicit formula. =# -J = Diagonal(0.5 ./ sqrt.(x)) +J = Diagonal(0.5 ./ vec(sqrt.(x))) #= However, things start to go wrong when we compute it with autodiff, due to the [limitations of ForwardDiff.jl](https://juliadiff.org/ForwardDiff.jl/stable/user/limitations/) and [those of Zygote.jl](https://fluxml.ai/Zygote.jl/stable/limitations/). diff --git a/ext/ImplicitDifferentiationChainRulesCoreExt.jl b/ext/ImplicitDifferentiationChainRulesCoreExt.jl index e55e4408..030fb876 100644 --- a/ext/ImplicitDifferentiationChainRulesCoreExt.jl +++ b/ext/ImplicitDifferentiationChainRulesCoreExt.jl @@ -14,7 +14,7 @@ using ImplicitDifferentiation: ImplicitDifferentiation.chainrules_suggested_backend(rc::RuleConfig) = AutoChainRules(rc) function ChainRulesCore.rrule( - rc::RuleConfig, implicit::ImplicitFunction, x::AbstractVector, args::Vararg{Any,N}; + rc::RuleConfig, implicit::ImplicitFunction, x::AbstractArray, args::Vararg{Any,N}; ) where {N} y, z = implicit(x, args...) @@ -25,8 +25,10 @@ function ChainRulesCore.rrule( function implicit_pullback((dy, dz)) dy = unthunk(dy) - dc = implicit.linear_solver(Aᵀ, -dy) - dx = Bᵀ * dc + dy_vec = vec(dy) + dc_vec = implicit.linear_solver(Aᵀ, -dy_vec) + dx_vec = Bᵀ * dc_vec + dx = reshape(dx_vec, size(x)) df = NoTangent() dargs = ntuple(unimplemented_tangent, N) return (df, project_x(dx), dargs...) diff --git a/ext/ImplicitDifferentiationForwardDiffExt.jl b/ext/ImplicitDifferentiationForwardDiffExt.jl index fff6823d..f7b9a8e6 100644 --- a/ext/ImplicitDifferentiationForwardDiffExt.jl +++ b/ext/ImplicitDifferentiationForwardDiffExt.jl @@ -5,7 +5,7 @@ using ForwardDiff: Dual, Partials, partials, value using ImplicitDifferentiation: ImplicitFunction, build_A, build_B function (implicit::ImplicitFunction)( - x_and_dx::AbstractVector{Dual{T,R,N}}, args... + x_and_dx::AbstractArray{Dual{T,R,N}}, args... ) where {T,R,N} x = value.(x_and_dx) y, z = implicit(x, args...) @@ -14,16 +14,17 @@ function (implicit::ImplicitFunction)( A = build_A(implicit, x, y, z, args...; suggested_backend) B = build_B(implicit, x, y, z, args...; suggested_backend) - dX = map(1:N) do k + dX = ntuple(Val(N)) do k partials.(x_and_dx, k) end - dC = mapreduce(hcat, dX) do dₖx - B * dₖx + dC_mat = mapreduce(hcat, dX) do dₖx + dₖx_vec = vec(dₖx) + dₖc_vec = B * dₖx_vec end - dY = implicit.linear_solver(A, -dC) + dY_mat = implicit.linear_solver(A, -dC_mat) - y_and_dy = map(eachindex(y)) do i - Dual{T}(y[i], Partials(ntuple(k -> dY[i, k], Val(N)))) + y_and_dy = map(LinearIndices(y)) do i + Dual{T}(y[i], Partials(ntuple(k -> dY_mat[i, k], Val(N)))) end return y_and_dy, z diff --git a/src/ImplicitDifferentiation.jl b/src/ImplicitDifferentiation.jl index 17e1122c..91c351b2 100644 --- a/src/ImplicitDifferentiation.jl +++ b/src/ImplicitDifferentiation.jl @@ -23,6 +23,7 @@ using Krylov: gmres using LinearOperators: LinearOperator using LinearAlgebra: factorize +include("utils.jl") include("settings.jl") include("preparation.jl") include("implicit_function.jl") diff --git a/src/execution.jl b/src/execution.jl index b34f64a4..74a3bc5a 100644 --- a/src/execution.jl +++ b/src/execution.jl @@ -1,31 +1,31 @@ const SYMMETRIC = false const HERMITIAN = false -struct JVP!{F,P,B,X,C} +struct JVP!{F,P,B,I,C} f::F prep::P backend::B - x::X + input::I contexts::C end -struct VJP!{F,P,B,X,C} +struct VJP!{F,P,B,I,C} f::F prep::P backend::B - x::X + input::I contexts::C end function (po::JVP!)(res::AbstractVector, v::AbstractVector) - (; f, backend, x, contexts, prep) = po - pushforward!(f, (res,), prep, backend, x, (v,), contexts...) + (; f, backend, input, contexts, prep) = po + pushforward!(f, (res,), prep, backend, input, (v,), contexts...) return res end function (po::VJP!)(res::AbstractVector, v::AbstractVector) - (; f, backend, x, contexts, prep) = po - pullback!(f, (res,), prep, backend, x, (v,), contexts...) + (; f, backend, input, contexts, prep) = po + pullback!(f, (res,), prep, backend, input, (v,), contexts...) return res end @@ -33,8 +33,8 @@ end function build_A( implicit::ImplicitFunction, - x::AbstractVector, - y::AbstractVector, + x::AbstractArray, + y::AbstractArray, z, args...; suggested_backend::AbstractADType, @@ -58,12 +58,15 @@ function build_A_aux( (; conditions, backend, prep_A) = implicit actual_backend = isnothing(backend) ? suggested_backend : backend contexts = (Constant(x), Constant(z), map(Constant, args)...) + f_vec = VecToVec(Switch12(conditions), y) + y_vec = vec(y) + dy_vec = vec(zero(y)) prep_A_same = prepare_pushforward_same_point( - Switch12(conditions), prep_A..., actual_backend, y, (zero(y),), contexts... + f_vec, prep_A..., actual_backend, y_vec, (dy_vec,), contexts... ) - prod! = JVP!(Switch12(conditions), prep_A_same, actual_backend, y, contexts) + prod! = JVP!(f_vec, prep_A_same, actual_backend, y_vec, contexts) return LinearOperator( - eltype(y), length(y), length(y), SYMMETRIC, HERMITIAN, prod!, typeof(y) + eltype(y), length(y), length(y), SYMMETRIC, HERMITIAN, prod!, typeof(y_vec) ) end @@ -71,8 +74,8 @@ end function build_Aᵀ( implicit::ImplicitFunction, - x::AbstractVector, - y::AbstractVector, + x::AbstractArray, + y::AbstractArray, z, args...; suggested_backend::AbstractADType, @@ -98,12 +101,15 @@ function build_Aᵀ_aux( (; conditions, backend, prep_Aᵀ) = implicit actual_backend = isnothing(backend) ? suggested_backend : backend contexts = (Constant(x), Constant(z), map(Constant, args)...) + f_vec = VecToVec(Switch12(conditions), y) + y_vec = vec(y) + dc_vec = vec(zero(y)) prep_Aᵀ_same = prepare_pullback_same_point( - Switch12(conditions), prep_Aᵀ..., actual_backend, y, (zero(y),), contexts... + f_vec, prep_Aᵀ..., actual_backend, y_vec, (dc_vec,), contexts... ) - prod! = VJP!(Switch12(conditions), prep_Aᵀ_same, actual_backend, y, contexts) + prod! = VJP!(f_vec, prep_Aᵀ_same, actual_backend, y_vec, contexts) return LinearOperator( - eltype(y), length(y), length(y), SYMMETRIC, HERMITIAN, prod!, typeof(y) + eltype(y), length(y), length(y), SYMMETRIC, HERMITIAN, prod!, typeof(y_vec) ) end @@ -111,8 +117,8 @@ end function build_B( implicit::ImplicitFunction, - x::AbstractVector, - y::AbstractVector, + x::AbstractArray, + y::AbstractArray, z, args...; suggested_backend::AbstractADType, @@ -135,12 +141,15 @@ function build_B_aux( (; conditions, backend, prep_B) = implicit actual_backend = isnothing(backend) ? suggested_backend : backend contexts = (Constant(y), Constant(z), map(Constant, args)...) + f_vec = VecToVec(conditions, x) + x_vec = vec(x) + dx_vec = vec(zero(x)) prep_B_same = prepare_pushforward_same_point( - conditions, prep_B..., actual_backend, x, (zero(x),), contexts... + f_vec, prep_B..., actual_backend, x_vec, (dx_vec,), contexts... ) - prod! = JVP!(conditions, prep_B_same, actual_backend, x, contexts) + prod! = JVP!(f_vec, prep_B_same, actual_backend, x_vec, contexts) return LinearOperator( - eltype(y), length(y), length(x), SYMMETRIC, HERMITIAN, prod!, typeof(x) + eltype(y), length(y), length(x), SYMMETRIC, HERMITIAN, prod!, typeof(x_vec) ) end @@ -148,8 +157,8 @@ end function build_Bᵀ( implicit::ImplicitFunction, - x::AbstractVector, - y::AbstractVector, + x::AbstractArray, + y::AbstractArray, z, args...; suggested_backend::AbstractADType, @@ -172,11 +181,14 @@ function build_Bᵀ_aux( (; conditions, backend, prep_Bᵀ) = implicit actual_backend = isnothing(backend) ? suggested_backend : backend contexts = (Constant(y), Constant(z), map(Constant, args)...) + f_vec = VecToVec(conditions, x) + x_vec = vec(x) + dc_vec = vec(zero(y)) prep_Bᵀ_same = prepare_pullback_same_point( - conditions, prep_Bᵀ..., actual_backend, x, (zero(y),), contexts... + f_vec, prep_Bᵀ..., actual_backend, x_vec, (dc_vec,), contexts... ) - prod! = VJP!(conditions, prep_Bᵀ_same, actual_backend, x, contexts) + prod! = VJP!(f_vec, prep_Bᵀ_same, actual_backend, x_vec, contexts) return LinearOperator( - eltype(y), length(x), length(y), SYMMETRIC, HERMITIAN, prod!, typeof(x) + eltype(y), length(x), length(y), SYMMETRIC, HERMITIAN, prod!, typeof(x_vec) ) end diff --git a/src/implicit_function.jl b/src/implicit_function.jl index 3658af3e..acba5687 100644 --- a/src/implicit_function.jl +++ b/src/implicit_function.jl @@ -29,7 +29,7 @@ This requires solving a linear system `A * J = -B` where `A = ∂₂c`, `B = ∂ ## Positional arguments -- `solver`: a callable returning `(x, args...) -> (y, z)` where `z` is an arbitrary byproduct of the solve. Both `x` and `y` must be subtypes of `AbstractVector`, while `z` and `args` can be anything. +- `solver`: a callable returning `(x, args...) -> (y, z)` where `z` is an arbitrary byproduct of the solve. Both `x` and `y` must be subtypes of `AbstractArray`, while `z` and `args` can be anything. - `conditions`: a callable returning a vector of optimality conditions `(x, y, z, args...) -> c`, must be compatible with automatic differentiation ## Keyword arguments @@ -127,6 +127,6 @@ function Base.show(io::IO, implicit::ImplicitFunction) ) end -function (implicit::ImplicitFunction)(x::AbstractVector, args::Vararg{Any,N}) where {N} +function (implicit::ImplicitFunction)(x::AbstractArray, args::Vararg{Any,N}) where {N} return implicit.solver(x, args...) end diff --git a/src/preparation.jl b/src/preparation.jl index 7ab1efe8..41f8ccb4 100644 --- a/src/preparation.jl +++ b/src/preparation.jl @@ -1,15 +1,7 @@ -struct Switch12{F} - f::F -end - -function (s12::Switch12)(arg1, arg2, other_args::Vararg{Any,N}) where {N} - return s12.f(arg2, arg1, other_args...) -end - function prepare_A( ::MatrixRepresentation, - x::AbstractVector, - y::AbstractVector, + x::AbstractArray, + y::AbstractArray, z, args...; conditions, @@ -21,21 +13,24 @@ end function prepare_A( ::OperatorRepresentation, - x::AbstractVector, - y::AbstractVector, + x::AbstractArray, + y::AbstractArray, z, args...; conditions, backend::AbstractADType, ) contexts = (Constant(x), Constant(z), map(Constant, args)...) - return prepare_pushforward(Switch12(conditions), backend, y, (zero(y),), contexts...) + f_vec = VecToVec(Switch12(conditions), y) + y_vec = vec(y) + dy_vec = vec(zero(y)) + return prepare_pushforward(f_vec, backend, y_vec, (dy_vec,), contexts...) end function prepare_Aᵀ( ::MatrixRepresentation, - x::AbstractVector, - y::AbstractVector, + x::AbstractArray, + y::AbstractArray, z, args...; conditions, @@ -47,21 +42,24 @@ end function prepare_Aᵀ( ::OperatorRepresentation, - x::AbstractVector, - y::AbstractVector, + x::AbstractArray, + y::AbstractArray, z, args...; conditions, backend::AbstractADType, ) contexts = (Constant(x), Constant(z), map(Constant, args)...) - return prepare_pullback(Switch12(conditions), backend, y, (zero(y),), contexts...) + f_vec = VecToVec(Switch12(conditions), y) + y_vec = vec(y) + dc_vec = vec(zero(y)) # same size + return prepare_pullback(f_vec, backend, y_vec, (dc_vec,), contexts...) end function prepare_B( ::MatrixRepresentation, - x::AbstractVector, - y::AbstractVector, + x::AbstractArray, + y::AbstractArray, z, args...; conditions, @@ -73,21 +71,24 @@ end function prepare_B( ::OperatorRepresentation, - x::AbstractVector, - y::AbstractVector, + x::AbstractArray, + y::AbstractArray, z, args...; conditions, backend::AbstractADType, ) contexts = (Constant(y), Constant(z), map(Constant, args)...) - return prepare_pushforward(conditions, backend, x, (zero(x),), contexts...) + f_vec = VecToVec(conditions, x) + x_vec = vec(x) + dx_vec = vec(zero(x)) + return prepare_pushforward(f_vec, backend, x_vec, (dx_vec,), contexts...) end function prepare_Bᵀ( ::MatrixRepresentation, - x::AbstractVector, - y::AbstractVector, + x::AbstractArray, + y::AbstractArray, z, args...; conditions, @@ -99,13 +100,16 @@ end function prepare_Bᵀ( ::OperatorRepresentation, - x::AbstractVector, - y::AbstractVector, + x::AbstractArray, + y::AbstractArray, z, args...; conditions, backend::AbstractADType, ) contexts = (Constant(y), Constant(z), map(Constant, args)...) - return prepare_pullback(conditions, backend, x, (zero(y),), contexts...) + f_vec = VecToVec(conditions, x) + x_vec = vec(x) + dc_vec = vec(zero(y)) # same size + return prepare_pullback(f_vec, backend, x_vec, (dc_vec,), contexts...) end diff --git a/src/utils.jl b/src/utils.jl new file mode 100644 index 00000000..5dc612eb --- /dev/null +++ b/src/utils.jl @@ -0,0 +1,36 @@ +""" + Switch12 + +Represent a function which behaves like `f`, except that the first and second arguments are switched: + f(a1, a2, a3) = b +becomes + g(a2, a1, a3) = f(a1, a2, a3) +""" +struct Switch12{F} + f::F +end + +function (s12::Switch12)(arg1, arg2, other_args::Vararg{Any,N}) where {N} + return s12.f(arg2, arg1, other_args...) +end + +""" + VecToVec + +Represent a function which behaves like `f`, except that the first argument is expected as a vector, and the return is converted to a vector: + f(a1, a2, a3) = b +becomes + g(a1_vec, a2, a3) = vec(f(reshape(a1_vec, size(a1)), a2, a3)) +""" +struct VecToVec{F,N} + f::F + arg1_size::NTuple{N,Int} +end + +VecToVec(f::F, arg1_example::AbstractArray) where {F} = VecToVec(f, size(arg1_example)) + +function (v2v::VecToVec)(arg1_vec::AbstractVector, other_args::Vararg{Any,N}) where {N} + arg1 = reshape(arg1_vec, v2v.arg1_size) + res = v2v.f(arg1, other_args...) + return vec(res) +end diff --git a/test/systematic.jl b/test/systematic.jl index f92a3984..32bfd4e1 100644 --- a/test/systematic.jl +++ b/test/systematic.jl @@ -14,21 +14,26 @@ linear_solver_candidates = [ID.KrylovLinearSolver(), \] representation_candidates = [MatrixRepresentation(), OperatorRepresentation()] backend_candidates = [nothing, AutoForwardDiff(), AutoZygote()] preparation_candidates = [nothing, ForwardMode(), ReverseMode()] +x_candidates = [float.(1:6), reshape(float.(1:12), 6, 2)] ## Test loop @testset verbose = true "Systematic tests" begin @testset for representation in representation_candidates - for (linear_solver, backend, preparation) in Iterators.product( - linear_solver_candidates, backend_candidates, preparation_candidates + for (linear_solver, backend, preparation, x) in Iterators.product( + linear_solver_candidates, + backend_candidates, + preparation_candidates, + x_candidates, ) - @info "Testing $((; linear_solver, backend, representation, preparation))" + x_type = typeof(x) + @info "Testing" linear_solver backend representation preparation x_type if (representation isa OperatorRepresentation && linear_solver == \) continue end outer_backends = [AutoForwardDiff(), AutoZygote()] x = Float64.(1:6) - @testset "$((; linear_solver, backend, preparation))" begin + @testset "$((; linear_solver, backend, preparation, x_type))" begin test_implicit( outer_backends, x; representation, backend, preparation, linear_solver ) diff --git a/test/utils.jl b/test/utils.jl index ee862ec3..5dd30d9b 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -12,7 +12,7 @@ using Zygote: Zygote, ZygoteRuleConfig ## -function identity_break_autodiff(x::AbstractVector{R}) where {R} +function identity_break_autodiff(x::AbstractArray{R}) where {R} float(first(x)) # break ForwardDiff (Vector{R}(undef, 1))[1] = first(x) # break Zygote result = try @@ -23,7 +23,7 @@ function identity_break_autodiff(x::AbstractVector{R}) where {R} return result end -mysqrt(x::AbstractVector) = identity_break_autodiff(sqrt.(x)) +mysqrt(x::AbstractArray) = identity_break_autodiff(sqrt.(x)) ## Various signatures @@ -43,7 +43,7 @@ function make_implicit_sqrt_args(x; kwargs...) return implicit end -function test_implicit_call(x::AbstractVector{T}; kwargs...) where {T} +function test_implicit_call(x::AbstractArray{T}; kwargs...) where {T} imf1 = make_implicit_sqrt_byproduct(x; kwargs...) imf2 = make_implicit_sqrt_args(x; kwargs...) @@ -59,9 +59,9 @@ function test_implicit_call(x::AbstractVector{T}; kwargs...) where {T} end end -tag(::AbstractVector{<:ForwardDiff.Dual{T}}) where {T} = T +tag(::AbstractArray{<:ForwardDiff.Dual{T}}) where {T} = T -function test_implicit_duals(x::AbstractVector{T}; kwargs...) where {T} +function test_implicit_duals(x::AbstractArray{T}; kwargs...) where {T} imf1 = make_implicit_sqrt_byproduct(x; kwargs...) imf2 = make_implicit_sqrt_args(x; kwargs...) @@ -85,7 +85,7 @@ function test_implicit_duals(x::AbstractVector{T}; kwargs...) where {T} end end -function test_implicit_rrule(rc, x::AbstractVector{T}; kwargs...) where {T} +function test_implicit_rrule(rc, x::AbstractArray{T}; kwargs...) where {T} imf1 = make_implicit_sqrt_byproduct(x; kwargs...) imf2 = make_implicit_sqrt_args(x; kwargs...) @@ -117,7 +117,7 @@ end ## High-level tests per backend function test_implicit_backend( - outer_backend::ADTypes.AbstractADType, x::AbstractVector{T}; kwargs... + outer_backend::ADTypes.AbstractADType, x::AbstractArray{T}; kwargs... ) where {T} imf1 = make_implicit_sqrt_byproduct(x; kwargs...) imf2 = make_implicit_sqrt_args(x; kwargs...)