diff --git a/Project.toml b/Project.toml index d9fd8b2a..47b58275 100644 --- a/Project.toml +++ b/Project.toml @@ -11,11 +11,13 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" [weakdeps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] ImplicitDifferentiationChainRulesCoreExt = "ChainRulesCore" +ImplicitDifferentiationEnzymeExt = ["EnzymeCore", "ADTypes"] ImplicitDifferentiationForwardDiffExt = "ForwardDiff" ImplicitDifferentiationZygoteExt = "Zygote" @@ -27,6 +29,7 @@ ChainRulesTestUtils = "1.13.0" ComponentArrays = "0.15.27" DifferentiationInterface = "0.6.1,0.7" Documenter = "1.12.0" +EnzymeCore = "0.8" ExplicitImports = "1" FiniteDiff = "2.27.0" ForwardDiff = "0.10.36, 1" @@ -54,6 +57,8 @@ ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" @@ -80,6 +85,7 @@ test = [ "ComponentArrays", "DifferentiationInterface", "Documenter", + "Enzyme", "ExplicitImports", "FiniteDiff", "ForwardDiff", diff --git a/examples/1_basic.jl b/examples/1_basic.jl index dae14007..602ea584 100644 --- a/examples/1_basic.jl +++ b/examples/1_basic.jl @@ -1,5 +1,7 @@ # # Basic use cases +versioninfo() + #= We show how to differentiate through very common routines: - an unconstrained optimization problem @@ -16,6 +18,7 @@ using NLsolve using Optim using Test #src using Zygote +using Enzyme #= In all three cases, we will use the square root as our forward mapping, but expressed in three different ways. @@ -75,7 +78,7 @@ end; We now have all the ingredients to construct our implicit function. =# -implicit_optim = ImplicitFunction(forward_optim, conditions_optim) +const implicit_optim = ImplicitFunction(forward_optim, conditions_optim) # And indeed, it behaves as it should when we call it: @@ -87,6 +90,12 @@ first(implicit_optim(x, LBFGS())) .^ 2 ForwardDiff.jacobian(_x -> first(implicit_optim(_x, LBFGS())), x) @test ForwardDiff.jacobian(_x -> first(implicit_optim(_x, LBFGS())), x) ≈ J #src +Enzyme.jacobian(Forward, _x -> first(implicit_optim(_x, LBFGS())), x) |> only +Enzyme.jacobian(Reverse, _x -> first(implicit_optim(_x, LBFGS())), x) |> only + +# Fails due to mismatched activity. +# Enzyme.jacobian(Forward, _x -> first(forward_optim(_x, LBFGS())), x) + #= In this instance, we could use ForwardDiff.jl directly on the solver: =# diff --git a/ext/ImplicitDifferentiationEnzymeExt.jl b/ext/ImplicitDifferentiationEnzymeExt.jl new file mode 100644 index 00000000..31b0abdb --- /dev/null +++ b/ext/ImplicitDifferentiationEnzymeExt.jl @@ -0,0 +1,133 @@ +module ImplicitDifferentiationEnzymeExt + +using ADTypes: AutoEnzyme +using EnzymeCore +import EnzymeCore: EnzymeRules +using ImplicitDifferentiation: + ImplicitFunction, + ImplicitFunctionPreparation, + IterativeLeastSquaresSolver, + build_A, + build_Aᵀ, + build_B, + build_Bᵀ + +import .EnzymeRules: AugmentedReturn + +const AnyDuplicated{T} = Union{Duplicated{T}, BatchDuplicated{T}, DuplicatedNoNeed{T}, BatchDuplicatedNoNeed{T}} + +function EnzymeRules.forward(config, implicit::Const{<:ImplicitFunction}, ::Type{<:AnyDuplicated}, x::AnyDuplicated, args::Vararg{<:Const}) + implicit = implicit.val + + dx = x.dval + x = x.val + args = ntuple(length(args)) do i + args[i].val + end + + prep = ImplicitFunctionPreparation(eltype(x)) + (; conditions, linear_solver) = implicit + + y, z = implicit(x, args...) + c = conditions(x, y, z, args...) + + y0 = zero(y) + forward_backend = AutoEnzyme(mode = Forward) + reverse_backend = AutoEnzyme(mode = Reverse) + + A = build_A(implicit, prep, x, y, z, c, args...; suggested_backend = forward_backend) + B = build_B(implicit, prep, x, y, z, c, args...; suggested_backend = forward_backend) + Aᵀ = if linear_solver isa IterativeLeastSquaresSolver + build_Aᵀ(implicit, prep, x, y, z, c, args...; suggested_backend = reverse_backend) + else + nothing + end + + return if EnzymeRules.width(config) == 1 + dc = B(dx) + dy = linear_solver(A, Aᵀ, dc, y0)::typeof(y0) + dz = nothing + + if EnzymeRules.needs_primal(config) + return Duplicated((y, z), (dy, dz)) + else + return dy, dz + end + else + dc = map(B, dx) + dy = map(dc) do dₖc + linear_solver(A, Aᵀ, -dₖc, y0) + end + + df = ntuple(Val(EnzymeRules.width(config))) do i + (dy[i]::typeof(y0), nothing) + end + + if EnzymeRules.needs_primal(config) + return BatchDuplicated((y, z), df) + else + # TODO: We need to heal the type instability from the linear solver here + # df::NTuple{EnzymeRules.width(config), Tuple{typeof(y0), Nothing}} + return df::NTuple{EnzymeRules.width(config), Tuple{Vector{Float64}, Nothing}} + end + end +end + +function EnzymeRules.augmented_primal(config, implicit::Const{<:ImplicitFunction}, RT::Type{<:AnyDuplicated}, x::AnyDuplicated, args::Vararg{<:Const}) + @assert EnzymeRules.width(config) == 1 + implicit = implicit.val + + x = x.val + args = ntuple(length(args)) do i + args[i].val + end + + prep = ImplicitFunctionPreparation(eltype(x)) + (; conditions, linear_solver) = implicit + + y, z = implicit(x, args...) + c = conditions(x, y, z, args...) + c0 = zero(c) + + forward_backend = AutoEnzyme(mode = Forward) + reverse_backend = AutoEnzyme(mode = Reverse) + + Aᵀ = build_Aᵀ(implicit, prep, x, y, z, c, args...; suggested_backend = reverse_backend) + Bᵀ = build_Bᵀ(implicit, prep, x, y, z, c, args...; suggested_backend = reverse_backend) + if linear_solver isa IterativeLeastSquaresSolver + A = build_A(implicit, prep, x, y, z, c, args...; suggested_backend = forward_backend) + else + A = nothing + end + + if EnzymeRules.needs_primal(config) + primal = (y, z) + else + primal = nothing + end + + dy = EnzymeCore.make_zero(y) + if EnzymeRules.needs_shadow(config) + shadow = (dy, EnzymeCore.make_zero(z)) + else + shadow = nothing + end + + tape = (; Aᵀ, Bᵀ, A, linear_solver, dy, c0) + + AR = EnzymeRules.augmented_rule_return_type(config, RT) + + return AR(primal, shadow, tape) +end + +function EnzymeRules.reverse(_, ::Const{<:ImplicitFunction}, ::Type, tape, x::AnyDuplicated, ::Vararg{<:Const}) + dx = x.dval + (; Aᵀ, Bᵀ, A, linear_solver, dy, c0) = tape + + dc = linear_solver(Aᵀ, A, -dy, c0) + dx .+= Bᵀ(dc) + + return (nothing, nothing) +end + +end # modul