Skip to content

Commit a456d10

Browse files
committed
changes for batch size being correct
1 parent 1b36fa5 commit a456d10

2 files changed

Lines changed: 5 additions & 6 deletions

File tree

src/maxdiffusion/configs/ltx2_video.yml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,9 @@ dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded
6161
dcn_fsdp_parallelism: -1
6262
dcn_context_parallelism: 1
6363
dcn_tensor_parallelism: 1
64-
ici_data_parallelism: 1
65-
ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded
66-
ici_context_parallelism: 1
64+
ici_data_parallelism: 2
65+
ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded
66+
ici_context_parallelism: 4
6767
ici_tensor_parallelism: 1
6868
enable_profiler: False
6969

@@ -81,7 +81,7 @@ dataset_name: 'diffusers/pokemon-gpt4-captions'
8181
train_split: 'train'
8282
dataset_type: 'tfrecord'
8383
cache_latents_text_encoder_outputs: True
84-
per_device_batch_size: 1
84+
per_device_batch_size: 0.25
8585
compile_topology_num_slices: -1
8686
quantization_local_shard_count: -1
8787
use_qwix_quantization: False

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1110,8 +1110,7 @@ def __call__(
11101110
)
11111111

11121112
# 3. Prepare latents
1113-
_bs = prompt_embeds[0].shape[0] if isinstance(prompt_embeds, list) else prompt_embeds.shape[0]
1114-
batch_size = _bs // 2 if guidance_scale > 1.0 else _bs
1113+
batch_size = prompt_embeds[0].shape[0] if isinstance(prompt_embeds, list) else prompt_embeds.shape[0]
11151114

11161115
# Prepare generators
11171116
if generator is None:

0 commit comments

Comments
 (0)