diff --git a/ext/ImplicitDifferentiationChainRulesCoreExt.jl b/ext/ImplicitDifferentiationChainRulesCoreExt.jl index 997ab1f7..e0a847cb 100644 --- a/ext/ImplicitDifferentiationChainRulesCoreExt.jl +++ b/ext/ImplicitDifferentiationChainRulesCoreExt.jl @@ -14,10 +14,11 @@ using ImplicitDifferentiation: ImplicitDifferentiation.chainrules_suggested_backend(rc::RuleConfig) = AutoChainRules(rc) function ChainRulesCore.rrule( - rc::RuleConfig, implicit::ImplicitFunction, x::AbstractArray, args::Vararg{Any,N}; + rc::RuleConfig, implicit::ImplicitFunction, x::AbstractArray, args::Vararg{Any,N} ) where {N} + (; conditions, linear_solver) = implicit y, z = implicit(x, args...) - c = implicit.conditions(x, y, z, args...) + c = conditions(x, y, z, args...) suggested_backend = chainrules_suggested_backend(rc) Aᵀ = build_Aᵀ(implicit, x, y, z, c, args...; suggested_backend) @@ -27,11 +28,11 @@ function ChainRulesCore.rrule( function implicit_pullback((dy, dz)) dy = unthunk(dy) dy_vec = vec(dy) - dc_vec = implicit.linear_solver(Aᵀ, -dy_vec) + dc_vec = linear_solver(Aᵀ, -dy_vec) dx_vec = Bᵀ(dc_vec) dx = reshape(dx_vec, size(x)) df = NoTangent() - dargs = ntuple(unimplemented_tangent, N) + dargs = ntuple(unimplemented_tangent, Val(N)) return (df, project_x(dx), dargs...) end diff --git a/test/utils.jl b/test/utils.jl index e95a1c87..82e75ee9 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -142,6 +142,14 @@ function test_implicit_rrule(scen::Scenario) @test z == z_true @test dimpl isa NoTangent @test dx ≈ dx_true + ChainRulesTestUtils.test_rrule( + implicit, + scen.x, + scen.args...; + rtol=1e-3, + check_inferred=false, + output_tangent=(copy(y), copy(z)), + ) end end