@@ -379,28 +379,18 @@ def test_scan_remat_parity(self):
379379 )
380380
381381 # 3. Run Forward
382- with self .mesh , nn_partitioning .axis_rules (self .config .logical_axis_rules ):
383- # For strict parity, we need same weights.
384- # Let's assume initialization with same key gives same weights if structure maps 1:1.
385- # scan/loop structure MIGHT differ if `scan` introduces extra variables?
386- # Usually strict weight copying is safer.
387- pass
388-
389- # Since weight copying might be tricky without access to intricate state,
390- # let's first check if they run and produce valid shapes.
391- # And if we can, assertions on output closeness IF we force same weights.
392-
393382 print ("Running scan_layers=True..." )
394- out_scan = model_scan (** inp_args )["sample" ]
395-
396- print ("Running scan_layers=False..." )
397- # To get same weights, we can try to copy state
398- # nnx.update(model_loop, nnx.state(model_scan)) # Might fail if structure differs
399- out_loop = model_loop (** inp_args )["sample" ]
400-
401- print ("Running remat_policy='full'..." )
402- out_remat = model_remat (** inp_args )["sample" ]
403-
383+ with self .mesh , nn_partitioning .axis_rules (self .config .logical_axis_rules ):
384+ out_scan = model_scan (** inp_args )["sample" ]
385+
386+ print ("Running scan_layers=False..." )
387+ # To get same weights, we can try to copy state
388+ # nnx.update(model_loop, nnx.state(model_scan)) # Might fail if structure differs
389+ out_loop = model_loop (** inp_args )["sample" ]
390+
391+ print ("Running remat_policy='full'..." )
392+ out_remat = model_remat (** inp_args )["sample" ]
393+
404394 self .assertEqual (out_scan .shape , out_loop .shape )
405395 self .assertEqual (out_scan .shape , out_remat .shape )
406396
0 commit comments