From 8890cb1d22a1a298b68a78166727d63e7d52b010 Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Thu, 13 Nov 2025 23:11:43 +0100 Subject: [PATCH 1/3] add forward rule for Enzyme --- Project.toml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/Project.toml b/Project.toml index d9fd8b2a..4bbce547 100644 --- a/Project.toml +++ b/Project.toml @@ -11,11 +11,13 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" [weakdeps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] ImplicitDifferentiationChainRulesCoreExt = "ChainRulesCore" +ImplicitDifferentiationEnzymeExt = "Enzyme" ImplicitDifferentiationForwardDiffExt = "ForwardDiff" ImplicitDifferentiationZygoteExt = "Zygote" @@ -27,6 +29,7 @@ ChainRulesTestUtils = "1.13.0" ComponentArrays = "0.15.27" DifferentiationInterface = "0.6.1,0.7" Documenter = "1.12.0" +EnzymeCore = "0.8" ExplicitImports = "1" FiniteDiff = "2.27.0" ForwardDiff = "0.10.36, 1" @@ -54,6 +57,7 @@ ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" From d4875268a5ff051b9027a5c9609b09d3e4d0a8c8 Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Fri, 14 Nov 2025 06:52:48 +0100 Subject: [PATCH 2/3] fixup! add forward rule for Enzyme --- ext/ImplicitDifferentiationEnzyme.jl | 80 ++++++++++++++++++++++++++++ 1 file changed, 80 insertions(+) create mode 100644 ext/ImplicitDifferentiationEnzyme.jl diff --git a/ext/ImplicitDifferentiationEnzyme.jl b/ext/ImplicitDifferentiationEnzyme.jl new file mode 100644 index 00000000..7bb5ec39 --- /dev/null +++ b/ext/ImplicitDifferentiationEnzyme.jl @@ -0,0 +1,80 @@ +module ImplicitDifferentiationForwardDiffExt + +using ADTypes: AutoEnzyme +using EnzymeCore +import EnzymeCore: EnzymeRules +using ImplicitDifferentiation: + ImplicitFunction, + ImplicitFunctionPreparation, + IterativeLeastSquaresSolver, + build_A, + build_Aᵀ, + build_B + +function EnzymeRules.forward(config, implicit::Const{<:ImplicitFunction}, RT::Type, x, args...) + prep = ImplicitFunctionPreparation(eltype(x.val)) + EnzymeRules.forward(config, implicit, RT, Const(prep), x, args...) +end + +@inline function EnzymeRules.forward(config, implicit::Const{<:ImplicitFunction}, RT::Type, prep::Const{<:ImplicitFunctionPreparation{R}}, x, args...) where R + implicit = implicit.val + prep = prep.val + + dx = x.dval + # dargs = ntuple(length(args)) do i + # args[i].dval + # end + + x = x.val + args = ntuple(length(args)) do i + @assert args[i] isa Const + args[i].val + end + + (; conditions, linear_solver) = implicit + + y, z = implicit(x, args...) + c = conditions(x, y, z, args...) + + y0 = zero(y) + forward_backend = AutoEnzyme(mode=Forward) + reverse_backend = AutoEnzyme(mode=Reverse) + + A = build_A(implicit, prep, x, y, z, c, args...; suggested_backend=forward_backend) + B = build_B(implicit, prep, x, y, z, c, args...; suggested_backend=forward_backend) + Aᵀ = if linear_solver isa IterativeLeastSquaresSolver + build_Aᵀ(implicit, prep, x, y, z, c, args...; suggested_backend=reverse_backend) + else + nothing + end + + if EnzymeRules.width(config) == 1 + dc = B(dx) + dy = linear_solver(A, Aᵀ, dc, y0)::typeof(y0) + dz = Enzyme.make_zero(z) + + if EnzymeRules.needs_primal(config) + return Duplicated((y, z), (dy, dz)) + else + return dy, dz + end + else + dc = map(B, dx) + dy = map(dc) do dₖc + linear_solver(A, Aᵀ, -dₖc, y0)::typeof(y0) + end + + df = ntuple(Val(EnzymeRules.width(config))) do i + (dy[i]::typeof(y0), Enzyme.make_zero(z)) + end + + if EnzymeRules.needs_primal(config) + return BatchDuplicated((y, z), df) + else + return df + end + end +end + + +end From b9d8a51574628be6b8382dc8893e60fb104b7247 Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Fri, 14 Nov 2025 19:18:10 +0100 Subject: [PATCH 3/3] add reverse rule and fix things --- Project.toml | 4 +- examples/1_basic.jl | 11 +- ext/ImplicitDifferentiationEnzyme.jl | 80 -------------- ext/ImplicitDifferentiationEnzymeExt.jl | 133 ++++++++++++++++++++++++ 4 files changed, 146 insertions(+), 82 deletions(-) delete mode 100644 ext/ImplicitDifferentiationEnzyme.jl create mode 100644 ext/ImplicitDifferentiationEnzymeExt.jl diff --git a/Project.toml b/Project.toml index 4bbce547..47b58275 100644 --- a/Project.toml +++ b/Project.toml @@ -17,7 +17,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] ImplicitDifferentiationChainRulesCoreExt = "ChainRulesCore" -ImplicitDifferentiationEnzymeExt = "Enzyme" +ImplicitDifferentiationEnzymeExt = ["EnzymeCore", "ADTypes"] ImplicitDifferentiationForwardDiffExt = "ForwardDiff" ImplicitDifferentiationZygoteExt = "Zygote" @@ -58,6 +58,7 @@ ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" @@ -84,6 +85,7 @@ test = [ "ComponentArrays", "DifferentiationInterface", "Documenter", + "Enzyme", "ExplicitImports", "FiniteDiff", "ForwardDiff", diff --git a/examples/1_basic.jl b/examples/1_basic.jl index dae14007..602ea584 100644 --- a/examples/1_basic.jl +++ b/examples/1_basic.jl @@ -1,5 +1,7 @@ # # Basic use cases +versioninfo() + #= We show how to differentiate through very common routines: - an unconstrained optimization problem @@ -16,6 +18,7 @@ using NLsolve using Optim using Test #src using Zygote +using Enzyme #= In all three cases, we will use the square root as our forward mapping, but expressed in three different ways. @@ -75,7 +78,7 @@ end; We now have all the ingredients to construct our implicit function. =# -implicit_optim = ImplicitFunction(forward_optim, conditions_optim) +const implicit_optim = ImplicitFunction(forward_optim, conditions_optim) # And indeed, it behaves as it should when we call it: @@ -87,6 +90,12 @@ first(implicit_optim(x, LBFGS())) .^ 2 ForwardDiff.jacobian(_x -> first(implicit_optim(_x, LBFGS())), x) @test ForwardDiff.jacobian(_x -> first(implicit_optim(_x, LBFGS())), x) ≈ J #src +Enzyme.jacobian(Forward, _x -> first(implicit_optim(_x, LBFGS())), x) |> only +Enzyme.jacobian(Reverse, _x -> first(implicit_optim(_x, LBFGS())), x) |> only + +# Fails due to mismatched activity. +# Enzyme.jacobian(Forward, _x -> first(forward_optim(_x, LBFGS())), x) + #= In this instance, we could use ForwardDiff.jl directly on the solver: =# diff --git a/ext/ImplicitDifferentiationEnzyme.jl b/ext/ImplicitDifferentiationEnzyme.jl deleted file mode 100644 index 7bb5ec39..00000000 --- a/ext/ImplicitDifferentiationEnzyme.jl +++ /dev/null @@ -1,80 +0,0 @@ -module ImplicitDifferentiationForwardDiffExt - -using ADTypes: AutoEnzyme -using EnzymeCore -import EnzymeCore: EnzymeRules -using ImplicitDifferentiation: - ImplicitFunction, - ImplicitFunctionPreparation, - IterativeLeastSquaresSolver, - build_A, - build_Aᵀ, - build_B - -function EnzymeRules.forward(config, implicit::Const{<:ImplicitFunction}, RT::Type, x, args...) - prep = ImplicitFunctionPreparation(eltype(x.val)) - EnzymeRules.forward(config, implicit, RT, Const(prep), x, args...) -end - -@inline function EnzymeRules.forward(config, implicit::Const{<:ImplicitFunction}, RT::Type, prep::Const{<:ImplicitFunctionPreparation{R}}, x, args...) where R - implicit = implicit.val - prep = prep.val - - dx = x.dval - # dargs = ntuple(length(args)) do i - # args[i].dval - # end - - x = x.val - args = ntuple(length(args)) do i - @assert args[i] isa Const - args[i].val - end - - (; conditions, linear_solver) = implicit - - y, z = implicit(x, args...) - c = conditions(x, y, z, args...) - - y0 = zero(y) - forward_backend = AutoEnzyme(mode=Forward) - reverse_backend = AutoEnzyme(mode=Reverse) - - A = build_A(implicit, prep, x, y, z, c, args...; suggested_backend=forward_backend) - B = build_B(implicit, prep, x, y, z, c, args...; suggested_backend=forward_backend) - Aᵀ = if linear_solver isa IterativeLeastSquaresSolver - build_Aᵀ(implicit, prep, x, y, z, c, args...; suggested_backend=reverse_backend) - else - nothing - end - - if EnzymeRules.width(config) == 1 - dc = B(dx) - dy = linear_solver(A, Aᵀ, dc, y0)::typeof(y0) - dz = Enzyme.make_zero(z) - - if EnzymeRules.needs_primal(config) - return Duplicated((y, z), (dy, dz)) - else - return dy, dz - end - else - dc = map(B, dx) - dy = map(dc) do dₖc - linear_solver(A, Aᵀ, -dₖc, y0)::typeof(y0) - end - - df = ntuple(Val(EnzymeRules.width(config))) do i - (dy[i]::typeof(y0), Enzyme.make_zero(z)) - end - - if EnzymeRules.needs_primal(config) - return BatchDuplicated((y, z), df) - else - return df - end - end -end - - -end diff --git a/ext/ImplicitDifferentiationEnzymeExt.jl b/ext/ImplicitDifferentiationEnzymeExt.jl new file mode 100644 index 00000000..31b0abdb --- /dev/null +++ b/ext/ImplicitDifferentiationEnzymeExt.jl @@ -0,0 +1,133 @@ +module ImplicitDifferentiationEnzymeExt + +using ADTypes: AutoEnzyme +using EnzymeCore +import EnzymeCore: EnzymeRules +using ImplicitDifferentiation: + ImplicitFunction, + ImplicitFunctionPreparation, + IterativeLeastSquaresSolver, + build_A, + build_Aᵀ, + build_B, + build_Bᵀ + +import .EnzymeRules: AugmentedReturn + +const AnyDuplicated{T} = Union{Duplicated{T}, BatchDuplicated{T}, DuplicatedNoNeed{T}, BatchDuplicatedNoNeed{T}} + +function EnzymeRules.forward(config, implicit::Const{<:ImplicitFunction}, ::Type{<:AnyDuplicated}, x::AnyDuplicated, args::Vararg{<:Const}) + implicit = implicit.val + + dx = x.dval + x = x.val + args = ntuple(length(args)) do i + args[i].val + end + + prep = ImplicitFunctionPreparation(eltype(x)) + (; conditions, linear_solver) = implicit + + y, z = implicit(x, args...) + c = conditions(x, y, z, args...) + + y0 = zero(y) + forward_backend = AutoEnzyme(mode = Forward) + reverse_backend = AutoEnzyme(mode = Reverse) + + A = build_A(implicit, prep, x, y, z, c, args...; suggested_backend = forward_backend) + B = build_B(implicit, prep, x, y, z, c, args...; suggested_backend = forward_backend) + Aᵀ = if linear_solver isa IterativeLeastSquaresSolver + build_Aᵀ(implicit, prep, x, y, z, c, args...; suggested_backend = reverse_backend) + else + nothing + end + + return if EnzymeRules.width(config) == 1 + dc = B(dx) + dy = linear_solver(A, Aᵀ, dc, y0)::typeof(y0) + dz = nothing + + if EnzymeRules.needs_primal(config) + return Duplicated((y, z), (dy, dz)) + else + return dy, dz + end + else + dc = map(B, dx) + dy = map(dc) do dₖc + linear_solver(A, Aᵀ, -dₖc, y0) + end + + df = ntuple(Val(EnzymeRules.width(config))) do i + (dy[i]::typeof(y0), nothing) + end + + if EnzymeRules.needs_primal(config) + return BatchDuplicated((y, z), df) + else + # TODO: We need to heal the type instability from the linear solver here + # df::NTuple{EnzymeRules.width(config), Tuple{typeof(y0), Nothing}} + return df::NTuple{EnzymeRules.width(config), Tuple{Vector{Float64}, Nothing}} + end + end +end + +function EnzymeRules.augmented_primal(config, implicit::Const{<:ImplicitFunction}, RT::Type{<:AnyDuplicated}, x::AnyDuplicated, args::Vararg{<:Const}) + @assert EnzymeRules.width(config) == 1 + implicit = implicit.val + + x = x.val + args = ntuple(length(args)) do i + args[i].val + end + + prep = ImplicitFunctionPreparation(eltype(x)) + (; conditions, linear_solver) = implicit + + y, z = implicit(x, args...) + c = conditions(x, y, z, args...) + c0 = zero(c) + + forward_backend = AutoEnzyme(mode = Forward) + reverse_backend = AutoEnzyme(mode = Reverse) + + Aᵀ = build_Aᵀ(implicit, prep, x, y, z, c, args...; suggested_backend = reverse_backend) + Bᵀ = build_Bᵀ(implicit, prep, x, y, z, c, args...; suggested_backend = reverse_backend) + if linear_solver isa IterativeLeastSquaresSolver + A = build_A(implicit, prep, x, y, z, c, args...; suggested_backend = forward_backend) + else + A = nothing + end + + if EnzymeRules.needs_primal(config) + primal = (y, z) + else + primal = nothing + end + + dy = EnzymeCore.make_zero(y) + if EnzymeRules.needs_shadow(config) + shadow = (dy, EnzymeCore.make_zero(z)) + else + shadow = nothing + end + + tape = (; Aᵀ, Bᵀ, A, linear_solver, dy, c0) + + AR = EnzymeRules.augmented_rule_return_type(config, RT) + + return AR(primal, shadow, tape) +end + +function EnzymeRules.reverse(_, ::Const{<:ImplicitFunction}, ::Type, tape, x::AnyDuplicated, ::Vararg{<:Const}) + dx = x.dval + (; Aᵀ, Bᵀ, A, linear_solver, dy, c0) = tape + + dc = linear_solver(Aᵀ, A, -dy, c0) + dx .+= Bᵀ(dc) + + return (nothing, nothing) +end + +end # modul