Skip to content

Commit 8cd7dbe

Browse files
committed
conversion done
1 parent 0eb3303 commit 8cd7dbe

3 files changed

Lines changed: 1 addition & 445 deletions

File tree

src/maxdiffusion/generate_ltx_video.py

Lines changed: 0 additions & 138 deletions
Original file line numberDiff line numberDiff line change
@@ -28,50 +28,6 @@ def validate_transformer_inputs(prompt_embeds, fractional_coords, latents, noise
2828
print("segment_ids.shape: ", segment_ids.shape, segment_ids.dtype)
2929
print("encoder_attention_segment_ids.shape: ", encoder_attention_segment_ids.shape, encoder_attention_segment_ids.dtype)
3030

31-
32-
def loop_body(
33-
step,
34-
args,
35-
transformer,
36-
fractional_cords,
37-
prompt_embeds,
38-
segment_ids,
39-
encoder_attention_segment_ids
40-
):
41-
latents, state, noise_cond = args
42-
noise_pred = transformer.apply(
43-
{"params": state.params},
44-
hidden_states=latents,
45-
indices_grid=fractional_cords,
46-
encoder_hidden_states=prompt_embeds,
47-
timestep=noise_cond,
48-
segment_ids=segment_ids,
49-
encoder_attention_segment_ids=encoder_attention_segment_ids
50-
)
51-
import pdb; pdb.set_trace()
52-
return noise_pred, state, noise_cond #need to make changes here? latents need to be changed based on noise_pred, but needs scheduler, return noise_pred for now
53-
54-
55-
56-
def run_inference(
57-
states, transformer, config, mesh, latents, fractional_cords, prompt_embeds, timestep, segment_ids, encoder_attention_segment_ids
58-
):
59-
transformer_state = states["transformer"]
60-
loop_body_p = functools.partial(
61-
loop_body,
62-
transformer=transformer,
63-
fractional_cords=fractional_cords,
64-
prompt_embeds=prompt_embeds,
65-
segment_ids=segment_ids,
66-
encoder_attention_segment_ids=encoder_attention_segment_ids
67-
)
68-
## TODO: add vae decode step
69-
## TODO: add loop
70-
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
71-
latents, transformer_state, _ = jax.lax.fori_loop(0, 1, loop_body_p, (latents, transformer_state, timestep))
72-
return latents
73-
74-
7531
def run(config):
7632
key = jax.random.PRNGKey(0)
7733

@@ -119,92 +75,6 @@ def run(config):
11975
)
12076

12177

122-
123-
124-
transformer_state = jax.device_put(transformer_state, transformer_state_shardings)
125-
get_memory_allocations()
126-
127-
states = {}
128-
state_shardings = {}
129-
130-
state_shardings["transformer"] = transformer_state_shardings
131-
states["transformer"] = transformer_state
132-
133-
#create dummy inputs:
134-
example_inputs = {}
135-
batch_size, num_tokens = 4, 256
136-
input_shapes = {
137-
"latents": (batch_size, num_tokens, in_channels),
138-
"fractional_coords": (batch_size, 3, num_tokens),
139-
"prompt_embeds": (batch_size, 128, model_config["caption_channels"]),
140-
"timestep": (batch_size, 256), #TODO: add in the segment id stuff
141-
"segment_ids": (batch_size, 256),
142-
"encoder_attention_segment_ids": (batch_size, 128),
143-
}
144-
for name, shape in input_shapes.items():
145-
example_inputs[name] = jnp.ones(
146-
shape, dtype=jnp.float32 if name not in ["attention_mask", "encoder_attention_mask"] else jnp.bool
147-
)
148-
149-
data_sharding = jax.sharding.NamedSharding(mesh, P(*config.data_sharding))
150-
latents = jax.device_put(example_inputs["latents"], data_sharding)
151-
prompt_embeds = jax.device_put(example_inputs["prompt_embeds"], data_sharding)
152-
fractional_coords = jax.device_put(example_inputs["fractional_coords"], data_sharding)
153-
noise_cond = jax.device_put(example_inputs["timestep"], data_sharding)
154-
segment_ids = jax.device_put(example_inputs["segment_ids"], data_sharding)
155-
encoder_attention_segment_ids = jax.device_put(example_inputs["encoder_attention_segment_ids"], data_sharding)
156-
157-
validate_transformer_inputs(prompt_embeds, fractional_coords, latents, noise_cond, segment_ids, encoder_attention_segment_ids)
158-
p_run_inference = jax.jit(
159-
functools.partial(
160-
run_inference,
161-
transformer=transformer,
162-
config=config,
163-
mesh=mesh,
164-
latents=latents,
165-
fractional_cords=fractional_coords,
166-
prompt_embeds=prompt_embeds,
167-
timestep = noise_cond,
168-
segment_ids=segment_ids,
169-
encoder_attention_segment_ids=encoder_attention_segment_ids
170-
),
171-
in_shardings=(state_shardings,),
172-
out_shardings=None,
173-
)
174-
noise_pred = p_run_inference(states).block_until_ready()
175-
print(noise_pred) #(4, 256, 128)
176-
177-
178-
179-
180-
181-
182-
183-
184-
185-
186-
187-
188-
189-
190-
191-
192-
193-
194-
195-
196-
197-
198-
199-
200-
201-
202-
203-
204-
205-
206-
207-
20878

20979

21080

@@ -219,12 +89,4 @@ def main(argv: Sequence[str]) -> None:
21989

22090

22191

222-
###setup_initial_state, can optionally load from checkpoint
223-
224-
225-
226-
227-
228-
22992

230-
#end to end steps from ltx repo: pipeline_ltx_video.py

src/maxdiffusion/models/attention_flax.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1188,4 +1188,4 @@ def setup(self):
11881188
def __call__(self, hidden_states, deterministic=True):
11891189
hidden_states = self.proj(hidden_states)
11901190
hidden_linear, hidden_gelu = jnp.split(hidden_states, 2, axis=2)
1191-
return self.dropout_layer(hidden_linear * nn.gelu(hidden_gelu), deterministic=deterministic)
1191+
return self.dropout_layer(hidden_linear * nn.gelu(hidden_gelu), deterministic=deterministic)

0 commit comments

Comments
 (0)