@@ -37,10 +37,16 @@ def setUp(self):
3737 def test_splash_attention (self ):
3838 """Test numerics of splash attention are equivalent to dot_product"""
3939
40- pyconfig .initialize ([None , os .path .join (THIS_DIR , ".." , "configs" , "base21.yml" ),
41- 'flash_block_sizes={"block_q" : 512, "block_kv_compute": 512, "block_kv": 512,'
42- '"block_q_dkv": 512, "block_kv_dkv": 512, "block_kv_dkv_compute": 512,'
43- '"block_q_dq": 512, "block_kv_dq": 512}' ,], unittest = True )
40+ pyconfig .initialize (
41+ [
42+ None ,
43+ os .path .join (THIS_DIR , ".." , "configs" , "base21.yml" ),
44+ 'flash_block_sizes={"block_q" : 512, "block_kv_compute": 512, "block_kv": 512,'
45+ '"block_q_dkv": 512, "block_kv_dkv": 512, "block_kv_dkv_compute": 512,'
46+ '"block_q_dq": 512, "block_kv_dq": 512}' ,
47+ ],
48+ unittest = True ,
49+ )
4450 config = pyconfig .config
4551
4652 batch = 8
@@ -57,7 +63,7 @@ def test_splash_attention(self):
5763 split_head_dim = True ,
5864 attention_kernel = "dot_product" ,
5965 mesh = None ,
60- dtype = jnp .bfloat16
66+ dtype = jnp .bfloat16 ,
6167 )
6268
6369 params = dot_product_attention .init (key2 , x )["params" ]
@@ -75,7 +81,7 @@ def test_splash_attention(self):
7581 attention_kernel = "flash" ,
7682 mesh = mesh ,
7783 dtype = jnp .bfloat16 ,
78- flash_block_sizes = flash_block_sizes
84+ flash_block_sizes = flash_block_sizes ,
7985 )
8086
8187 params = splash_attention .init (key2 , x )["params" ]
0 commit comments