diff --git a/docs/src/faq.md b/docs/src/faq.md index 8a9717f7..b767873c 100644 --- a/docs/src/faq.md +++ b/docs/src/faq.md @@ -58,7 +58,7 @@ See the examples for a demonstration. ### Multiple inputs | Derivatives not needed -If your forward mapping (or conditions) takes multiple inputs but you don't care about derivatives, then you can add further positional and keyword arguments beyond `x`. +If your forward mapping (or conditions) takes multiple inputs but you don't care about derivatives, then you can add further positional arguments beyond `x`. It is important to make sure that the forward mapping and conditions accept the same set of arguments, even if each of these functions only uses a subset of them. ```julia diff --git a/ext/ImplicitDifferentiationChainRulesCoreExt.jl b/ext/ImplicitDifferentiationChainRulesCoreExt.jl index 5a63b8e0..e55e4408 100644 --- a/ext/ImplicitDifferentiationChainRulesCoreExt.jl +++ b/ext/ImplicitDifferentiationChainRulesCoreExt.jl @@ -14,13 +14,9 @@ using ImplicitDifferentiation: ImplicitDifferentiation.chainrules_suggested_backend(rc::RuleConfig) = AutoChainRules(rc) function ChainRulesCore.rrule( - rc::RuleConfig, - implicit::ImplicitFunction, - x::AbstractVector, - args::Vararg{Any,N}; - kwargs..., + rc::RuleConfig, implicit::ImplicitFunction, x::AbstractVector, args::Vararg{Any,N}; ) where {N} - y, z = implicit(x, args...; kwargs...) + y, z = implicit(x, args...) suggested_backend = chainrules_suggested_backend(rc) Aᵀ = build_Aᵀ(implicit, x, y, z, args...; suggested_backend) diff --git a/ext/ImplicitDifferentiationForwardDiffExt.jl b/ext/ImplicitDifferentiationForwardDiffExt.jl index 06c53a8a..fff6823d 100644 --- a/ext/ImplicitDifferentiationForwardDiffExt.jl +++ b/ext/ImplicitDifferentiationForwardDiffExt.jl @@ -5,10 +5,10 @@ using ForwardDiff: Dual, Partials, partials, value using ImplicitDifferentiation: ImplicitFunction, build_A, build_B function (implicit::ImplicitFunction)( - x_and_dx::AbstractVector{Dual{T,R,N}}, args...; kwargs... + x_and_dx::AbstractVector{Dual{T,R,N}}, args... ) where {T,R,N} x = value.(x_and_dx) - y, z = implicit(x, args...; kwargs...) + y, z = implicit(x, args...) suggested_backend = AutoForwardDiff() A = build_A(implicit, x, y, z, args...; suggested_backend)