11import dataclasses
22import itertools
33import logging
4+ import math
45import typing as t
56
67import numpy
@@ -122,6 +123,40 @@ def _normalize_observers(
122123 return ObserverSet (obs )
123124
124125
126+ def _normalize_scan_shape (
127+ patterns : Patterns , state : ReconsState
128+ ) -> t .Tuple [Patterns , ReconsState ]:
129+ """
130+ Normalizes 'patterns' and 'state' to share a common scan shape.
131+
132+ Requires that there are an equal number of patterns and scan positions.
133+ Reshapes 'state.scan' and 'patterns' to match shape, choosing the highest
134+ dimensional shape of the two. 'state.tilt' is reshaped as well.
135+ """
136+ patterns_shape = patterns .patterns .shape [:- 2 ]
137+ scan_shape = state .scan .shape [:- 1 ]
138+
139+ n_patterns = math .prod (patterns_shape )
140+ n_scan = math .prod (scan_shape )
141+ if n_scan != n_patterns :
142+ raise ValueError (f"# of scan positions { n_scan } doesn't match # of patterns { n_patterns } " )
143+
144+ # choose the highest dimensional shape
145+ new_shape = scan_shape if len (scan_shape ) > len (patterns_shape ) else patterns_shape
146+
147+ patterns .patterns = patterns .patterns .reshape ((* new_shape , * patterns .patterns .shape [- 2 :]))
148+ state .scan = state .scan .reshape ((* new_shape , 2 ))
149+
150+ if state .tilt is not None :
151+ n_tilt = math .prod (state .tilt .shape [:- 1 ])
152+ if n_tilt != n_patterns :
153+ raise ValueError (f"# of tilt positions { n_scan } doesn't match # of patterns { n_patterns } " )
154+
155+ state .tilt = state .tilt .reshape ((* new_shape , 2 ))
156+
157+ return patterns , state
158+
159+
125160def load_raw_data (
126161 plan : ReconsPlan , xp : t .Any , seed : t .Any = None ,
127162 init_state : t .Union [ReconsState , PartialReconsState , None ] = None
@@ -295,6 +330,7 @@ def initialize_reconstruction(
295330 progress = ProgressState (iters = numpy .array ([]), detector_errors = numpy .array ([])),
296331 wavelength = wavelength
297332 )
333+ data , state = _normalize_scan_shape (data , state )
298334
299335 # process post_init hooks
300336 for p in plan .post_init :
@@ -305,15 +341,6 @@ def initialize_reconstruction(
305341
306342 # perform some checks on preprocessed data
307343
308- if state .scan .shape [:- 1 ] != data .patterns .shape [:- 2 ]:
309- n_pos = int (numpy .prod (state .scan .shape [:- 1 ]))
310- n_pat = int (numpy .prod (data .patterns .shape [:- 2 ]))
311- if n_pos != n_pat :
312- raise ValueError (f"# of scan positions { n_pos } doesn't match # of patterns { n_pat } " )
313-
314- # reshape patterns to match scan
315- data .patterns = data .patterns .reshape ((* state .scan .shape [:- 1 ], * data .patterns .shape [- 2 :]))
316-
317344 avg_pattern_intensity = float (numpy .nanmean (numpy .nansum (data .patterns , axis = (- 1 , - 2 ))))
318345
319346 if avg_pattern_intensity < 5.0 :
0 commit comments