|
18 | 18 | from functools import partial |
19 | 19 | import pathlib |
20 | 20 | import shutil |
| 21 | +import subprocess |
21 | 22 | import unittest |
22 | 23 | from absl.testing import absltest |
23 | 24 |
|
@@ -425,13 +426,68 @@ def test_make_pokemon_iterator_sdxl_cache(self): |
425 | 426 | config.resolution // vae_scale_factor, |
426 | 427 | ) |
427 | 428 |
|
| 429 | + def test_make_laion_grain_iterator(self): |
| 430 | + try: |
| 431 | + subprocess.check_output( |
| 432 | + [ |
| 433 | + "bash", |
| 434 | + "setup_gcsfuse.sh", |
| 435 | + "DATASET_GCS_BUCKET=maxdiffusion-github-runner-test-assets", |
| 436 | + "MOUNT_PATH=/tmp/gcsfuse", |
| 437 | + ], |
| 438 | + stderr=subprocess.STDOUT, |
| 439 | + ) |
| 440 | + except subprocess.CalledProcessError as e: |
| 441 | + raise ValueError(f"setup_gcsfuse failed with error: {e.output}") from e |
| 442 | + pyconfig.initialize( |
| 443 | + [ |
| 444 | + None, |
| 445 | + os.path.join(THIS_DIR, "..", "configs", "base_2_base.yml"), |
| 446 | + "grain_train_files=/tmp/gcsfuse/datasets/array-record/laion400m/tf_records_512_encoder_state_fp32/*.arrayrecord", |
| 447 | + "dataset_type=grain", |
| 448 | + ], |
| 449 | + unittest=True, |
| 450 | + ) |
| 451 | + config = pyconfig.config |
| 452 | + global_batch_size = config.per_device_batch_size * jax.device_count() |
| 453 | + devices_array = max_utils.create_device_mesh(config) |
| 454 | + mesh = Mesh(devices_array, config.mesh_axes) |
| 455 | + |
| 456 | + pipeline, _ = FlaxStableDiffusionPipeline.from_pretrained( |
| 457 | + config.pretrained_model_name_or_path, |
| 458 | + revision=config.revision, |
| 459 | + dtype=config.activations_dtype, |
| 460 | + safety_checker=None, |
| 461 | + feature_extractor=None, |
| 462 | + from_pt=config.from_pt, |
| 463 | + ) |
| 464 | + |
| 465 | + train_iterator = make_data_iterator(config, jax.process_index(), jax.process_count(), mesh, global_batch_size) |
| 466 | + data = next(train_iterator) |
| 467 | + device_count = jax.device_count() |
| 468 | + |
| 469 | + vae_scale_factor = 2 ** (len(pipeline.vae.config.block_out_channels) - 1) |
| 470 | + encoder_hidden_states = data["input_ids"] |
| 471 | + |
| 472 | + # TODO - laion dataset was prepared with an extra dim. |
| 473 | + # need to preprocess the dataset with dim removed. |
| 474 | + if len(encoder_hidden_states.shape) == 4: |
| 475 | + encoder_hidden_states = jnp.squeeze(encoder_hidden_states) |
| 476 | + |
| 477 | + assert encoder_hidden_states.shape == (device_count, 77, 1024) |
| 478 | + assert data["pixel_values"].shape == ( |
| 479 | + config.total_train_batch_size, |
| 480 | + config.resolution // vae_scale_factor, |
| 481 | + config.resolution // vae_scale_factor, |
| 482 | + 8, |
| 483 | + ) |
| 484 | + |
428 | 485 | def test_make_laion_tfrecord_iterator(self): |
429 | 486 | pyconfig.initialize( |
430 | 487 | [ |
431 | 488 | None, |
432 | 489 | os.path.join(THIS_DIR, "..", "configs", "base_2_base.yml"), |
433 | | - "cache_latents_text_encoder_outputs=True", |
434 | | - "train_data_dir=gs://jfacevedo-maxdiffusion/laion400m/processed/laion400m_tfrec", |
| 490 | + "train_data_dir=gs://jfacevedo-maxdiffusion/laion400m/raw_data/tf_records_512_encoder_state_fp32", |
435 | 491 | "dataset_type=tfrecord", |
436 | 492 | ], |
437 | 493 | unittest=True, |
@@ -464,10 +520,10 @@ def test_make_laion_tfrecord_iterator(self): |
464 | 520 |
|
465 | 521 | assert encoder_hidden_states.shape == (device_count, 77, 1024) |
466 | 522 | assert data["pixel_values"].shape == ( |
467 | | - device_count, |
468 | | - pipeline.unet.config.in_channels, |
| 523 | + config.total_train_batch_size, |
469 | 524 | config.resolution // vae_scale_factor, |
470 | 525 | config.resolution // vae_scale_factor, |
| 526 | + 8, |
471 | 527 | ) |
472 | 528 |
|
473 | 529 | def test_tfrecord(self): |
|
0 commit comments