@@ -1198,17 +1198,10 @@ def wrap_splash_kernel(single_head_mask, shard_head_size=1):
11981198 segment_axis_names_splash_kernel = self ._logical_to_mesh_axes ((Q_LENGTH ,))
11991199 else :
12001200 segment_axis_names_splash_kernel = self ._logical_to_mesh_axes ((Q_LENGTH_NO_EXP ,))
1201- elif (
1202- self .config .use_jax_splash
1203- and self .config .expert_shard_attention_option == EP_AS_FSDP
1204- ):
1201+ elif self .config .use_jax_splash and self .config .expert_shard_attention_option == EP_AS_FSDP :
12051202 if self .config .use_max_logit_estimate > 0 :
1206- sa_config = dataclasses .replace (
1207- sa_config , max_logit_const = self .config .use_max_logit_estimate
1208- )
1209- segment_axis_names_splash_kernel = nn .logical_to_mesh_axes ((
1210- Q_LENGTH_NO_EXP ,
1211- ))
1203+ sa_config = dataclasses .replace (sa_config , max_logit_const = self .config .use_max_logit_estimate )
1204+ segment_axis_names_splash_kernel = nn .logical_to_mesh_axes ((Q_LENGTH_NO_EXP ,))
12121205 else :
12131206 # Create multi-head mask
12141207 multi_head_mask = splash_attention_mask .MultiHeadMask (masks = (mask ,) * query .shape [1 ])
@@ -1327,7 +1320,13 @@ def _maybe_shard_with_pspec(inputs, pspec: jax.sharding.PartitionSpec | None):
13271320 if pspec is None :
13281321 return None
13291322 sharding = NamedSharding (self .mesh , pspec )
1330- return maybe_shard_with_name (inputs , sharding , shard_mode = self .config .shard_mode )
1323+ return maybe_shard_with_name (
1324+ inputs ,
1325+ sharding ,
1326+ shard_mode = self .config .shard_mode ,
1327+ debug_sharding = self .config .debug_sharding ,
1328+ extra_stack_level = 1 ,
1329+ )
13311330
13321331 query = _maybe_shard_with_pspec (query , axis_names_q )
13331332 key = _maybe_shard_with_pspec (key , axis_names_kv )
0 commit comments