Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements_with_jax_ai_image.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ orbax-checkpoint
tokenizers==0.21.0
huggingface_hub>=0.30.2
transformers==4.48.1
tokamax
einops==0.8.0
sentencepiece
aqtp
Expand Down
40 changes: 36 additions & 4 deletions src/maxdiffusion/generate_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,11 @@ def inference_generate_video(config, pipeline, filename_prefix=""):
def run(config, pipeline=None, filename_prefix=""):
print("seed: ", config.seed)
Comment thread
prishajain1 marked this conversation as resolved.
Outdated
model_key = config.model_name
tensorboard_dir = os.path.join(config.output_dir, "tensorboard")
Comment thread
prishajain1 marked this conversation as resolved.
Outdated
# Initialize TensorBoard writer
writer = max_utils.initialize_summary_writer(config)
if jax.process_index() == 0 and writer:
max_logging.log(f"TensorBoard logs will be written to: {tensorboard_dir}")

checkpointer_lib = get_checkpointer(model_key)
WanCheckpointer = checkpointer_lib.WanCheckpointer
Expand All @@ -163,8 +168,19 @@ def run(config, pipeline=None, filename_prefix=""):
)

videos = call_pipeline(config, pipeline, prompt, negative_prompt)

print("compile time: ", (time.perf_counter() - s0))
print("===================== Model details =======================")
Comment thread
prishajain1 marked this conversation as resolved.
Outdated
print("model name: ", config.model_name)
print("model path: ", config.pretrained_model_name_or_path)
print("model type: t2v")
print("hardware: ", jax.devices()[0].platform)
print("number of devices: ", jax.device_count())
print("per_device_batch_size: ", config.per_device_batch_size)
print("============================================================")

compile_time = time.perf_counter() - s0
print("compile_time: ", compile_time)
if writer and jax.process_index() == 0:
writer.add_scalar("inference/compile_time", compile_time, global_step=0)
saved_video_path = []
for i in range(len(videos)):
video_path = f"{filename_prefix}wan_output_{config.seed}_{i}.mp4"
Expand All @@ -175,14 +191,30 @@ def run(config, pipeline=None, filename_prefix=""):

s0 = time.perf_counter()
videos = call_pipeline(config, pipeline, prompt, negative_prompt)
print("generation time: ", (time.perf_counter() - s0))
generation_time = time.perf_counter() - s0
print("generation_time: ", generation_time)
if writer and jax.process_index() == 0:
writer.add_scalar("inference/generation_time", generation_time, global_step=0)
num_devices = jax.device_count()
num_videos = num_devices * config.per_device_batch_size
if num_videos > 0:
generation_time_per_video = generation_time / num_videos
writer.add_scalar("inference/generation_time_per_video", generation_time_per_video, global_step=0)
print(f"generation time per video: {generation_time_per_video}")
else:
max_logging.log("Warning: Number of videos is zero, cannot calculate generation_time_per_video.")


s0 = time.perf_counter()
if config.enable_profiler:
max_utils.activate_profiler(config)
videos = call_pipeline(config, pipeline, prompt, negative_prompt)
max_utils.deactivate_profiler(config)
print("generation time: ", (time.perf_counter() - s0))
generation_time_with_profiler = time.perf_counter() - s0
print("generation_time_with_profiler: ", generation_time_with_profiler)
if writer and jax.process_index() == 0:
writer.add_scalar("inference/generation_time_with_profiler", generation_time_with_profiler, global_step=0)

return saved_video_path


Expand Down
Loading