Skip to content

Commit 67cfc0f

Browse files
committed
abstracting tpu type logic
1 parent d57b0e5 commit 67cfc0f

2 files changed

Lines changed: 24 additions & 2 deletions

File tree

src/maxdiffusion/models/ltx2/attention_ltx2.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import jax.numpy as jnp
2121
from ... import common_types
2222
from ..attention_flax import NNXAttentionOp
23+
from maxdiffusion.tpu_utils import get_tpu_type, TpuType
2324

2425
Array = common_types.Array
2526
Mesh = common_types.Mesh
@@ -360,8 +361,8 @@ def __init__(
360361
self.dropout_rate = dropout
361362

362363
# Auto-detect hardware for sharding specs if not overridden
363-
device_kind = jax.devices()[0].device_kind
364-
is_ironwood = "7x" in device_kind
364+
tpu_type = get_tpu_type()
365+
is_ironwood = tpu_type == TpuType.TPU_7X
365366

366367
if qkv_sharding_spec is None:
367368
qkv_sharding_spec = (None, "heads") if is_ironwood else ("embed", "heads")

src/maxdiffusion/tpu_utils.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"""
1616

1717
import jax
18+
from enum import Enum
1819

1920

2021
def print_device_memory_info(devices):
@@ -42,3 +43,23 @@ def print_array_info(array, name):
4243
for device_idx in num_devices:
4344
jax.debug.print("shape on device {x} : {y}", x=device_idx, y=array.device_buffers[0].shape)
4445
jax.debug.print("size on device {x} : {y}", x=device_idx, y=array.device_buffers[device_idx].size / array.size)
46+
47+
48+
class TpuType(Enum):
49+
TPU_V6_LITE = "v6e"
50+
TPU_7X = "v7x"
51+
UNKNOWN = "unknown"
52+
53+
54+
def get_tpu_type() -> TpuType:
55+
"""Detects the current TPU hardware generation."""
56+
try:
57+
device_kind = jax.devices()[0].device_kind
58+
if "7x" in device_kind:
59+
return TpuType.TPU_7X
60+
elif "v6 lite" in device_kind:
61+
return TpuType.TPU_V6_LITE
62+
else:
63+
return TpuType.UNKNOWN
64+
except Exception:
65+
return TpuType.UNKNOWN

0 commit comments

Comments
 (0)