-
Notifications
You must be signed in to change notification settings - Fork 9
Expand file tree
/
Copy pathImplicitDifferentiationEnzymeExt.jl
More file actions
133 lines (105 loc) · 3.85 KB
/
ImplicitDifferentiationEnzymeExt.jl
File metadata and controls
133 lines (105 loc) · 3.85 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
module ImplicitDifferentiationEnzymeExt
using ADTypes: AutoEnzyme
using EnzymeCore
import EnzymeCore: EnzymeRules
using ImplicitDifferentiation:
ImplicitFunction,
ImplicitFunctionPreparation,
IterativeLeastSquaresSolver,
build_A,
build_Aᵀ,
build_B,
build_Bᵀ
import .EnzymeRules: AugmentedReturn
const AnyDuplicated{T} = Union{Duplicated{T}, BatchDuplicated{T}, DuplicatedNoNeed{T}, BatchDuplicatedNoNeed{T}}
function EnzymeRules.forward(config, implicit::Const{<:ImplicitFunction}, ::Type{<:AnyDuplicated}, x::AnyDuplicated, args::Vararg{<:Const})
implicit = implicit.val
dx = x.dval
x = x.val
args = ntuple(length(args)) do i
args[i].val
end
prep = ImplicitFunctionPreparation(eltype(x))
(; conditions, linear_solver) = implicit
y, z = implicit(x, args...)
c = conditions(x, y, z, args...)
y0 = zero(y)
forward_backend = AutoEnzyme(mode = Forward)
reverse_backend = AutoEnzyme(mode = Reverse)
A = build_A(implicit, prep, x, y, z, c, args...; suggested_backend = forward_backend)
B = build_B(implicit, prep, x, y, z, c, args...; suggested_backend = forward_backend)
Aᵀ = if linear_solver isa IterativeLeastSquaresSolver
build_Aᵀ(implicit, prep, x, y, z, c, args...; suggested_backend = reverse_backend)
else
nothing
end
return if EnzymeRules.width(config) == 1
dc = B(dx)
dy = linear_solver(A, Aᵀ, dc, y0)::typeof(y0)
dz = nothing
if EnzymeRules.needs_primal(config)
return Duplicated((y, z), (dy, dz))
else
return dy, dz
end
else
dc = map(B, dx)
dy = map(dc) do dₖc
linear_solver(A, Aᵀ, -dₖc, y0)
end
df = ntuple(Val(EnzymeRules.width(config))) do i
(dy[i]::typeof(y0), nothing)
end
if EnzymeRules.needs_primal(config)
return BatchDuplicated((y, z), df)
else
# TODO: We need to heal the type instability from the linear solver here
# df::NTuple{EnzymeRules.width(config), Tuple{typeof(y0), Nothing}}
return df::NTuple{EnzymeRules.width(config), Tuple{Vector{Float64}, Nothing}}
end
end
end
function EnzymeRules.augmented_primal(config, implicit::Const{<:ImplicitFunction}, RT::Type{<:AnyDuplicated}, x::AnyDuplicated, args::Vararg{<:Const})
@assert EnzymeRules.width(config) == 1
implicit = implicit.val
x = x.val
args = ntuple(length(args)) do i
args[i].val
end
prep = ImplicitFunctionPreparation(eltype(x))
(; conditions, linear_solver) = implicit
y, z = implicit(x, args...)
c = conditions(x, y, z, args...)
c0 = zero(c)
forward_backend = AutoEnzyme(mode = Forward)
reverse_backend = AutoEnzyme(mode = Reverse)
Aᵀ = build_Aᵀ(implicit, prep, x, y, z, c, args...; suggested_backend = reverse_backend)
Bᵀ = build_Bᵀ(implicit, prep, x, y, z, c, args...; suggested_backend = reverse_backend)
if linear_solver isa IterativeLeastSquaresSolver
A = build_A(implicit, prep, x, y, z, c, args...; suggested_backend = forward_backend)
else
A = nothing
end
if EnzymeRules.needs_primal(config)
primal = (y, z)
else
primal = nothing
end
dy = EnzymeCore.make_zero(y)
if EnzymeRules.needs_shadow(config)
shadow = (dy, EnzymeCore.make_zero(z))
else
shadow = nothing
end
tape = (; Aᵀ, Bᵀ, A, linear_solver, dy, c0)
AR = EnzymeRules.augmented_rule_return_type(config, RT)
return AR(primal, shadow, tape)
end
function EnzymeRules.reverse(_, ::Const{<:ImplicitFunction}, ::Type, tape, x::AnyDuplicated, ::Vararg{<:Const})
dx = x.dval
(; Aᵀ, Bᵀ, A, linear_solver, dy, c0) = tape
dc = linear_solver(Aᵀ, A, -dy, c0)
dx .+= Bᵀ(dc)
return (nothing, nothing)
end
end # modul