Skip to content

Fix for XLML tests failing with jax nightly#242

Merged
entrpn merged 1 commit intomainfrom
rbierneni-fix-xlml-test
Sep 8, 2025
Merged

Fix for XLML tests failing with jax nightly#242
entrpn merged 1 commit intomainfrom
rbierneni-fix-xlml-test

Conversation

@Rohan-Bierneni
Copy link
Copy Markdown
Collaborator

Nightly XLML tests have been failing with this message:

jax._src.dtypes.InvalidInputException: Argument 'ShapeDtypeStruct(...)' of type <class 'jax._src.core.ShapeDtypeStruct'> is not a valid JAX type.

These are the logs: https://paste.googleplex.com/4875429911199744

There is a call to x.unbox() within unbox_logicallypartioned_trainstate during the setup_initial_state. The code is doing sharding on an abstract representation of an array, which jax doesn't allow. The fix is to only do the unboxing on concrete arrays and not abstract arrays.

Link to successfull workload run after the change: https://c3aed121fd4247389a3ca4b9c5878779-dot-us-east4.composer.googleusercontent.com/dags/jax_ai_image_candidate_tpu_e2e/grid?tab=graph&dag_run_id=manual__2025-09-05T10%3A19%3A29.084042%2B00%3A00&task_id=maxdiffusion-jax-stable-stack-sdxl-stable-v5-8-1x-v5p-8

@github-actions
Copy link
Copy Markdown

github-actions Bot commented Sep 5, 2025

@suexu1025
Copy link
Copy Markdown
Collaborator

suexu1025 commented Sep 5, 2025 via email

@entrpn entrpn merged commit 8fdf3c2 into main Sep 8, 2025
3 of 4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants