diff --git a/ext/ImplicitDifferentiationChainRulesCoreExt.jl b/ext/ImplicitDifferentiationChainRulesCoreExt.jl index e0a847cb..997ab1f7 100644 --- a/ext/ImplicitDifferentiationChainRulesCoreExt.jl +++ b/ext/ImplicitDifferentiationChainRulesCoreExt.jl @@ -14,11 +14,10 @@ using ImplicitDifferentiation: ImplicitDifferentiation.chainrules_suggested_backend(rc::RuleConfig) = AutoChainRules(rc) function ChainRulesCore.rrule( - rc::RuleConfig, implicit::ImplicitFunction, x::AbstractArray, args::Vararg{Any,N} + rc::RuleConfig, implicit::ImplicitFunction, x::AbstractArray, args::Vararg{Any,N}; ) where {N} - (; conditions, linear_solver) = implicit y, z = implicit(x, args...) - c = conditions(x, y, z, args...) + c = implicit.conditions(x, y, z, args...) suggested_backend = chainrules_suggested_backend(rc) Aᵀ = build_Aᵀ(implicit, x, y, z, c, args...; suggested_backend) @@ -28,11 +27,11 @@ function ChainRulesCore.rrule( function implicit_pullback((dy, dz)) dy = unthunk(dy) dy_vec = vec(dy) - dc_vec = linear_solver(Aᵀ, -dy_vec) + dc_vec = implicit.linear_solver(Aᵀ, -dy_vec) dx_vec = Bᵀ(dc_vec) dx = reshape(dx_vec, size(x)) df = NoTangent() - dargs = ntuple(unimplemented_tangent, Val(N)) + dargs = ntuple(unimplemented_tangent, N) return (df, project_x(dx), dargs...) end diff --git a/test/utils.jl b/test/utils.jl index 82e75ee9..e95a1c87 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -142,14 +142,6 @@ function test_implicit_rrule(scen::Scenario) @test z == z_true @test dimpl isa NoTangent @test dx ≈ dx_true - ChainRulesTestUtils.test_rrule( - implicit, - scen.x, - scen.args...; - rtol=1e-3, - check_inferred=false, - output_tangent=(copy(y), copy(z)), - ) end end