@@ -14,17 +14,18 @@ using ImplicitDifferentiation:
1414# not covered by Codecov for now
1515ImplicitDifferentiation. chainrules_suggested_backend (rc:: RuleConfig ) = AutoChainRules (rc)
1616
17- struct ImplicitPullback{TA,TB,TL,TP,Nargs}
17+ struct ImplicitPullback{TA,TB,TL,TC, TP,Nargs}
1818 Aᵀ:: TA
1919 Bᵀ:: TB
2020 linear_solver:: TL
21+ c0:: TC
2122 project_x:: TP
2223 _Nargs:: Val{Nargs}
2324end
2425
25- function (pb:: ImplicitPullback{TA,TB,TL,TP,Nargs} )((dy, dz)) where {TA,TB,TL,TP,Nargs}
26- (; Aᵀ, Bᵀ, linear_solver, project_x) = pb
27- dc = linear_solver (Aᵀ, - unthunk (dy))
26+ function (pb:: ImplicitPullback{TA,TB,TL,TC, TP,Nargs} )((dy, dz)) where {TA,TB,TL,TP,TC ,Nargs}
27+ (; Aᵀ, Bᵀ, linear_solver, c0, project_x) = pb
28+ dc = linear_solver (Aᵀ, - unthunk (dy), c0 )
2829 dx = Bᵀ (dc)
2930 df = NoTangent ()
3031 dargs = ntuple (unimplemented_tangent, Val (Nargs))
@@ -37,14 +38,15 @@ function ChainRulesCore.rrule(
3738 (; conditions, linear_solver) = implicit
3839 y, z = implicit (x, args... )
3940 c = conditions (x, y, z, args... )
41+ c0 = zero (c)
4042
4143 suggested_backend = chainrules_suggested_backend (rc)
4244 prep = ImplicitFunctionPreparation (eltype (x))
4345 Aᵀ = build_Aᵀ (implicit, prep, x, y, z, c, args... ; suggested_backend)
4446 Bᵀ = build_Bᵀ (implicit, prep, x, y, z, c, args... ; suggested_backend)
4547 project_x = ProjectTo (x)
4648
47- implicit_pullback = ImplicitPullback (Aᵀ, Bᵀ, linear_solver, project_x, Val (N))
49+ implicit_pullback = ImplicitPullback (Aᵀ, Bᵀ, linear_solver, c0, project_x, Val (N))
4850 return (y, z), implicit_pullback
4951end
5052
0 commit comments