1- # # Partial conditions
2-
3- struct ConditionsXNoByproduct{C,Y,A,K}
1+ struct ConditionsX{C,K}
42 conditions:: C
5- y:: Y
6- args:: A
73 kwargs:: K
84end
95
10- function (conditions_x_nobyproduct:: ConditionsXNoByproduct )(x:: AbstractVector )
11- (; conditions, y, args, kwargs) = conditions_x_nobyproduct
12- return conditions (x, y, args... ; kwargs... )
13- end
14-
15- struct ConditionsYNoByproduct{C,X,A,K}
6+ struct ConditionsY{C,K}
167 conditions:: C
17- x:: X
18- args:: A
198 kwargs:: K
209end
2110
22- function (conditions_y_nobyproduct:: ConditionsYNoByproduct )(y:: AbstractVector )
23- (; conditions, x, args, kwargs) = conditions_y_nobyproduct
24- return conditions (x, y, args... ; kwargs... )
11+ function (cx:: ConditionsX )(x, y, args... )
12+ return cx. conditions (x, y, args... ; cx. kwargs... )
2513end
2614
27- struct ConditionsXByproduct{C,Y,Z,A,K}
28- conditions:: C
29- y:: Y
30- z:: Z
31- args:: A
32- kwargs:: K
33- end
34-
35- function (conditions_x_byproduct:: ConditionsXByproduct )(x:: AbstractVector )
36- (; conditions, y, z, args, kwargs) = conditions_x_byproduct
37- return conditions (x, y, z, args... ; kwargs... )
15+ function (cy:: ConditionsY )(y, x, args... ) # order switch
16+ return cy. conditions (x, y, args... ; cy. kwargs... )
3817end
3918
40- struct ConditionsYByproduct{C,X,Z,A,K}
41- conditions:: C
19+ struct PushforwardOperator!{F,P,B,X,C,R}
20+ f:: F
21+ prep:: P
22+ backend:: B
4223 x:: X
43- z:: Z
44- args:: A
45- kwargs:: K
46- end
47-
48- function (conditions_y_byproduct:: ConditionsYByproduct )(y:: AbstractVector )
49- (; conditions, x, z, args, kwargs) = conditions_y_byproduct
50- return conditions (x, y, z, args... ; kwargs... )
51- end
52-
53- function ConditionsX (conditions, x, y_or_yz, args... ; kwargs... )
54- y = output (y_or_yz)
55- if y_or_yz isa Tuple
56- z = byproduct (y_or_yz)
57- return ConditionsXByproduct (conditions, y, z, args, kwargs)
58- else
59- return ConditionsXNoByproduct (conditions, y, args, kwargs)
60- end
61- end
62-
63- function ConditionsY (conditions, x, y_or_yz, args... ; kwargs... )
64- if y_or_yz isa Tuple
65- z = byproduct (y_or_yz)
66- return ConditionsYByproduct (conditions, x, z, args, kwargs)
67- else
68- return ConditionsYNoByproduct (conditions, x, args, kwargs)
69- end
24+ contexts:: C
25+ res_backup:: R
7026end
7127
72- # # Lazy operators
73-
74- struct PushforwardOperator!{F,B,X,E,R}
28+ struct PullbackOperator!{F,P,B,X,C,R}
7529 f:: F
30+ prep:: P
7631 backend:: B
7732 x:: X
78- extras :: E
33+ contexts :: C
7934 res_backup:: R
8035end
8136
37+ function PushforwardOperator! (f, prep, backend, x, contexts)
38+ res_backup = similar (f (x, map (unwrap, contexts)... ))
39+ return PushforwardOperator! (f, prep, backend, x, contexts, res_backup)
40+ end
41+
42+ function PullbackOperator! (f, prep, backend, x, contexts)
43+ res_backup = similar (x)
44+ return PullbackOperator! (f, prep, backend, x, contexts, res_backup)
45+ end
46+
8247function (po:: PushforwardOperator! )(res, v, α, β)
48+ (; f, backend, x, contexts, prep, res_backup) = po
8349 if iszero (β)
84- pushforward! (po. f, res, po. backend, po. x, v, po. extras)
85- res .= α .* res
50+ pushforward! (f, (res,), prep, backend, x, (v,), contexts... )
51+ if ! isone (α)
52+ res .*= α
53+ end
8654 else
87- po . res_backup . = res
88- pushforward! (po . f, res, po . backend, po . x, v, po . extras )
89- res . = α .* res .+ β .* po . res_backup
55+ copyto! ( res_backup, res)
56+ pushforward! (f, ( res,), prep, backend, x, (v,), contexts ... )
57+ axpby! (β, res_backup, α, res)
9058 end
9159 return res
9260end
9361
94- struct PullbackOperator!{F,B,X,E,R}
95- f:: F
96- backend:: B
97- x:: X
98- extras:: E
99- res_backup:: R
100- end
101-
10262function (po:: PullbackOperator! )(res, v, α, β)
63+ (; f, backend, x, contexts, prep, res_backup) = po
10364 if iszero (β)
104- pullback! (po. f, res, po. backend, po. x, v, po. extras)
105- res .= α .* res
65+ pullback! (f, (res,), prep, backend, x, (v,), contexts... )
66+ if ! isone (α)
67+ res .*= α
68+ end
10669 else
107- po . res_backup . = res
108- pullback! (po . f, res, po . backend, po . x, v, po . extras )
109- res . = α .* res .+ β .+ po . res_backup
70+ copyto! ( res_backup, res)
71+ pullback! (f, ( res,), prep, backend, x, (v,), contexts ... )
72+ axpby! (β, res_backup, α, res)
11073 end
11174 return res
11275end
@@ -119,24 +82,25 @@ function build_A(
11982 suggested_backend,
12083 kwargs... ,
12184) where {lazy}
122- (; conditions, linear_solver, conditions_y_backend) = implicit
85+ (; conditions, conditions_y_backend) = implicit
12386 y = output (y_or_yz)
12487 n, m = length (x), length (y)
12588 back_y = isnothing (conditions_y_backend) ? suggested_backend : conditions_y_backend
126- cond_y = ConditionsY (conditions, x, y_or_yz, args... ; kwargs... )
89+ cond_y = ConditionsY (conditions, kwargs)
90+ contexts = (Constant (x), map (Constant, rest (y_or_yz))... , map (Constant, args)... )
12791 if lazy
128- extras = prepare_pushforward_same_point (cond_y, back_y, y, zero (y))
92+ prep = prepare_pushforward_same_point (cond_y, back_y, y, ( zero (y),), contexts ... )
12993 A = LinearOperator (
13094 eltype (y),
13195 m,
13296 m,
13397 false ,
13498 false ,
135- PushforwardOperator! (cond_y, back_y, y, extras, similar (y) ),
99+ PushforwardOperator! (cond_y, prep, back_y, y, contexts ),
136100 typeof (y),
137101 )
138102 else
139- J = jacobian (cond_y, back_y, y)
103+ J = jacobian (cond_y, back_y, y, contexts ... )
140104 A = factorize (J)
141105 end
142106 return A
@@ -150,24 +114,25 @@ function build_Aᵀ(
150114 suggested_backend,
151115 kwargs... ,
152116) where {lazy}
153- (; conditions, linear_solver, conditions_y_backend) = implicit
117+ (; conditions, conditions_y_backend) = implicit
154118 y = output (y_or_yz)
155119 n, m = length (x), length (y)
156120 back_y = isnothing (conditions_y_backend) ? suggested_backend : conditions_y_backend
157- cond_y = ConditionsY (conditions, x, y_or_yz, args... ; kwargs... )
121+ cond_y = ConditionsY (conditions, kwargs)
122+ contexts = (Constant (x), map (Constant, rest (y_or_yz))... , map (Constant, args)... )
158123 if lazy
159- extras = prepare_pullback_same_point (cond_y, back_y, y, zero (y))
124+ prep = prepare_pullback_same_point (cond_y, back_y, y, ( zero (y),), contexts ... )
160125 Aᵀ = LinearOperator (
161126 eltype (y),
162127 m,
163128 m,
164129 false ,
165130 false ,
166- PullbackOperator! (cond_y, back_y, y, extras, similar (y) ),
131+ PullbackOperator! (cond_y, prep, back_y, y, contexts ),
167132 typeof (y),
168133 )
169134 else
170- Jᵀ = transpose (jacobian (cond_y, back_y, y))
135+ Jᵀ = transpose (jacobian (cond_y, back_y, y, contexts ... ))
171136 Aᵀ = factorize (Jᵀ)
172137 end
173138 return Aᵀ
@@ -181,24 +146,25 @@ function build_B(
181146 suggested_backend,
182147 kwargs... ,
183148) where {lazy}
184- (; conditions, linear_solver, conditions_x_backend) = implicit
149+ (; conditions, conditions_x_backend) = implicit
185150 y = output (y_or_yz)
186151 n, m = length (x), length (y)
187152 back_x = isnothing (conditions_x_backend) ? suggested_backend : conditions_x_backend
188- cond_x = ConditionsX (conditions, x, y_or_yz, args... ; kwargs... )
153+ cond_x = ConditionsX (conditions, kwargs)
154+ contexts = (Constant (y), map (Constant, rest (y_or_yz))... , map (Constant, args)... )
189155 if lazy
190- extras = prepare_pushforward_same_point (cond_x, back_x, x, zero (x))
156+ prep = prepare_pushforward_same_point (cond_x, back_x, x, ( zero (x),), contexts ... )
191157 B = LinearOperator (
192158 eltype (y),
193159 m,
194160 n,
195161 false ,
196162 false ,
197- PushforwardOperator! (cond_x, back_x, x, extras, similar (y) ),
163+ PushforwardOperator! (cond_x, prep, back_x, x, contexts ),
198164 typeof (x),
199165 )
200166 else
201- B = transpose (jacobian (cond_x, back_x, x))
167+ B = transpose (jacobian (cond_x, back_x, x, contexts ... ))
202168 end
203169 return B
204170end
@@ -211,24 +177,25 @@ function build_Bᵀ(
211177 suggested_backend,
212178 kwargs... ,
213179) where {lazy}
214- (; conditions, linear_solver, conditions_x_backend) = implicit
180+ (; conditions, conditions_x_backend) = implicit
215181 y = output (y_or_yz)
216182 n, m = length (x), length (y)
217183 back_x = isnothing (conditions_x_backend) ? suggested_backend : conditions_x_backend
218- cond_x = ConditionsX (conditions, x, y_or_yz, args... ; kwargs... )
184+ cond_x = ConditionsX (conditions, kwargs)
185+ contexts = (Constant (y), map (Constant, rest (y_or_yz))... , map (Constant, args)... )
219186 if lazy
220- extras = prepare_pullback_same_point (cond_x, back_x, x, zero (y))
187+ prep = prepare_pullback_same_point (cond_x, back_x, x, ( zero (y),), contexts ... )
221188 Bᵀ = LinearOperator (
222189 eltype (y),
223190 n,
224191 m,
225192 false ,
226193 false ,
227- PullbackOperator! (cond_x, back_x, x, extras, similar (x) ),
194+ PullbackOperator! (cond_x, prep, back_x, x, contexts ),
228195 typeof (x),
229196 )
230197 else
231- Bᵀ = transpose (jacobian (cond_x, back_x, x))
198+ Bᵀ = transpose (jacobian (cond_x, back_x, x, contexts ... ))
232199 end
233200 return Bᵀ
234201end
0 commit comments