We started to have pspec error when run Maxdiffusion with new JAX version like 0.6:
File "/opt/maxdiffusion/src/maxdiffusion/train_flux.py", line 36, in train
trainer.start_training()
File "/opt/maxdiffusion/src/maxdiffusion/trainers/flux_trainer.py", line 138, in start_training
p_train_step = self.compile_train_step(pipeline, params, train_states, state_shardings, data_shardings)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/maxdiffusion/src/maxdiffusion/trainers/flux_trainer.py", line 309, in compile_train_step
p_train_step = p_train_step.lower(train_states[FLUX_STATE_KEY], dummy_batch, train_rngs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: pspec PartitionSpec(('data', 'fsdp', 'tensor'), None) contains a manual axes ('data', 'fsdp', 'tensor') of mesh which is not allowed. If you are using a with_sharding_constraint under a shard_map, only use the mesh axis in PartitionSpec which are not manual.
It's likely due to API change in new JAX and Maxdiffusion may need to change accordingly.
We started to have pspec error when run Maxdiffusion with new JAX version like 0.6:
It's likely due to API change in new JAX and Maxdiffusion may need to change accordingly.