Skip to content

Commit e8734f4

Browse files
committed
reformatted
1 parent c1446ba commit e8734f4

6 files changed

Lines changed: 2317 additions & 2268 deletions

File tree

src/maxdiffusion/models/embeddings_flax.py

Lines changed: 28 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -504,6 +504,7 @@ def __call__(self, timestep, guidance, pooled_projection):
504504

505505

506506
class NNXTimesteps(nnx.Module):
507+
507508
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1):
508509
self.num_channels = num_channels
509510
self.flip_sin_to_cos = flip_sin_to_cos
@@ -516,75 +517,64 @@ def __call__(self, timesteps: jax.Array) -> jax.Array:
516517
embedding_dim=self.num_channels,
517518
freq_shift=self.downscale_freq_shift,
518519
flip_sin_to_cos=self.flip_sin_to_cos,
519-
scale=self.scale
520+
scale=self.scale,
520521
)
521522

522523

523524
class NNXPixArtAlphaCombinedTimestepSizeEmbeddings(nnx.Module):
525+
524526
def __init__(
525527
self,
526528
rngs: nnx.Rngs,
527529
embedding_dim: int,
528530
size_emb_dim: int,
529531
use_additional_conditions: bool = False,
530532
dtype: jnp.dtype = jnp.float32,
531-
weights_dtype: jnp.dtype = jnp.float32
533+
weights_dtype: jnp.dtype = jnp.float32,
532534
):
533535
self.outdim = size_emb_dim
534536
self.use_additional_conditions = use_additional_conditions
535537

536538
self.time_proj = NNXTimesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
537539
self.timestep_embedder = NNXTimestepEmbedding(
538-
rngs=rngs,
539-
in_channels=256,
540-
time_embed_dim=embedding_dim,
541-
dtype=dtype,
542-
weights_dtype=weights_dtype
540+
rngs=rngs, in_channels=256, time_embed_dim=embedding_dim, dtype=dtype, weights_dtype=weights_dtype
543541
)
544542

545543
if use_additional_conditions:
546-
self.additional_condition_proj = NNXTimesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
547-
self.resolution_embedder = NNXTimestepEmbedding(
548-
rngs=rngs,
549-
in_channels=256,
550-
time_embed_dim=size_emb_dim,
551-
dtype=dtype,
552-
weights_dtype=weights_dtype
553-
)
554-
self.aspect_ratio_embedder = NNXTimestepEmbedding(
555-
rngs=rngs,
556-
in_channels=256,
557-
time_embed_dim=size_emb_dim,
558-
dtype=dtype,
559-
weights_dtype=weights_dtype
560-
)
544+
self.additional_condition_proj = NNXTimesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
545+
self.resolution_embedder = NNXTimestepEmbedding(
546+
rngs=rngs, in_channels=256, time_embed_dim=size_emb_dim, dtype=dtype, weights_dtype=weights_dtype
547+
)
548+
self.aspect_ratio_embedder = NNXTimestepEmbedding(
549+
rngs=rngs, in_channels=256, time_embed_dim=size_emb_dim, dtype=dtype, weights_dtype=weights_dtype
550+
)
561551

562552
def __call__(
563553
self,
564554
timestep: jax.Array,
565555
resolution: Optional[jax.Array] = None,
566556
aspect_ratio: Optional[jax.Array] = None,
567-
hidden_dtype: jnp.dtype = jnp.float32
557+
hidden_dtype: jnp.dtype = jnp.float32,
568558
) -> jax.Array:
569559
timesteps_proj = self.time_proj(timestep)
570560
timesteps_emb = self.timestep_embedder(timesteps_proj.astype(hidden_dtype))
571561

572562
if self.use_additional_conditions:
573-
if resolution is None or aspect_ratio is None:
574-
raise ValueError("resolution and aspect_ratio must be provided when use_additional_conditions is True")
575-
576-
resolution_emb = self.additional_condition_proj(resolution.flatten()).astype(hidden_dtype)
577-
resolution_emb = self.resolution_embedder(resolution_emb)
578-
# Reshape to (batch_size, -1) matching PyTorch's reshape(batch_size, -1)
579-
# assuming resolution input was (batch_size, ...) so flatten logic holds.
580-
resolution_emb = resolution_emb.reshape(timestep.shape[0], -1)
581-
582-
aspect_ratio_emb = self.additional_condition_proj(aspect_ratio.flatten()).astype(hidden_dtype)
583-
aspect_ratio_emb = self.aspect_ratio_embedder(aspect_ratio_emb)
584-
aspect_ratio_emb = aspect_ratio_emb.reshape(timestep.shape[0], -1)
585-
586-
conditioning = timesteps_emb + jnp.concatenate([resolution_emb, aspect_ratio_emb], axis=1)
563+
if resolution is None or aspect_ratio is None:
564+
raise ValueError("resolution and aspect_ratio must be provided when use_additional_conditions is True")
565+
566+
resolution_emb = self.additional_condition_proj(resolution.flatten()).astype(hidden_dtype)
567+
resolution_emb = self.resolution_embedder(resolution_emb)
568+
# Reshape to (batch_size, -1) matching PyTorch's reshape(batch_size, -1)
569+
# assuming resolution input was (batch_size, ...) so flatten logic holds.
570+
resolution_emb = resolution_emb.reshape(timestep.shape[0], -1)
571+
572+
aspect_ratio_emb = self.additional_condition_proj(aspect_ratio.flatten()).astype(hidden_dtype)
573+
aspect_ratio_emb = self.aspect_ratio_embedder(aspect_ratio_emb)
574+
aspect_ratio_emb = aspect_ratio_emb.reshape(timestep.shape[0], -1)
575+
576+
conditioning = timesteps_emb + jnp.concatenate([resolution_emb, aspect_ratio_emb], axis=1)
587577
else:
588-
conditioning = timesteps_emb
578+
conditioning = timesteps_emb
589579

590580
return conditioning

0 commit comments

Comments
 (0)