1616
1717from typing import Any , Callable , Dict , List , Optional , Union , Sequence
1818from absl import app
19-
19+ import functools
2020import numpy as np
2121import jax
22+ from jax .sharding import Mesh , PositionalSharding
2223import jax .numpy as jnp
2324from chex import Array
2425from transformers import (
2526 CLIPTokenizer ,
2627 FlaxCLIPTextModel ,
2728 T5TokenizerFast ,
28- T5EncoderModel
29+ T5EncoderModel ,
30+ FlaxT5EncoderModel
2931)
3032
3133from maxdiffusion import FlaxAutoencoderKL
3234from maxdiffusion .models .flux .transformers .transformer_flux_flax import FluxTransformer2DModel
33-
3435from maxdiffusion import pyconfig
36+ from max_utils import (
37+ device_put_replicated ,
38+ get_memory_allocations ,
39+ create_device_mesh ,
40+ get_flash_block_sizes ,
41+ get_precision
42+ )
3543
3644def prepare_latent_image_ids (height , width ):
3745 latent_image_ids = jnp .zeros ((height , width , 3 ))
@@ -133,19 +141,17 @@ def get_t5_prompt_embeds(
133141 truncation = True ,
134142 return_length = False ,
135143 return_overflowing_tokens = False ,
136- return_tensors = "pt "
144+ return_tensors = "np "
137145 )
138146 text_input_ids = text_inputs .input_ids
139-
140147 prompt_embeds = text_encoder (text_input_ids , output_hidden_states = False )[0 ]
141148 dtype = text_encoder .dtype
142- prompt_embeds = prompt_embeds .to ( dtype = dtype )
149+ prompt_embeds = prompt_embeds .astype ( dtype )
143150
144151 _ , seq_len , _ = prompt_embeds .shape
145-
146152 # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
147- prompt_embeds = prompt_embeds . repeat ( 1 , num_images_per_prompt , 1 )
148- prompt_embeds = prompt_embeds . view ( batch_size * num_images_per_prompt , seq_len , - 1 )
153+ prompt_embeds = jnp . tile ( prompt_embeds , ( 1 , num_images_per_prompt , 1 ) )
154+ prompt_embeds = jnp . reshape ( prompt_embeds , ( batch_size * num_images_per_prompt , seq_len , - 1 ) )
149155
150156 return prompt_embeds
151157
@@ -178,15 +184,16 @@ def encode_prompt(
178184 tokenizer = t5_tokenizer ,
179185 text_encoder = t5_text_encoder
180186 )
181- prompt_embeds = jnp .asarray (prompt_embeds .detach ().numpy ())
182187
183188 text_ids = jnp .zeros ((prompt_embeds .shape [0 ], prompt_embeds .shape [1 ], 3 )).astype (jnp .bfloat16 )
184189 return prompt_embeds , pooled_prompt_embeds , text_ids
185190
186191def run (config ):
187192 from maxdiffusion .models .flux .util import load_flow_model
188193
189- rng = jax .random .PRNGKey (config .seed )
194+ rng = jax .random .key (config .seed )
195+ devices_array = create_device_mesh (config )
196+ mesh = Mesh (devices_array , config .mesh_axes )
190197
191198 per_host_number_of_images = 1 #config.per_device_batch_size * jax.local_device_count()
192199
@@ -201,11 +208,18 @@ def run(config):
201208 )
202209 vae_scale_factor = 2 ** (len (vae .config .block_out_channels ) - 1 )
203210
204- # LOAD UNET
205-
211+ # LOAD TRANSFORMER
212+ flash_block_sizes = get_flash_block_sizes ( config )
206213 transformer = FluxTransformer2DModel .from_config (
207214 config .pretrained_model_name_or_path ,
208- subfolder = "transformer"
215+ subfolder = "transformer" ,
216+ mesh = mesh ,
217+ split_head_dim = config .split_head_dim ,
218+ attention_kernel = config .attention ,
219+ flash_block_sizes = flash_block_sizes ,
220+ dtype = config .activations_dtype ,
221+ weights_dtype = config .weights_dtype ,
222+ precision = get_precision (config )
209223 )
210224
211225 num_channels_latents = transformer .in_channels // 4
@@ -242,34 +256,40 @@ def run(config):
242256 dtype = config .weights_dtype
243257 )
244258
245- t5_encoder_pt = T5EncoderModel .from_pretrained (
246- config .pretrained_model_name_or_path ,
247- subfolder = "text_encoder_2" ,
259+ t5_encoder = FlaxT5EncoderModel .from_pretrained (
260+ config .clip_model_name_or_path ,
261+ dtype = config . weights_dtype
248262 )
249-
250263 t5_tokenizer = T5TokenizerFast .from_pretrained (
251264 config .pretrained_model_name_or_path ,
252265 subfolder = "tokenizer_2" ,
253266 )
254267
268+ encoders_sharding = PositionalSharding (devices_array ).replicate ()
269+ partial_device_put_replicated = functools .partial (device_put_replicated , sharding = encoders_sharding )
270+ clip_text_encoder .params = jax .tree_util .tree_map (lambda x : x .astype (jnp .bfloat16 ), clip_text_encoder .params )
271+ clip_text_encoder .params = jax .tree_util .tree_map (partial_device_put_replicated , clip_text_encoder .params )
272+ t5_encoder .params = jax .tree_util .tree_map (lambda x : x .astype (jnp .bfloat16 ), t5_encoder .params )
273+ t5_encoder .params = jax .tree_util .tree_map (partial_device_put_replicated , t5_encoder .params )
274+
255275 prompt_embeds , pooled_prompt_embeds , text_ids = encode_prompt (
256276 prompt = config .prompt ,
257277 prompt_2 = config .prompt_2 ,
258278 clip_tokenizer = clip_tokenizer ,
259279 clip_text_encoder = clip_text_encoder ,
260280 t5_tokenizer = t5_tokenizer ,
261- t5_text_encoder = t5_encoder_pt ,
281+ t5_text_encoder = t5_encoder ,
262282 num_images_per_prompt = per_host_number_of_images
263283 )
264284
265285 def validate_inputs (latents , latent_image_ids , prompt_embeds , text_ids , timesteps , guidance , pooled_prompt_embeds ):
266- print ("latents.shape: " , latents .shape )
267- print ("latent_image_ids.shape: " , latent_image_ids .shape )
268- print ("text_ids.shape: " , text_ids .shape )
269- print ("prompt_embeds: " , prompt_embeds .shape )
270- print ("timesteps.shape: " , timesteps .shape )
271- print ("guidance.shape: " , guidance .shape )
272- print ("pooled_prompt_embeds.shape: " , pooled_prompt_embeds .shape )
286+ print ("latents.shape: " , latents .shape , latents . dtype )
287+ print ("latent_image_ids.shape: " , latent_image_ids .shape , latent_image_ids . dtype )
288+ print ("text_ids.shape: " , text_ids .shape , text_ids . dtype )
289+ print ("prompt_embeds: " , prompt_embeds .shape , prompt_embeds . dtype )
290+ print ("timesteps.shape: " , timesteps .shape , timesteps . dtype )
291+ print ("guidance.shape: " , guidance .shape , guidance . dtype )
292+ print ("pooled_prompt_embeds.shape: " , pooled_prompt_embeds .shape , pooled_prompt_embeds . dtype )
273293
274294 timesteps = jnp .asarray ([1.0 ], dtype = jnp .bfloat16 )
275295 guidance = jnp .asarray ([3.5 ], dtype = jnp .bfloat16 )
@@ -282,17 +302,19 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep
282302 guidance ,
283303 pooled_prompt_embeds
284304 )
285-
286- transformer_params = transformer .init (
287- {"params" : rng },
288- img = latents ,
289- img_ids = latent_image_ids ,
290- txt = prompt_embeds ,
291- txt_ids = text_ids ,
292- timesteps = timesteps ,
293- guidance = guidance ,
294- y = pooled_prompt_embeds
295- )["params" ]
305+ get_memory_allocations ()
306+ transformer_params = transformer .init_weights (rng , True )
307+ # transformer_params = transformer.init(
308+ # {"params" : rng},
309+ # img=latents,
310+ # img_ids=latent_image_ids,
311+ # txt=prompt_embeds,
312+ # txt_ids=text_ids,
313+ # timesteps=timesteps,
314+ # guidance=guidance,
315+ # y=pooled_prompt_embeds
316+ # )["params"]
317+ get_memory_allocations ()
296318 breakpoint ()
297319
298320
0 commit comments