diff --git a/Project.toml b/Project.toml index c525894c..7e8f4f0a 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ImplicitDifferentiation" uuid = "57b37032-215b-411a-8a7c-41a003a55207" authors = ["Guillaume Dalle", "Mohamed Tarek"] -version = "0.9.0" +version = "0.9.1" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/ext/ImplicitDifferentiationChainRulesCoreExt.jl b/ext/ImplicitDifferentiationChainRulesCoreExt.jl index c86bafa4..4b3d0397 100644 --- a/ext/ImplicitDifferentiationChainRulesCoreExt.jl +++ b/ext/ImplicitDifferentiationChainRulesCoreExt.jl @@ -14,17 +14,18 @@ using ImplicitDifferentiation: # not covered by Codecov for now ImplicitDifferentiation.chainrules_suggested_backend(rc::RuleConfig) = AutoChainRules(rc) -struct ImplicitPullback{TA,TB,TL,TP,Nargs} +struct ImplicitPullback{TA,TB,TL,TC,TP,Nargs} Aᵀ::TA Bᵀ::TB linear_solver::TL + c0::TC project_x::TP _Nargs::Val{Nargs} end -function (pb::ImplicitPullback{TA,TB,TL,TP,Nargs})((dy, dz)) where {TA,TB,TL,TP,Nargs} - (; Aᵀ, Bᵀ, linear_solver, project_x) = pb - dc = linear_solver(Aᵀ, -unthunk(dy)) +function (pb::ImplicitPullback{TA,TB,TL,TC,TP,Nargs})((dy, dz)) where {TA,TB,TL,TP,TC,Nargs} + (; Aᵀ, Bᵀ, linear_solver, c0, project_x) = pb + dc = linear_solver(Aᵀ, -unthunk(dy), c0) dx = Bᵀ(dc) df = NoTangent() dargs = ntuple(unimplemented_tangent, Val(Nargs)) @@ -37,6 +38,7 @@ function ChainRulesCore.rrule( (; conditions, linear_solver) = implicit y, z = implicit(x, args...) c = conditions(x, y, z, args...) + c0 = zero(c) suggested_backend = chainrules_suggested_backend(rc) prep = ImplicitFunctionPreparation(eltype(x)) @@ -44,7 +46,7 @@ function ChainRulesCore.rrule( Bᵀ = build_Bᵀ(implicit, prep, x, y, z, c, args...; suggested_backend) project_x = ProjectTo(x) - implicit_pullback = ImplicitPullback(Aᵀ, Bᵀ, linear_solver, project_x, Val(N)) + implicit_pullback = ImplicitPullback(Aᵀ, Bᵀ, linear_solver, c0, project_x, Val(N)) return (y, z), implicit_pullback end diff --git a/ext/ImplicitDifferentiationForwardDiffExt.jl b/ext/ImplicitDifferentiationForwardDiffExt.jl index 5228c270..54fa761f 100644 --- a/ext/ImplicitDifferentiationForwardDiffExt.jl +++ b/ext/ImplicitDifferentiationForwardDiffExt.jl @@ -11,6 +11,7 @@ function (implicit::ImplicitFunction)( x = value.(x_and_dx) y, z = implicit(x, args...) c = implicit.conditions(x, y, z, args...) + y0 = zero(y) suggested_backend = AutoForwardDiff() A = build_A(implicit, prep, x, y, z, c, args...; suggested_backend) @@ -21,7 +22,7 @@ function (implicit::ImplicitFunction)( end dC = map(B, dX) dY = map(dC) do dₖc - dₖy = implicit.linear_solver(A, -dₖc) + dₖy = implicit.linear_solver(A, -dₖc, y0) return dₖy end diff --git a/src/settings.jl b/src/settings.jl index c54e2a33..a917e21b 100644 --- a/src/settings.jl +++ b/src/settings.jl @@ -12,7 +12,7 @@ Specify that linear systems `Ax = b` should be solved with a direct method. """ struct DirectLinearSolver end -function (solver::DirectLinearSolver)(A, b::AbstractVector) +function (solver::DirectLinearSolver)(A, b::AbstractVector, x0::AbstractVector) return A \ b end @@ -33,8 +33,8 @@ struct IterativeLinearSolver{K} end end -function (solver::IterativeLinearSolver)(A, b) - sol, info = linsolve(A, b; solver.kwargs...) +function (solver::IterativeLinearSolver)(A, b, x0) + sol, info = linsolve(A, b, x0; solver.kwargs...) @assert info.converged == 1 return sol end