@@ -91,17 +91,21 @@ function (po::PushforwardOperator!)(res, v, α, β)
9191 return res
9292end
9393
94- struct PullbackOperator!{PB,R}
95- pullbackfunc!:: PB
94+ struct PullbackOperator!{F,B,X,E,R}
95+ f:: F
96+ backend:: B
97+ x:: X
98+ extras:: E
9699 res_backup:: R
97100end
98101
99102function (po:: PullbackOperator! )(res, v, α, β)
100103 if iszero (β)
101- po. pullbackfunc! (res, v)
104+ pullback! (po. f, res, po. backend, po. x, v, po. extras)
105+ res .= α .* res
102106 else
103107 po. res_backup .= res
104- po . pullbackfunc! ( res, v )
108+ pullback! (po . f, res, po . backend, po . x, v, po . extras )
105109 res .= α .* res .+ β .+ po. res_backup
106110 end
107111 return res
@@ -121,7 +125,7 @@ function build_A(
121125 back_y = isnothing (conditions_y_backend) ? suggested_backend : conditions_y_backend
122126 cond_y = ConditionsY (conditions, x, y_or_yz, args... ; kwargs... )
123127 if lazy
124- extras = prepare_pushforward (cond_y, back_y, y, similar (y))
128+ extras = prepare_pushforward_same_point (cond_y, back_y, y, zero (y))
125129 A = LinearOperator (
126130 eltype (y),
127131 m,
@@ -152,15 +156,14 @@ function build_Aᵀ(
152156 back_y = isnothing (conditions_y_backend) ? suggested_backend : conditions_y_backend
153157 cond_y = ConditionsY (conditions, x, y_or_yz, args... ; kwargs... )
154158 if lazy
155- extras = prepare_pullback (cond_y, back_y, y, similar (y))
156- _, pullbackfunc! = value_and_pullback!_split (cond_y, back_y, y, extras)
159+ extras = prepare_pullback_same_point (cond_y, back_y, y, zero (y))
157160 Aᵀ = LinearOperator (
158161 eltype (y),
159162 m,
160163 m,
161164 false ,
162165 false ,
163- PullbackOperator! (pullbackfunc! , similar (y)),
166+ PullbackOperator! (cond_y, back_y, y, extras , similar (y)),
164167 typeof (y),
165168 )
166169 else
@@ -184,7 +187,7 @@ function build_B(
184187 back_x = isnothing (conditions_x_backend) ? suggested_backend : conditions_x_backend
185188 cond_x = ConditionsX (conditions, x, y_or_yz, args... ; kwargs... )
186189 if lazy
187- extras = prepare_pushforward (cond_x, back_x, x, similar (x))
190+ extras = prepare_pushforward_same_point (cond_x, back_x, x, zero (x))
188191 B = LinearOperator (
189192 eltype (y),
190193 m,
@@ -214,15 +217,14 @@ function build_Bᵀ(
214217 back_x = isnothing (conditions_x_backend) ? suggested_backend : conditions_x_backend
215218 cond_x = ConditionsX (conditions, x, y_or_yz, args... ; kwargs... )
216219 if lazy
217- extras = prepare_pullback (cond_x, back_x, x, similar (y))
218- _, pullbackfunc! = value_and_pullback!_split (cond_x, back_x, x, extras)
220+ extras = prepare_pullback_same_point (cond_x, back_x, x, zero (y))
219221 Bᵀ = LinearOperator (
220222 eltype (y),
221223 n,
222224 m,
223225 false ,
224226 false ,
225- PullbackOperator! (pullbackfunc!, similar (y )),
227+ PullbackOperator! (cond_x, back_x, x, extras, similar (x )),
226228 typeof (x),
227229 )
228230 else
0 commit comments