@@ -104,10 +104,10 @@ def _reshape_heads_to_head_dim(tensor):
104104 return jax .lax .with_sharding_constraint (reshaped_tensor , PartitionSpec ("data" , "fsdp" , "tensor" ))
105105
106106
107- def _unflatten_heads (tensor , heads ):
107+ def _unflatten_heads (tensor , heads , divisor = 1 ):
108108 # reshapes from [b, s, h * d] to [b, h, s, d] (input format to flash format)
109109 batch , seq , heads_and_dim_head = tensor .shape
110- tensor = tensor .reshape (batch , seq , heads , heads_and_dim_head // heads )
110+ tensor = tensor .reshape (batch , seq , heads // divisor , divisor * heads_and_dim_head // heads )
111111 # Transpose to ('batch', 'heads', 'length', 'kv')
112112 tensor = jnp .transpose (tensor , (0 , 2 , 1 , 3 ))
113113 return tensor
@@ -120,7 +120,7 @@ def _reshape_data_for_flash(tensor, heads):
120120 blocks is divisible by the number of shards.
121121 """
122122 if tensor .ndim != 4 :
123- tensor = _unflatten_heads (tensor , heads )
123+ tensor = _unflatten_heads (tensor , heads , divisor = 2 )
124124 return tensor
125125
126126
0 commit comments