@@ -1247,65 +1247,34 @@ def __init__(
12471247 precision = precision ,
12481248 )
12491249
1250- # @nnx.jit
1250+ # REMOVE @nnx.jit for now to ensure this logic runs
12511251 def encode (
12521252 self , x : jax .Array , return_dict : bool = True
12531253 ) -> Union [FlaxAutoencoderKLOutput , Tuple [FlaxDiagonalGaussianDistribution ]]:
12541254 if x .shape [- 1 ] != 3 :
1255- # reshape channel last for JAX
12561255 x = jnp .transpose (x , (0 , 2 , 3 , 4 , 1 ))
12571256 assert x .shape [- 1 ] == 3 , f"Expected input shape (N, D, H, W, 3), got { x .shape } "
12581257
12591258 b , t , h , w , c = x .shape
1260- all_outs = []
1261-
1262- def scan_fn_chunk (carry , input_slice ):
1263- # input_slice shape is (B, H, W, C)
1264- input_slice = jnp .expand_dims (input_slice , 1 ) # Shape (B, 1, H, W, C) for encoder
1265- out_slice , new_carry = self .encoder (input_slice , carry )
1266- # out_slice shape is (B, 1, H', W', C')
1267- out_slice = jnp .squeeze (out_slice , 1 ) # Shape (B, H', W', C') for scan output
1268- return new_carry , out_slice
1269-
1270- # 1. Process the first frame
1271- # Initialize cache for the first frame
1272- init_cache_first = self .encoder .init_cache (b , h , w , x .dtype )
1273- x_first_scan = jnp .expand_dims (x [:, 0 , ...], axis = 0 ) # Shape (1, B, H, W, C) for scan
1259+ init_cache = self .encoder .init_cache (b , h , w , x .dtype )
12741260
1275- _ , out_first_frames = jax .lax .scan (scan_fn_chunk , init_cache_first , x_first_scan )
1276- # out_first_frames shape is (1, B, H', W', C')
1277- all_outs .append (jnp .swapaxes (out_first_frames , 0 , 1 )) # Shape (B, 1, H', W', C')
1261+ # Process first frame
1262+ out1 , _ = self .encoder (x [:, :1 , ...], init_cache )
12781263
1279- # 2. Process subsequent Chunks of 4
12801264 if t > 1 :
1281- num_chunks = (t - 1 + 3 ) // 4 # Ceiling division
1282- for i in range (num_chunks ):
1283- start_idx = 1 + 4 * i
1284- end_idx = min (start_idx + 4 , t )
1285-
1286- if start_idx >= t :
1287- break
1288-
1289- chunk = x [:, start_idx :end_idx , ...]
1290- # Prepare chunk for scan: shape (T_chunk, B, H, W, C)
1291- x_scan = jnp .swapaxes (chunk , 0 , 1 )
1292-
1293- # *** Cache Reset for EACH CHUNK ***
1294- init_cache_chunk = self .encoder .init_cache (b , h , w , x .dtype )
1295-
1296- _ , encoded_frames_chunk = jax .lax .scan (scan_fn_chunk , init_cache_chunk , x_scan )
1297- # encoded_frames_chunk shape is (T_chunk, B, H', W', C')
1298-
1299- # Transpose back to (B, T_chunk, H', W', C')
1300- out_chunk = jnp .swapaxes (encoded_frames_chunk , 0 , 1 )
1301- all_outs .append (out_chunk )
1302-
1303- # Concatenate results from all chunks along the time dimension
1304- encoded = jnp .concatenate (all_outs , axis = 1 )
1265+ # Process remaining frames in one chunk
1266+ # We need to manage cache updates manually if not using scan
1267+ # This part is tricky because the new encoder returns cache,
1268+ # but the old logic didn't seem to carry cache between chunks.
1269+
1270+ # Let's SIMPLIFY to match the OLD logic's spirit: Reset cache for the chunk
1271+ init_cache_rest = self .encoder .init_cache (b , h , w , x .dtype )
1272+ out_rest , _ = self .encoder (x [:, 1 :, ...], init_cache_rest )
1273+ encoded = jnp .concatenate ([out1 , out_rest ], axis = 1 )
1274+ else :
1275+ encoded = out1
13051276
1306- # Apply quant_conv - this layer also has a cache, but the old code didn't pipe it.
13071277 enc , _ = self .quant_conv (encoded , cache_x = None )
1308-
13091278 mu , logvar = enc [:, :, :, :, : self .z_dim ], enc [:, :, :, :, self .z_dim :]
13101279 h = jnp .concatenate ([mu , logvar ], axis = - 1 )
13111280
0 commit comments