Skip to content

Commit 2eb2008

Browse files
committed
jax checkpointing removed
1 parent b1f1b6d commit 2eb2008

1 file changed

Lines changed: 4 additions & 58 deletions

File tree

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 4 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,6 @@ def __call__(
166166

167167
out = self.conv(x_padded)
168168
new_cache = new_cache.astype(jnp.bfloat16)
169-
print(f"Exiting WanCausalConv3d: {out.shape}")
170169
return out, new_cache
171170

172171

@@ -194,13 +193,11 @@ def __init__(
194193
self.bias = 0
195194

196195
def __call__(self, x: jax.Array) -> jax.Array:
197-
print(f"Entering WanRMS_norm: {x.shape}")
198196
normalized = jnp.linalg.norm(x, ord=2, axis=(1 if self.channel_first else -1), keepdims=True)
199197
normalized = x / jnp.maximum(normalized, self.eps)
200198
normalized = normalized * self.scale * self.gamma
201199
if self.bias:
202200
return normalized + self.bias.value
203-
print(f"Exiting WanRMS_norm: {normalized.shape}")
204201
return normalized
205202

206203

@@ -213,7 +210,6 @@ def __init__(self, scale_factor: Tuple[float, float], method: str = "nearest"):
213210
self.method = method
214211

215212
def __call__(self, x: jax.Array) -> jax.Array:
216-
print(f"Entering WanUpsample: {x.shape}")
217213
input_dtype = x.dtype
218214
in_shape = x.shape
219215
assert len(in_shape) == 4, "This module only takes tensors with shape of 4."
@@ -222,14 +218,11 @@ def __call__(self, x: jax.Array) -> jax.Array:
222218
target_w = int(w * self.scale_factor[1])
223219
out = jax.image.resize(x.astype(jnp.float32), (n, target_h, target_w, c), method=self.method)
224220
out = out.astype(input_dtype)
225-
print(f"Exiting WanUpsample: {out.shape}")
226221
return out
227222

228223

229224
class Identity(nnx.Module):
230225
def __call__(self, x, cache=None):
231-
print(f"Entering Identity: {x.shape}")
232-
print(f"Exiting Identity: {x.shape}")
233226
return x, cache
234227

235228

@@ -264,9 +257,7 @@ def __init__(
264257
)
265258

266259
def __call__(self, x, cache=None):
267-
print(f"Entering ZeroPaddedConv2D: {x.shape}")
268260
out = self.conv(x)
269-
print(f"Exiting ZeroPaddedConv2D: {out.shape}")
270261
return out, cache
271262

272263

@@ -378,81 +369,65 @@ def initialize_cache(self, batch_size, height, width, dtype):
378369
def __call__(
379370
self, x: jax.Array, cache: Dict[str, Any] = None
380371
) -> Tuple[jax.Array, Dict[str, Any]]:
381-
print(f"Entering WanResample ({self.mode}): {x.shape}")
382372
if cache is None:
383373
cache = {}
384374
new_cache = {}
385375

386376
if self.mode == "upsample2d":
387377
b, t, h, w, c = x.shape
388378
x = x.reshape(b * t, h, w, c)
389-
print(f"WanResample ({self.mode}) reshaped for resample: {x.shape}")
390379
x = self.resample(x)
391-
print(f"WanResample ({self.mode}) after resample: {x.shape}")
392380
h_new, w_new, c_new = x.shape[1:]
393381
x = x.reshape(b, t, h_new, w_new, c_new)
394382

395383
elif self.mode == "upsample3d":
396384
x, tc_cache = self.time_conv(x, cache.get("time_conv"))
397385
new_cache["time_conv"] = tc_cache
398-
print(f"WanResample ({self.mode}) after time_conv: {x.shape}")
399386

400387
b, t, h, w, c = x.shape
401388
x = x.reshape(b, t, h, w, 2, c // 2)
402389
x = jnp.stack([x[:, :, :, :, 0, :], x[:, :, :, :, 1, :]], axis=1)
403390
x = x.reshape(b, t * 2, h, w, c // 2)
404-
print(f"WanResample ({self.mode}) after time dim expand: {x.shape}")
405-
406391

407392
b, t, h, w, c = x.shape
408393
x = x.reshape(b * t, h, w, c)
409-
print(f"WanResample ({self.mode}) reshaped for resample: {x.shape}")
410394
x = self.resample(x)
411-
print(f"WanResample ({self.mode}) after resample: {x.shape}")
412395
h_new, w_new, c_new = x.shape[1:]
413396
x = x.reshape(b, t, h_new, w_new, c_new)
414397

415398
elif self.mode == "downsample2d":
416399
b, t, h, w, c = x.shape
417400
x = x.reshape(b * t, h, w, c)
418-
print(f"WanResample ({self.mode}) reshaped for resample: {x.shape}")
419401
x, _ = self.resample(x, None)
420-
print(f"WanResample ({self.mode}) after resample: {x.shape}")
421402
h_new, w_new, c_new = x.shape[1:]
422403
x = x.reshape(b, t, h_new, w_new, c_new)
423404

424405
elif self.mode == "downsample3d":
425406
if x.shape[1] >= self.time_conv.kernel_size[0]:
426407
x, tc_cache = self.time_conv(x, cache.get("time_conv"))
427408
new_cache["time_conv"] = tc_cache
428-
print(f"WanResample ({self.mode}) after time_conv: {x.shape}")
429409
else:
430410
# Skip temporal downsampling if not enough frames
431-
print(f"WanResample ({self.mode}): Skipping time_conv, input time dim {x.shape[1]} < kernel {self.time_conv.kernel_size[0]}")
432411
new_cache["time_conv"] = cache.get("time_conv") # Pass through cache
433412

434413
b, t, h, w, c = x.shape
435414
if b * t > 0:
436415
x = x.reshape(b * t, h, w, c)
437-
print(f"WanResample ({self.mode}) reshaped for resample: {x.shape}")
438416
x, _ = self.resample(x, None)
439-
print(f"WanResample ({self.mode}) after resample: {x.shape}")
440417
h_new, w_new, c_new = x.shape[1:]
441418
x = x.reshape(b, t, h_new, w_new, c_new)
442419
else:
443420
# If time dimension became 0, spatial shape changes, but batch and time are still 0
444421
h_new, w_new = h // self.resample.conv.strides[0], w // self.resample.conv.strides[1]
445422
c_new = self.resample.conv.out_features
446423
x = jnp.zeros((b, t, h_new, w_new, c_new), dtype=x.dtype)
447-
print(f"WanResample ({self.mode}): Spatial downsample output shape {x.shape} (due to t=0)")
448424
else:
449425
if hasattr(self, "resample"):
450426
if isinstance(self.resample, Identity):
451427
x, _ = self.resample(x, None)
452428
else:
453429
x = self.resample(x)
454430

455-
print(f"Exiting WanResample ({self.mode}): {x.shape}")
456431
return x, new_cache
457432

458433

@@ -526,33 +501,26 @@ def initialize_cache(self, batch_size, height, width, dtype):
526501
return cache
527502

528503
def __call__(self, x: jax.Array, cache: Dict[str, Any] = None):
529-
print(f"Entering WanResidualBlock (in={self.conv1.conv.in_features}, out={self.conv1.conv.out_features}): {x.shape}")
530504
if cache is None:
531505
cache = {}
532506
new_cache = {}
533507

534508
h, sc_cache = self.conv_shortcut(x, cache.get("shortcut"))
535509
new_cache["shortcut"] = sc_cache
536-
print(f"WanResidualBlock after shortcut: {h.shape}")
537510

538511
x = self.norm1(x)
539512
x = self.nonlinearity(x)
540-
print(f"WanResidualBlock after norm1/nl: {x.shape}")
541513

542514
x, c1 = self.conv1(x, cache.get("conv1"))
543515
new_cache["conv1"] = c1
544-
print(f"WanResidualBlock after conv1: {x.shape}")
545516

546517
x = self.norm2(x)
547518
x = self.nonlinearity(x)
548-
print(f"WanResidualBlock after norm2/nl: {x.shape}")
549519

550520
x, c2 = self.conv2(x, cache.get("conv2"))
551521
new_cache["conv2"] = c2
552-
print(f"WanResidualBlock after conv2: {x.shape}")
553522

554523
x = x + h
555-
print(f"Exiting WanResidualBlock: {x.shape}")
556524
return x, new_cache
557525

558526

@@ -591,17 +559,14 @@ def __init__(
591559
)
592560

593561
def __call__(self, x: jax.Array):
594-
print(f"Entering WanAttentionBlock: {x.shape}")
595562
identity = x
596563
batch_size, time, height, width, channels = x.shape
597564

598565
x = x.reshape(batch_size * time, height, width, channels)
599-
print(f"WanAttentionBlock reshaped for norm: {x.shape}")
600566
x = self.norm(x)
601567

602568
qkv = self.to_qkv(x) # Output: (N*D, H, W, C * 3)
603569
# qkv = qkv.reshape(batch_size * time, 1, channels * 3, -1)
604-
print(f"WanAttentionBlock qkv shape: {qkv.shape}")
605570
qkv = qkv.reshape(batch_size * time, 1, -1, channels * 3)
606571
qkv = jnp.transpose(qkv, (0, 1, 3, 2))
607572
q, k, v = jnp.split(qkv, 3, axis=-2)
@@ -616,7 +581,6 @@ def __call__(self, x: jax.Array):
616581
# Reshape back
617582
x = x.reshape(batch_size, time, height, width, channels)
618583
out = x + identity
619-
print(f"Exiting WanAttentionBlock: {out.shape}")
620584
return out
621585

622586

@@ -685,26 +649,19 @@ def initialize_cache(self, batch_size, height, width, dtype):
685649
return cache
686650

687651
def __call__(self, x: jax.Array, cache: Dict[str, Any] = None):
688-
print(f"Entering WanMidBlock: {x.shape}")
689652
if cache is None:
690653
cache = {}
691654
new_cache = {"resnets": []}
692655

693656
x, c = self.resnets[0](x, cache.get("resnets", [None])[0])
694657
new_cache["resnets"].append(c)
695-
print(f"WanMidBlock after resnets[0]: {x.shape}")
696658

697659
for i, (attn, resnet) in enumerate(zip(self.attentions, self.resnets[1:])):
698660
if attn is not None:
699-
print(f"WanMidBlock before attn {i}: {x.shape}")
700661
x = attn(x)
701-
print(f"WanMidBlock after attn {i}: {x.shape}")
702-
print(f"WanMidBlock before resnets[{i + 1}]: {x.shape}")
703662
x, c = resnet(x, cache.get("resnets", [None] * len(self.resnets))[i + 1])
704663
new_cache["resnets"].append(c)
705-
print(f"WanMidBlock after resnets[{i + 1}]: {x.shape}")
706664

707-
print(f"Exiting WanMidBlock: {x.shape}")
708665
return x, new_cache
709666

710667

@@ -922,41 +879,32 @@ def init_cache(self, batch_size, height, width, dtype):
922879
return cache
923880

924881
def __call__(self, x: jax.Array, cache: Dict[str, Any] = None):
925-
print(f"Entering WanEncoder3d: {x.shape}")
926882
if cache is None:
927883
cache = {}
928884
new_cache = {}
929885

930886
x, c = self.conv_in(x, cache.get("conv_in"))
931887
new_cache["conv_in"] = c
932-
print(f"WanEncoder3d after conv_in: {x.shape}")
933888

934889
new_cache["down_blocks"] = []
935890
current_down_caches = cache.get("down_blocks", [None] * len(self.down_blocks))
936891

937892
for i, layer in enumerate(self.down_blocks):
938-
print(f"WanEncoder3d before down_block {i} ({type(layer).__name__}): {x.shape}")
939893
if isinstance(layer, (WanResidualBlock, WanResample)):
940894
x, c = layer(x, current_down_caches[i])
941895
new_cache["down_blocks"].append(c)
942896
else:
943897
x = layer(x)
944898
new_cache["down_blocks"].append(None)
945-
print(f"WanEncoder3d after down_block {i}: {x.shape}")
946-
947899

948900
x, c = self.mid_block(x, cache.get("mid_block"))
949901
new_cache["mid_block"] = c
950-
print(f"WanEncoder3d after mid_block: {x.shape}")
951902

952903
x = self.norm_out(x)
953-
print(f"WanEncoder3d after norm_out: {x.shape}")
954904
x = self.nonlinearity(x)
955-
print(f"WanEncoder3d after nonlinearity: {x.shape}")
956905

957906
x, c = self.conv_out(x, cache.get("conv_out"))
958907
new_cache["conv_out"] = c
959-
print(f"Exiting WanEncoder3d: {x.shape}")
960908

961909
return x, new_cache
962910

@@ -1270,8 +1218,7 @@ def _encode_jit(self, x: jax.Array) -> jax.Array:
12701218
# Process the first frame (Time=1)
12711219
x_first = x[:, :1, ...]
12721220
init_cache_first = self.encoder.init_cache(b, h, w, x_first.dtype)
1273-
encoder_checkpointed = jax.checkpoint(self.encoder)
1274-
out1, state_carry = encoder_checkpointed(x_first, init_cache_first)
1221+
out1, state_carry = self.encoder(x_first, init_cache_first)
12751222
all_outs.append(out1)
12761223

12771224
# Process the remaining frames using scan over chunks of 4
@@ -1295,7 +1242,7 @@ def _encode_jit(self, x: jax.Array) -> jax.Array:
12951242
x_scannable = jnp.swapaxes(x_reshaped, 0, 1)
12961243

12971244
def scan_fn(carry_state, x_chunk):
1298-
out_chunk, new_state = encoder_checkpointed(x_chunk, carry_state)
1245+
out_chunk, new_state = self.encoder(x_chunk, carry_state)
12991246
return new_state, out_chunk
13001247

13011248
_, encoded_chunks = jax.lax.scan(scan_fn, state_carry, x_scannable)
@@ -1331,19 +1278,18 @@ def _decode_jit(self, z: jax.Array) -> jax.Array:
13311278

13321279
b, t, h, w, c = x.shape
13331280
init_cache = self.decoder.init_cache(b, h, w, x.dtype)
1334-
decoder_checkpointed = jax.checkpoint(self.decoder)
13351281

13361282
all_decoded = []
13371283
x_first = x[:, :1, ...]
1338-
out_first, state_carry = decoder_checkpointed(x_first, init_cache)
1284+
out_first, state_carry = self.decoder(x_first, init_cache)
13391285
all_decoded.append(out_first)
13401286
if t > 1:
13411287
x_rest = x[:, 1:, ...]
13421288
x_scan = jnp.swapaxes(x_rest, 0, 1)
13431289

13441290
def scan_fn(carry, input_slice):
13451291
input_slice = jnp.expand_dims(input_slice, 1)
1346-
out_slice, new_carry = decoder_checkpointed(input_slice, carry)
1292+
out_slice, new_carry = self.decoder(input_slice, carry)
13471293
out_swapped = out_slice[:, jnp.array([0, 2, 1, 3]), ...]
13481294

13491295
return new_carry, out_swapped

0 commit comments

Comments
 (0)