Skip to content

Commit b4fbdf5

Browse files
force 1 block for flux training.
1 parent b84fc34 commit b4fbdf5

2 files changed

Lines changed: 9 additions & 7 deletions

File tree

src/maxdiffusion/checkpointing/flux_checkpointer.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def create_flux_state(self, pipeline, params, checkpoint_item_name, is_training)
8787
rngs=self.rng, max_sequence_length=self.config.max_sequence_length, eval_only=True
8888
)
8989

90-
transformer_params = load_flow_model(self.config.flux_name, transformer_eval_params, "cpu")
90+
#transformer_params = load_flow_model(self.config.flux_name, transformer_eval_params, "cpu")
9191

9292
weights_init_fn = functools.partial(
9393
pipeline.flux.init_weights, rngs=self.rng, max_sequence_length=self.config.max_sequence_length
@@ -103,9 +103,9 @@ def create_flux_state(self, pipeline, params, checkpoint_item_name, is_training)
103103
checkpoint_item=checkpoint_item_name,
104104
training=is_training,
105105
)
106-
if not self.config.train_new_flux:
107-
flux_state = flux_state.replace(params=transformer_params)
108-
flux_state = jax.device_put(flux_state, state_mesh_shardings)
106+
# if not self.config.train_new_flux:
107+
# flux_state = flux_state.replace(params=transformer_params)
108+
# flux_state = jax.device_put(flux_state, state_mesh_shardings)
109109
return flux_state, state_mesh_shardings, learning_rate_scheduler
110110

111111
def create_vae_state(self, pipeline, params, checkpoint_item_name, is_training=False):
@@ -217,12 +217,13 @@ def load_diffusers_checkpoint(self):
217217
dtype=self.config.activations_dtype,
218218
weights_dtype=self.config.weights_dtype,
219219
precision=max_utils.get_precision(self.config),
220+
num_layers=1
220221
)
221-
transformer_eval_params = transformer.init_weights(
222-
rngs=self.rng, max_sequence_length=self.config.max_sequence_length, eval_only=True
222+
transformer_params = transformer.init_weights(
223+
rngs=self.rng, max_sequence_length=self.config.max_sequence_length, eval_only=False
223224
)
224225

225-
transformer_params = load_flow_model(self.config.flux_name, transformer_eval_params, "cpu")
226+
#transformer_params = load_flow_model(self.config.flux_name, transformer_eval_params, "cpu")
226227

227228
pipeline = FluxPipeline(
228229
t5_encoder,

src/maxdiffusion/models/attention_flax.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -524,6 +524,7 @@ class AttentionOp(nn.Module):
524524
quant: Quant = None
525525

526526
def setup(self):
527+
self.dpa_layer = None
527528
if self.attention_kernel == "cudnn_flash_te":
528529
from transformer_engine.jax.flax.transformer import DotProductAttention # pytype: disable=import-error
529530

0 commit comments

Comments
 (0)