Skip to content

Commit 1b36fa5

Browse files
committed
test_batch_logic.py file added
1 parent 2724462 commit 1b36fa5

1 file changed

Lines changed: 35 additions & 0 deletions

File tree

test_batch_logic.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
class MockPipeline:
2+
def __init__(self):
3+
pass
4+
5+
def run_logic(self, prompt, num_videos_per_prompt=1, guidance_scale=3.0):
6+
# 2. Encode inputs (Text)
7+
if isinstance(prompt, str):
8+
_bs0 = 1
9+
elif isinstance(prompt, list):
10+
_bs0 = len(prompt)
11+
12+
# Simulate encode_prompt output shapes
13+
prompt_embeds_shape = (_bs0 * num_videos_per_prompt, 10)
14+
negative_prompt_embeds_shape = (_bs0 * num_videos_per_prompt, 10)
15+
16+
# 3. Prepare latents
17+
_bs = prompt_embeds_shape[0]
18+
batch_size = _bs // 2 if guidance_scale > 1.0 else _bs
19+
print(f"Evaluated true batch_size: {batch_size}")
20+
21+
# 6. Prepare JAX State
22+
latents_jax_shape = (batch_size, 5)
23+
prompt_embeds_jax_shape = prompt_embeds_shape
24+
negative_prompt_embeds_jax_shape = negative_prompt_embeds_shape
25+
26+
if guidance_scale > 1.0:
27+
prompt_embeds_jax_shape = (negative_prompt_embeds_jax_shape[0] + prompt_embeds_jax_shape[0], 10)
28+
latents_jax_shape = (latents_jax_shape[0] * 2, 5)
29+
30+
print(f"latents_jax shape during generation: {latents_jax_shape}")
31+
print(f"prompt_embeds_jax shape during generation: {prompt_embeds_jax_shape}")
32+
print(f"Videos Decoded: {latents_jax_shape[0] - batch_size}")
33+
34+
p = MockPipeline()
35+
p.run_logic(["Prompt"] * 8)

0 commit comments

Comments
 (0)