Skip to content

Commit cce283a

Browse files
committed
Apply tilts to conventional solvers
1 parent 029ccfc commit cce283a

5 files changed

Lines changed: 53 additions & 37 deletions

File tree

phaser/engines/common/simulation.py

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -191,34 +191,33 @@ def make_propagators(state: ReconsState, bwlim_frac: t.Optional[float] = 2/3) ->
191191

192192

193193
def tilt_propagators(
194+
ky: NDArray[numpy.floating], kx: NDArray[numpy.floating],
194195
state: ReconsState,
195-
base_props: t.Optional[NDArray[numpy.complexfloating]], # shape: (Nz-1, Ny, Nx)
196-
group_tilts: NDArray[numpy.floating], # shape: (batch, 2), in mrad
197-
kx: NDArray[numpy.floating], # shape: (Ny, Nx)
198-
ky: NDArray[numpy.floating], # shape: (Ny, Nx)
199-
delta_zs: NDArray[numpy.floating], # shape: (Nz-1)
196+
props: t.Optional[NDArray[numpy.complexfloating]], # shape: (Nz-1, Ny, Nx)
197+
tilts: NDArray[numpy.floating] # shape: (..., 2), in mrad
200198
) -> t.Optional[NDArray[numpy.complexfloating]]:
201199
"""
202-
Applies tilt and slice-dependent propagation phase shifts to base_props.
200+
Applies tilt and slice-dependent propagation phase shifts to props.
203201
-------
204202
NDArray[complex] or None
205-
Tilted propagators of shape (batch, Nz-1, Ny, Nx), or None if no slices.
203+
Tilted propagators of shape (n_layers-1, ..., Ny, Nx), or None if no slices.
206204
"""
207-
if base_props is None:
205+
if props is None:
208206
return None
209207

210208
xp = get_array_module(state.probe.data)
211209
dtype = to_real_dtype(state.probe.data.dtype)
212210
complex_dtype = to_complex_dtype(dtype)
211+
delta_zs = state.object.thicknesses[:-1]
213212

