Skip to content

Commit 33fd50e

Browse files
committed
Add train step name and setup.sh install transformer jax from pip
Signed-off-by: kunjan <kunjanp@google.com>
1 parent a9069d6 commit 33fd50e

2 files changed

Lines changed: 12 additions & 8 deletions

File tree

setup.sh

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,7 @@ if [[ "$MODE" == "stable" || ! -v MODE ]]; then
7777
echo "Installing stable jax, jaxlib, libtpu for NVIDIA gpu"
7878
pip3 install "jax[cuda12]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
7979
fi
80-
export NVTE_FRAMEWORK=jax
81-
pip3 install git+https://github.com/NVIDIA/TransformerEngine.git@stable
80+
pip install "transformer_engine[jax]"
8281
fi
8382

8483
elif [[ $MODE == "nightly" ]]; then
@@ -88,8 +87,7 @@ elif [[ $MODE == "nightly" ]]; then
8887
# Install jax-nightly
8988
pip install -U --pre jax jaxlib jax-cuda12-plugin[with_cuda] jax-cuda12-pjrt -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
9089
# Install Transformer Engine
91-
export NVTE_FRAMEWORK=jax
92-
pip3 install git+https://github.com/NVIDIA/TransformerEngine.git@stable
90+
pip install "transformer_engine[jax]"
9391
elif [[ $DEVICE == "tpu" ]]; then
9492
echo "Installing jax-nightly,jaxlib-nightly"
9593
# Install jax-nightly

src/maxdiffusion/trainers/sdxl_trainer.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -109,15 +109,19 @@ def load_dataset(self, pipeline, params, train_states):
109109
p_vae_apply = None
110110
rng = None
111111
if config.dataset_type == "tf" and config.cache_latents_text_encoder_outputs:
112-
p_encode = jax.jit(
113-
partial(
112+
text_encoder_partial = partial(
114113
maxdiffusion_utils.encode_xl,
115114
text_encoders=[pipeline.text_encoder, pipeline.text_encoder_2],
116115
text_encoder_params=[train_states["text_encoder_state"].params, train_states["text_encoder_2_state"].params],
117116
)
117+
text_encoder_partial.__name__="Text encoder"
118+
p_encode = jax.jit(
119+
text_encoder_partial
118120
)
121+
vae_partial = partial(maxdiffusion_utils.vae_apply, vae=pipeline.vae, vae_params=train_states["vae_state"].params)
122+
vae_partial.__name__="VAE Partial"
119123
p_vae_apply = jax.jit(
120-
partial(maxdiffusion_utils.vae_apply, vae=pipeline.vae, vae_params=train_states["vae_state"].params)
124+
vae_partial
121125
)
122126
rng = self.rng
123127

@@ -152,8 +156,10 @@ def compile_train_step(self, pipeline, params, train_states, state_shardings, da
152156

153157
self.rng, train_rngs = jax.random.split(self.rng)
154158
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
159+
train_step_partial = partial(_train_step, pipeline=pipeline, params=params, config=self.config)
160+
train_step_partial.__name__ = "Train Step"
155161
p_train_step = jax.jit(
156-
partial(_train_step, pipeline=pipeline, params=params, config=self.config),
162+
train_step_partial,
157163
in_shardings=(
158164
state_shardings["unet_state_shardings"],
159165
state_shardings["vae_state_shardings"],

0 commit comments

Comments
 (0)