|
34 | 34 | D_KV, |
35 | 35 | AxisNames, |
36 | 36 | AxisIdxes, |
37 | | - LENGTH, |
38 | | - LENGTH_NO_EXP, |
| 37 | + ATTN_LENGTH, |
| 38 | + ATTN_LENGTH_NO_EXP, |
39 | 39 | DType, |
40 | 40 | Config, |
41 | 41 | Array, |
|
46 | 46 | KV_HEAD_DIM, |
47 | 47 | KV_BATCH, |
48 | 48 | KV_BATCH_NO_EXP, |
49 | | - EMBED, |
| 49 | + ATTN_EMBED, |
50 | 50 | MODEL_MODE_AUTOREGRESSIVE, |
51 | 51 | MODEL_MODE_TRAIN, |
52 | 52 | MODEL_MODE_PREFILL, |
@@ -141,18 +141,18 @@ def attention_as_linen( |
141 | 141 | prefill_query_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM), |
142 | 142 | prefill_key_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM), |
143 | 143 | prefill_value_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM), |
144 | | - query_axis_names: AxisNames = (KV_BATCH, LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM), |
145 | | - key_axis_names: AxisNames = (KV_BATCH, LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM), |
146 | | - value_axis_names: AxisNames = (KV_BATCH, LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM), |
147 | | - ep_query_axis_names: AxisNames = (KV_BATCH_NO_EXP, LENGTH, KV_HEAD, KV_HEAD_DIM), |
148 | | - ep_key_axis_names: AxisNames = (KV_BATCH_NO_EXP, LENGTH, KV_HEAD, KV_HEAD_DIM), |
149 | | - ep_value_axis_names: AxisNames = (KV_BATCH_NO_EXP, LENGTH, KV_HEAD, KV_HEAD_DIM), |
150 | | - input_axis_names: AxisNames = (BATCH, LENGTH_NO_EXP, EMBED), |
151 | | - ep_input_axis_names: AxisNames = (BATCH_NO_EXP, LENGTH, EMBED), |
152 | | - out_axis_names: AxisNames = (BATCH, LENGTH_NO_EXP, HEAD, D_KV), |
153 | | - ep_out_axis_names: AxisNames = (BATCH_NO_EXP, LENGTH, HEAD, D_KV), |
154 | | - prefill_input_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, EMBED), |
155 | | - decode_input_axis_names: AxisNames = (DECODE_BATCH, DECODE_LENGTH, EMBED), |
| 144 | + query_axis_names: AxisNames = (KV_BATCH, ATTN_LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM), |
| 145 | + key_axis_names: AxisNames = (KV_BATCH, ATTN_LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM), |
| 146 | + value_axis_names: AxisNames = (KV_BATCH, ATTN_LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM), |
| 147 | + ep_query_axis_names: AxisNames = (KV_BATCH_NO_EXP, ATTN_LENGTH, KV_HEAD, KV_HEAD_DIM), |
| 148 | + ep_key_axis_names: AxisNames = (KV_BATCH_NO_EXP, ATTN_LENGTH, KV_HEAD, KV_HEAD_DIM), |
| 149 | + ep_value_axis_names: AxisNames = (KV_BATCH_NO_EXP, ATTN_LENGTH, KV_HEAD, KV_HEAD_DIM), |
| 150 | + input_axis_names: AxisNames = (BATCH, ATTN_LENGTH_NO_EXP, ATTN_EMBED), |
| 151 | + ep_input_axis_names: AxisNames = (BATCH_NO_EXP, ATTN_LENGTH, ATTN_EMBED), |
| 152 | + out_axis_names: AxisNames = (BATCH, ATTN_LENGTH_NO_EXP, HEAD, D_KV), |
| 153 | + ep_out_axis_names: AxisNames = (BATCH_NO_EXP, ATTN_LENGTH, HEAD, D_KV), |
| 154 | + prefill_input_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, ATTN_EMBED), |
| 155 | + decode_input_axis_names: AxisNames = (DECODE_BATCH, DECODE_LENGTH, ATTN_EMBED), |
156 | 156 | prefill_out_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, HEAD, D_KV), |
157 | 157 | decode_out_axis_names: AxisNames = (DECODE_BATCH, DECODE_LENGTH, HEAD, D_KV), |
158 | 158 | prefill_cache_axis_order: AxisIdxes = (1, 2, 0, 3), |
@@ -300,18 +300,18 @@ def __init__( |
300 | 300 | prefill_query_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM), |
301 | 301 | prefill_key_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM), |
302 | 302 | prefill_value_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM), |
303 | | - query_axis_names: AxisNames = (KV_BATCH, LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM), |
304 | | - key_axis_names: AxisNames = (KV_BATCH, LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM), |
305 | | - value_axis_names: AxisNames = (KV_BATCH, LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM), |
306 | | - ep_query_axis_names: AxisNames = (KV_BATCH_NO_EXP, LENGTH, KV_HEAD, KV_HEAD_DIM), |
307 | | - ep_key_axis_names: AxisNames = (KV_BATCH_NO_EXP, LENGTH, KV_HEAD, KV_HEAD_DIM), |
308 | | - ep_value_axis_names: AxisNames = (KV_BATCH_NO_EXP, LENGTH, KV_HEAD, KV_HEAD_DIM), |
309 | | - input_axis_names: AxisNames = (BATCH, LENGTH_NO_EXP, EMBED), |
310 | | - ep_input_axis_names: AxisNames = (BATCH_NO_EXP, LENGTH, EMBED), |
311 | | - out_axis_names: AxisNames = (BATCH, LENGTH_NO_EXP, HEAD, D_KV), |
312 | | - ep_out_axis_names: AxisNames = (BATCH_NO_EXP, LENGTH, HEAD, D_KV), |
313 | | - prefill_input_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, EMBED), |
314 | | - decode_input_axis_names: AxisNames = (DECODE_BATCH, DECODE_LENGTH, EMBED), |
| 303 | + query_axis_names: AxisNames = (KV_BATCH, ATTN_LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM), |
| 304 | + key_axis_names: AxisNames = (KV_BATCH, ATTN_LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM), |
| 305 | + value_axis_names: AxisNames = (KV_BATCH, ATTN_LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM), |
| 306 | + ep_query_axis_names: AxisNames = (KV_BATCH_NO_EXP, ATTN_LENGTH, KV_HEAD, KV_HEAD_DIM), |
| 307 | + ep_key_axis_names: AxisNames = (KV_BATCH_NO_EXP, ATTN_LENGTH, KV_HEAD, KV_HEAD_DIM), |
| 308 | + ep_value_axis_names: AxisNames = (KV_BATCH_NO_EXP, ATTN_LENGTH, KV_HEAD, KV_HEAD_DIM), |
| 309 | + input_axis_names: AxisNames = (BATCH, ATTN_LENGTH_NO_EXP, ATTN_EMBED), |
| 310 | + ep_input_axis_names: AxisNames = (BATCH_NO_EXP, ATTN_LENGTH, ATTN_EMBED), |
| 311 | + out_axis_names: AxisNames = (BATCH, ATTN_LENGTH_NO_EXP, HEAD, D_KV), |
| 312 | + ep_out_axis_names: AxisNames = (BATCH_NO_EXP, ATTN_LENGTH, HEAD, D_KV), |
| 313 | + prefill_input_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, ATTN_EMBED), |
| 314 | + decode_input_axis_names: AxisNames = (DECODE_BATCH, DECODE_LENGTH, ATTN_EMBED), |
315 | 315 | prefill_out_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, HEAD, D_KV), |
316 | 316 | decode_out_axis_names: AxisNames = (DECODE_BATCH, DECODE_LENGTH, HEAD, D_KV), |
317 | 317 | prefill_cache_axis_order: AxisIdxes = (1, 2, 0, 3), |
|
0 commit comments