214-
tilt_ramps = xp.exp( # (batch, Nz-1, Ny, Nx)
215-
2.j * xp.pi * delta_zs[:, None, None] * (
216-
ufunc_outer(xp.multiply, xp.tan(group_tilts[:, 0] * 1e-3), ky)[:, None, ...] +
217-
ufunc_outer(xp.multiply, xp.tan(group_tilts[:, 1] * 1e-3), kx)[:, None, ...]
218-
)
213+
tilt_ramps = xp.exp( # (n_layers-1, batch, Ny, Nx)
214+
2.j * xp.pi * ufunc_outer(xp.multiply, delta_zs, (
215+
ufunc_outer(xp.multiply, xp.tan(tilts[..., 0] * 1e-3), ky) +
216+
ufunc_outer(xp.multiply, xp.tan(tilts[..., 1] * 1e-3), kx)
217+
))
219218
)
220219

221-
return base_props[None, ...] * tilt_ramps.astype(complex_dtype)
220+
return props[(slice(None), *(None,)*(tilts.ndim - 1), Ellipsis)] * tilt_ramps.astype(complex_dtype)
222221

223222

224223
@t.overload
@@ -267,21 +266,21 @@ def slice_forwards(
267266
if props is None:
268267
return f(0, None, state)
269268

270-
n_slices = props.shape[1] + 1 # props shape: (batch, Nz-1, Ny, Nx)
269+
# props shape: (N_slices - 1, [batch], Ny, Nx)
270+
n_slices = len(props) + 1
271271

272272
if is_jax(props):
273273
import jax
274274
def step_fn(carry, slice_i):
275-
# props[:, slice_i] has shape (batch, Ny, Nx)
276-
new_state = f(slice_i, props[:, slice_i], carry)
275+
new_state = f(slice_i, props[slice_i], carry)
277276
return new_state, None
278277

279278
state, _ = jax.lax.scan(step_fn, state, jax.numpy.arange(n_slices - 1))
280279
return f(n_slices - 1, None, state)
281280

282281
# fallback numpy mode
283282
for slice_i in range(n_slices - 1):
284-
state = f(slice_i, props[:, slice_i], state)
283+
state = f(slice_i, props[slice_i], state)
285284
return f(n_slices - 1, None, state)
286285

287286

phaser/engines/conventional/solvers.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from phaser.plan import ConventionalEnginePlan, LSQMLSolverPlan, EPIESolverPlan
1212
from phaser.execute import Observer
1313
from phaser.engines.common.simulation import (
14-
stream_patterns, SimulationState, cutout_group, slice_forwards, slice_backwards
14+
stream_patterns, SimulationState, cutout_group, tilt_propagators, slice_forwards, slice_backwards
1515
)
1616

1717

@@ -151,10 +151,11 @@ def run_slice(slice_i: int, prop: t.Optional[NDArray[numpy.complexfloating]], st
151151
)
152152

153153
if prop is not None:
154-
psi = ifft2(fft2(psi * group_obj[:, slice_i, None]) * prop)
154+
psi = ifft2(fft2(psi * group_obj[:, slice_i, None]) * prop[:, None])
155155

156156
return (probe_mag, psi)
157157

158+
props = tilt_propagators(sim.ky, sim.kx, sim.state, props, sim.state.tilt[tuple(group)])
158159
(probe_mag, psi) = slice_forwards(props, (probe_mag, psi), run_slice)
159160

160161
# modeled and experimental intensity
@@ -214,11 +215,12 @@ def sim_slice(slice_i: int, prop: t.Optional[NDArray[numpy.complexfloating]], st
214215

215216
if prop is not None:
216217
psi = at(psi, slice_i + 1).set(
217-
ifft2(fft2(psi[slice_i] * group_obj[:, slice_i, None]) * prop)
218+
ifft2(fft2(psi[slice_i] * group_obj[:, slice_i, None]) * prop[:, None])
218219
)
219220

220221
return (group_probe_mag, psi)
221222

223+
props = tilt_propagators(sim.ky, sim.kx, sim.state, props, sim.state.tilt[tuple(group)])
222224
(group_probe_mag, psi) = slice_forwards(props, (group_probe_mag, psi), sim_slice)
223225

224226
new_obj_mag += group_obj_mag
@@ -253,7 +255,7 @@ def update_slice(slice_i: int, prop: t.Optional[NDArray[numpy.complexfloating]],
253255
sim.state.object.data = at(sim.state.object.data, slice_i).add(obj_update)
254256

255257
if prop is not None:
256-
chi = ifft2(fft2(delta_P) * prop.conj())
258+
chi = ifft2(fft2(delta_P) * prop.conj()[:, None])
257259
elif update_probe:
258260
delta_P_avg = ifft2(xp.sum(fft2(delta_P) * subpx_filters.conj(), axis=0))
259261
delta_P_avg /= (group_obj_mag + illum_reg_probe)
@@ -392,10 +394,11 @@ def epie_dry_run(
392394

393395
def run_slice(slice_i: int, prop: t.Optional[NDArray[numpy.complexfloating]], psi):
394396
if prop is not None:
395-
psi = ifft2(fft2(psi * group_obj[:, slice_i, None]) * prop)
397+
psi = ifft2(fft2(psi * group_obj[:, slice_i, None]) * prop[:, None])
396398

397399
return psi
398400

401+
props = tilt_propagators(sim.ky, sim.kx, sim.state, props, sim.state.tilt[tuple(group)])
399402
psi = slice_forwards(props, psi, run_slice)
400403

401404
# modeled and experimental intensity
@@ -431,11 +434,12 @@ def epie_run(
431434
def sim_slice(slice_i: int, prop: t.Optional[NDArray[numpy.complexfloating]], psi):
432435
if prop is not None:
433436
psi = at(psi, slice_i + 1).set(
434-
ifft2(fft2(psi[slice_i] * group_obj[:, slice_i, None]) * prop)
437+
ifft2(fft2(psi[slice_i] * group_obj[:, slice_i, None]) * prop[:, None])
435438
)
436439

437440
return psi
438441

442+
props = tilt_propagators(sim.ky, sim.kx, sim.state, props, sim.state.tilt[tuple(group)])
439443
psi = slice_forwards(props, psi, sim_slice)
440444

441445
model_wave = fft2(psi[-1] * group_obj[:, -1, None])
@@ -468,7 +472,7 @@ def update_slice(slice_i: int, prop: t.Optional[NDArray[numpy.complexfloating]],
468472
)
469473

470474
if prop is not None:
471-
chi = ifft2(fft2(probe_update) * prop.conj())
475+
chi = ifft2(fft2(probe_update) * prop.conj()[:, None])
472476
elif update_probe:
473477
# average probe updates in group
474478
probe_update = ifft2(xp.mean(fft2(probe_update) * subpx_filters.conj(), axis=0))

phaser/engines/gradient/run.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -400,25 +400,22 @@ def run_model(
400400
group_tilts = sim.tilt
401401

402402
(ky, kx) = sim.probe.sampling.recip_grid(dtype=dtype, xp=xp)
403-
delta_zs = sim.object.thicknesses[:-1]
404403
xp = get_array_module(sim.probe.data)
405404
dtype = to_real_dtype(sim.probe.data.dtype)
406-
complex_dtype = to_complex_dtype(dtype)
407405

408406
probes = sim.probe.data
409407
group_obj = sim.object.sampling.get_view_at_pos(sim.object.data, group_scan, probes.shape[-2:])
410408
group_subpx_filters = fourier_shift_filter(ky, kx, sim.object.sampling.get_subpx_shifts(group_scan, probes.shape[-2:]))[:, None, ...]
411409
probes = ifft2(fft2(probes) * group_subpx_filters)
412-
413-
t_props = tilt_propagators(sim, props, group_tilts, kx, ky, delta_zs,)
414410

415411
def sim_slice(slice_i: int, prop: t.Optional[NDArray[numpy.complexfloating]], psi):
416412
# psi: (batch, n_probe, Ny, Nx)
417413
if prop is not None:
418414
return ifft2(fft2(psi * group_obj[:, slice_i, None]) * prop[:, None])
419415
return psi * group_obj[:, slice_i, None]
420416

421-
model_wave = fft2(slice_forwards(t_props, probes, sim_slice))
417+
props = tilt_propagators(ky, kx, sim, props, group_tilts)
418+
model_wave = fft2(slice_forwards(props, probes, sim_slice))
422419
model_intensity = xp.sum(abs2(model_wave), axis=1)
423420

424421
(loss, solver_states.noise_model_state) = noise_model.calc_loss(
@@ -447,20 +444,19 @@ def dry_run(
447444
dtype: t.Type[numpy.floating],
448445
) -> NDArray[numpy.floating]:
449446
(ky, kx) = sim.probe.sampling.recip_grid(dtype=dtype, xp=xp)
450-
delta_zs = sim.object.thicknesses[:-1]
451447

452448
probes = sim.probe.data
453449
group_obj = sim.object.sampling.get_view_at_pos(sim.object.data, sim.scan[tuple(group)], probes.shape[-2:])
454450
group_subpx_filters = fourier_shift_filter(ky, kx, sim.object.sampling.get_subpx_shifts(sim.scan[tuple(group)], probes.shape[-2:]))[:, None, ...]
455451
probes = ifft2(fft2(probes) * group_subpx_filters)
456-
t_props = tilt_propagators(sim, props, sim.tilt[tuple(group)], kx, ky, delta_zs)
452+
props = tilt_propagators(ky, kx, sim, props, sim.tilt[tuple(group)])
457453

458454
def sim_slice(slice_i: int, prop: t.Optional[NDArray[numpy.complexfloating]], psi):
459455
if prop is not None:
460456
return ifft2(fft2(psi * group_obj[:, slice_i, None]) * prop[:, None])
461457
return psi * group_obj[:, slice_i, None]
462458

463-
model_wave = fft2(slice_forwards(t_props, probes, sim_slice))
459+
model_wave = fft2(slice_forwards(props, probes, sim_slice))
464460
model_intensity = xp.sum(abs2(model_wave), axis=(1, -2, -1))
465461
exp_intensity = xp.sum(group_patterns, axis=(-2, -1))
466462

phaser/utils/_jax_kernels.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,11 @@ def get_cutouts(obj: jax.Array, start_idxs: jax.Array, cutout_shape: t.Tuple[int
5858

5959
@partial(jax.jit, static_argnums=0)
6060
def outer(ufunc: t.Any, x: jax.Array, y: jax.Array) -> jax.Array:
61-
return jax.vmap(jax.vmap(ufunc, (None, 0)), (0, None))(x, y)
61+
if x.ndim == 0 or y.ndim == 0:
62+
return ufunc(x, y)
63+
64+
out_shape = (*x.shape, *y.shape)
65+
return jax.vmap(jax.vmap(ufunc, (None, 0)), (0, None))(x.ravel(), y.ravel()).reshape(out_shape)
6266

6367

6468
@partial(jax.jit, static_argnames=('output_shape', 'order', 'mode', 'cval'))

tests/test_num.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11

22
import numpy
3-
from numpy.testing import assert_array_almost_equal
3+
from numpy.testing import assert_array_almost_equal, assert_array_equal
44
import pytest
55

66
from .utils import with_backends, get_backend_module, get_backend_scipy, mock_importerror
@@ -9,7 +9,8 @@
99
get_array_module, get_scipy_module,
1010
to_real_dtype, to_complex_dtype,
1111
fft2, ifft2, abs2,
12-
to_numpy, as_array
12+
to_numpy, as_array,
13+
ufunc_outer
1314
)
1415

1516

@@ -197,4 +198,16 @@ def test_to_array(backend: str):
197198
assert_array_almost_equal(
198199
arr,
199200
numpy.array([1., 2., 3., 4.])
200-
)
201+
)
202+
203+
204+
@with_backends('cpu', 'jax', 'cuda')
205+
def test_ufunc_outer(backend: str):
206+
xp = get_backend_module(backend)
207+
208+
xs = numpy.arange(12).reshape(4, 3)
209+
ys = numpy.arange(30).reshape(5, 6)
210+
211+
expected = numpy.multiply.outer(xs, ys)
212+
actual = to_numpy(ufunc_outer(xp.multiply, xp.array(xs), xp.array(ys)))
213+
assert_array_equal(expected, actual)

0 commit comments

Comments
 (0)