File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change 2020import jax .numpy as jnp
2121from ... import common_types
2222from ..attention_flax import NNXAttentionOp
23+ from maxdiffusion .tpu_utils import get_tpu_type , TpuType
2324
2425Array = common_types .Array
2526Mesh = common_types .Mesh
@@ -360,8 +361,8 @@ def __init__(
360361 self .dropout_rate = dropout
361362
362363 # Auto-detect hardware for sharding specs if not overridden
363- device_kind = jax . devices ()[ 0 ]. device_kind
364- is_ironwood = "7x" in device_kind
364+ tpu_type = get_tpu_type ()
365+ is_ironwood = tpu_type == TpuType . TPU_7X
365366
366367 if qkv_sharding_spec is None :
367368 qkv_sharding_spec = (None , "heads" ) if is_ironwood else ("embed" , "heads" )
Original file line number Diff line number Diff line change 1515"""
1616
1717import jax
18+ from enum import Enum
1819
1920
2021def print_device_memory_info (devices ):
@@ -42,3 +43,23 @@ def print_array_info(array, name):
4243 for device_idx in num_devices :
4344 jax .debug .print ("shape on device {x} : {y}" , x = device_idx , y = array .device_buffers [0 ].shape )
4445 jax .debug .print ("size on device {x} : {y}" , x = device_idx , y = array .device_buffers [device_idx ].size / array .size )
46+
47+
48+ class TpuType (Enum ):
49+ TPU_V6_LITE = "v6e"
50+ TPU_7X = "v7x"
51+ UNKNOWN = "unknown"
52+
53+
54+ def get_tpu_type () -> TpuType :
55+ """Detects the current TPU hardware generation."""
56+ try :
57+ device_kind = jax .devices ()[0 ].device_kind
58+ if "7x" in device_kind :
59+ return TpuType .TPU_7X
60+ elif "v6 lite" in device_kind :
61+ return TpuType .TPU_V6_LITE
62+ else :
63+ return TpuType .UNKNOWN
64+ except Exception :
65+ return TpuType .UNKNOWN
You can’t perform that action at this time.
0 commit comments