Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,8 @@ version = "0.9.0"
[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
Krylov = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7"
KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LinearMaps = "7a12625a-238d-50fd-b39a-03d52299707e"
LinearOperators = "5c8ed15e-5a4c-59e4-a42b-c7e8811fb125"

[weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand All @@ -35,9 +33,8 @@ ForwardDiff = "0.10.36, 1"
JET = "0.9, 0.10"
JuliaFormatter = "2.1.2"
Krylov = "0.9.6, 0.10"
KrylovKit = "0.9.5"
LinearAlgebra = "1"
LinearMaps = "3.11.4"
LinearOperators = "2.8.0"
NLsolve = "4.5.1"
Optim = "1.12.0"
Random = "1"
Expand Down
3 changes: 1 addition & 2 deletions docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
ImplicitDifferentiation = "57b37032-215b-411a-8a7c-41a003a55207"
Krylov = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
Expand All @@ -16,4 +15,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
Documenter = "1.3"
Documenter = "1.3"
1 change: 0 additions & 1 deletion docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ MatrixRepresentation
OperatorRepresentation
IterativeLinearSolver
DirectLinearSolver
prepare_implicit
```

## Internals
Expand Down
4 changes: 3 additions & 1 deletion examples/3_tricks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ end;

# And build your implicit function like so:

implicit_components = ImplicitFunction(forward_components, conditions_components);
implicit_components = ImplicitFunction(
forward_components, conditions_components; strict=Val(false)
);

# Now we're good to go.

Expand Down
43 changes: 23 additions & 20 deletions ext/ImplicitDifferentiationChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,40 +14,43 @@ using ImplicitDifferentiation:
# not covered by Codecov for now
ImplicitDifferentiation.chainrules_suggested_backend(rc::RuleConfig) = AutoChainRules(rc)

struct ImplicitPullback{TA,TB,TL,TP,Nargs}
Aᵀ::TA
Bᵀ::TB
linear_solver::TL
project_x::TP
_Nargs::Val{Nargs}
end

function (pb::ImplicitPullback{TA,TB,TL,TP,Nargs})((dy, dz)) where {TA,TB,TL,TP,Nargs}
(; Aᵀ, Bᵀ, linear_solver, project_x) = pb
dc = linear_solver(Aᵀ, -unthunk(dy))
dx = Bᵀ(dc)
df = NoTangent()
dargs = ntuple(unimplemented_tangent, Val(Nargs))
return (df, project_x(dx), dargs...)
end

function ChainRulesCore.rrule(
rc::RuleConfig,
implicit::ImplicitFunction,
prep::ImplicitFunctionPreparation,
x::AbstractArray,
args::Vararg{Any,N};
rc::RuleConfig, implicit::ImplicitFunction, x::AbstractArray, args::Vararg{Any,N};
) where {N}
(; conditions, linear_solver) = implicit
y, z = implicit(x, args...)
c = implicit.conditions(x, y, z, args...)
c = conditions(x, y, z, args...)

suggested_backend = chainrules_suggested_backend(rc)
prep = ImplicitFunctionPreparation(eltype(x))
Aᵀ = build_Aᵀ(implicit, prep, x, y, z, c, args...; suggested_backend)
Bᵀ = build_Bᵀ(implicit, prep, x, y, z, c, args...; suggested_backend)
project_x = ProjectTo(x)

function implicit_pullback_prepared((dy, dz))
dy = unthunk(dy)
dy_vec = vec(dy)
dc_vec = implicit.linear_solver(Aᵀ, -dy_vec)
dx_vec = Bᵀ(dc_vec)
dx = reshape(dx_vec, size(x))
df = NoTangent()
dprep = @not_implemented("Tangents for mutable arguments are not defined")
dargs = ntuple(unimplemented_tangent, N)
return (df, dprep, project_x(dx), dargs...)
end

return (y, z), implicit_pullback_prepared
implicit_pullback = ImplicitPullback(Aᵀ, Bᵀ, linear_solver, project_x, Val(N))
return (y, z), implicit_pullback
end

function unimplemented_tangent(_)
return @not_implemented(
"Tangents for positional arguments of an `ImplicitFunction` beyond `x` (the first one) are not implemented"
)
end

end
20 changes: 11 additions & 9 deletions ext/ImplicitDifferentiationForwardDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ using ImplicitDifferentiation:
ImplicitFunction, ImplicitFunctionPreparation, build_A, build_B

function (implicit::ImplicitFunction)(
prep::ImplicitFunctionPreparation, x_and_dx::AbstractArray{Dual{T,R,N}}, args...
prep::ImplicitFunctionPreparation{R}, x_and_dx::AbstractArray{Dual{T,R,N}}, args...
) where {T,R,N}
x = value.(x_and_dx)
y, z = implicit(x, args...)
Expand All @@ -19,14 +19,9 @@ function (implicit::ImplicitFunction)(
dX = ntuple(Val(N)) do k
partials.(x_and_dx, k)
end
dC_vec = map(dX) do dₖx
dₖx_vec = vec(dₖx)
dₖc_vec = B(dₖx_vec)
return dₖc_vec
end
dY = map(dC_vec) do dₖc_vec
dₖy_vec = implicit.linear_solver(A, -dₖc_vec)
dₖy = reshape(dₖy_vec, size(y))
dC = map(B, dX)
dY = map(dC) do dₖc
dₖy = implicit.linear_solver(A, -dₖc)
return dₖy
end

Expand All @@ -37,4 +32,11 @@ function (implicit::ImplicitFunction)(
return y_and_dy, z
end

function (implicit::ImplicitFunction)(
x_and_dx::AbstractArray{Dual{T,R,N}}, args...
) where {T,R,N}
prep = ImplicitFunctionPreparation(R)
return implicit(prep, x_and_dx, args...)
end

end
7 changes: 1 addition & 6 deletions src/ImplicitDifferentiation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,9 @@ using DifferentiationInterface:
prepare_pullback_same_point,
prepare_pushforward,
prepare_pushforward_same_point,
pullback!,
pullback,
pushforward!,
pushforward
using Krylov: Krylov, krylov_workspace, krylov_solve!, solution
using LinearOperators: LinearOperator
using LinearMaps: FunctionMap
using KrylovKit: linsolve
using LinearAlgebra: factorize

include("utils.jl")
Expand All @@ -36,6 +32,5 @@ include("callable.jl")
export MatrixRepresentation, OperatorRepresentation
export IterativeLinearSolver, DirectLinearSolver
export ImplicitFunction
export prepare_implicit

end
6 changes: 3 additions & 3 deletions src/callable.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
function (implicit::ImplicitFunction)(x::AbstractArray, args::Vararg{Any,N}) where {N}
return implicit(ImplicitFunctionPreparation(), x, args...)
return implicit(ImplicitFunctionPreparation(eltype(x)), x, args...)
end

function (implicit::ImplicitFunction)(
::ImplicitFunctionPreparation, x::AbstractArray, args::Vararg{Any,N}
) where {N}
::ImplicitFunctionPreparation{R}, x::AbstractArray{R}, args::Vararg{Any,N}
) where {R<:Real,N}
return implicit.solver(x, args...)
end
Loading
Loading