Skip to content

Commit 1b90eb1

Browse files
committed
Refactor scan shape normalization
1 parent db91844 commit 1b90eb1

2 files changed

Lines changed: 86 additions & 13 deletions

File tree

phaser/execute.py

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import dataclasses
22
import itertools
33
import logging
4+
import math
45
import typing as t
56

67
import numpy
@@ -122,6 +123,40 @@ def _normalize_observers(
122123
return ObserverSet(obs)
123124

124125

126+
def _normalize_scan_shape(
127+
patterns: Patterns, state: ReconsState
128+
) -> t.Tuple[Patterns, ReconsState]:
129+
"""
130+
Normalizes 'patterns' and 'state' to share a common scan shape.
131+
132+
Requires that there are an equal number of patterns and scan positions.
133+
Reshapes 'state.scan' and 'patterns' to match shape, choosing the highest
134+
dimensional shape of the two. 'state.tilt' is reshaped as well.
135+
"""
136+
patterns_shape = patterns.patterns.shape[:-2]
137+
scan_shape = state.scan.shape[:-1]
138+
139+
n_patterns = math.prod(patterns_shape)
140+
n_scan = math.prod(scan_shape)
141+
if n_scan != n_patterns:
142+
raise ValueError(f"# of scan positions {n_scan} doesn't match # of patterns {n_patterns}")
143+
144+
# choose the highest dimensional shape
145+
new_shape = scan_shape if len(scan_shape) > len(patterns_shape) else patterns_shape
146+
147+
patterns.patterns = patterns.patterns.reshape((*new_shape, *patterns.patterns.shape[-2:]))
148+
state.scan = state.scan.reshape((*new_shape, 2))
149+
150+
if state.tilt is not None:
151+
n_tilt = math.prod(state.tilt.shape[:-1])
152+
if n_tilt != n_patterns:
153+
raise ValueError(f"# of tilt positions {n_scan} doesn't match # of patterns {n_patterns}")
154+
155+
state.tilt = state.tilt.reshape((*new_shape, 2))
156+
157+
return patterns, state
158+
159+
125160
def load_raw_data(
126161
plan: ReconsPlan, xp: t.Any, seed: t.Any = None,
127162
init_state: t.Union[ReconsState, PartialReconsState, None] = None
@@ -295,6 +330,7 @@ def initialize_reconstruction(
295330
progress=ProgressState(iters=numpy.array([]), detector_errors=numpy.array([])),
296331
wavelength=wavelength
297332
)
333+
data, state = _normalize_scan_shape(data, state)
298334

299335
# process post_init hooks
300336
for p in plan.post_init:
@@ -305,15 +341,6 @@ def initialize_reconstruction(
305341

306342
# perform some checks on preprocessed data
307343

308-
if state.scan.shape[:-1] != data.patterns.shape[:-2]:
309-
n_pos = int(numpy.prod(state.scan.shape[:-1]))
310-
n_pat = int(numpy.prod(data.patterns.shape[:-2]))
311-
if n_pos != n_pat:
312-
raise ValueError(f"# of scan positions {n_pos} doesn't match # of patterns {n_pat}")
313-
314-
# reshape patterns to match scan
315-
data.patterns = data.patterns.reshape((*state.scan.shape[:-1], *data.patterns.shape[-2:]))
316-
317344
avg_pattern_intensity = float(numpy.nanmean(numpy.nansum(data.patterns, axis=(-1, -2))))
318345

319346
if avg_pattern_intensity < 5.0:

tests/test_initialization.py

Lines changed: 50 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,21 @@
1616

1717

1818
def load_empty(args, props) -> RawData:
19+
scan_shape = props['scan_shape']
20+
det_shape = props['det_shape']
21+
22+
return {
23+
'patterns': numpy.zeros((*scan_shape, *det_shape), dtype=numpy.float32),
24+
'mask': numpy.ones(det_shape, dtype=numpy.float32),
25+
'sampling': Sampling(det_shape, sampling=(1.0, 1.0)),
26+
'wavelength': 1.0,
27+
'scan_hook': None,
28+
'probe_hook': None,
29+
'seed': None,
30+
}
31+
32+
33+
def load_no_probe(args, props) -> RawData:
1934
return {
2035
'patterns': numpy.zeros((32, 32, 64, 64), dtype=numpy.float32),
2136
'mask': numpy.ones((64, 64), dtype=numpy.float32),
@@ -34,7 +49,7 @@ def load_empty(args, props) -> RawData:
3449
def test_load_raw_data_missing():
3550
plan = ReconsPlan.from_data({
3651
'name': 'test',
37-
'raw_data': 'tests.test_initialization:load_empty',
52+
'raw_data': 'tests.test_initialization:load_no_probe',
3853
'engines': [],
3954
})
4055
xp = numpy
@@ -46,7 +61,7 @@ def test_load_raw_data_missing():
4661
def test_load_raw_data_override():
4762
plan = {
4863
'name': 'test',
49-
'raw_data': 'tests.test_initialization:load_empty',
64+
'raw_data': 'tests.test_initialization:load_no_probe',
5065
'engines': [],
5166
'init': {
5267
'probe': {
@@ -90,7 +105,7 @@ def test_load_raw_data_override():
90105
def test_load_raw_data_prev_state(caplog):
91106
plan = {
92107
'name': 'test',
93-
'raw_data': 'tests.test_initialization:load_empty',
108+
'raw_data': 'tests.test_initialization:load_no_probe',
94109
'engines': [],
95110
}
96111

@@ -135,4 +150,35 @@ def test_load_raw_data_prev_state(caplog):
135150

136151
# both should be modeled
137152
assert recons.state.probe is not probe_state
138-
assert recons.state.scan is not scan_state
153+
assert recons.state.scan is not scan_state
154+
155+
156+
def test_load_3d_raw_data():
157+
scan_shape = (64, 64)
158+
det_shape = (128, 128)
159+
160+
plan = ReconsPlan.from_data({
161+
'name': 'test',
162+
'raw_data': {
163+
'type': 'tests.test_initialization:load_empty',
164+
'scan_shape': (4096,),
165+
'det_shape': det_shape,
166+
},
167+
'init': {
168+
'scan': {
169+
'type': 'raster',
170+
'shape': scan_shape,
171+
'step_size': (1.0, 1.0),
172+
},
173+
'probe': {
174+
'type': 'focused',
175+
'conv_angle': 20.0,
176+
'defocus': 300.0,
177+
}
178+
},
179+
'engines': [],
180+
})
181+
recons = initialize_reconstruction(plan)
182+
183+
assert recons.state.scan.shape == (*scan_shape, 2)
184+
assert recons.patterns.patterns.shape == (*scan_shape, *det_shape)

0 commit comments

Comments
 (0)