@@ -156,7 +156,7 @@ def decorator(*args, **kwargs):
156156 x_list = []
157157 const_args = [self .arg_names [i ] for i in self .constexprs ]
158158
159- decalare_arg_exclude_constexpr = list (self .arg_exclude_constexpr )
159+ declare_arg_exclude_constexpr = list (self .arg_exclude_constexpr )
160160 passed_arg_exclude_constexpr = list (self .arg_exclude_constexpr )
161161
162162 const_hint_dict = {}
@@ -177,8 +177,8 @@ def decorator(*args, **kwargs):
177177 else :
178178 dtypes .append (paddle .int8 )
179179 passed_arg_exclude_constexpr [i ] = "(CUdeviceptr)(nullptr)"
180- decalare_arg_exclude_constexpr [i ] = (
181- "const paddle::optional<paddle::Tensor>&" + decalare_arg_exclude_constexpr [i ]
180+ declare_arg_exclude_constexpr [i ] = (
181+ "const paddle::optional<paddle::Tensor>&" + declare_arg_exclude_constexpr [i ]
182182 )
183183 elif i in self .constexprs :
184184 if isinstance (ele , bool ):
@@ -193,9 +193,9 @@ def decorator(*args, **kwargs):
193193 else :
194194 x_list .append (ele )
195195 if isinstance (ele , int ):
196- decalare_arg_exclude_constexpr [i ] = "const int64_t " + decalare_arg_exclude_constexpr [i ]
196+ declare_arg_exclude_constexpr [i ] = "const int64_t " + declare_arg_exclude_constexpr [i ]
197197 elif isinstance (ele , float ):
198- decalare_arg_exclude_constexpr [i ] = "const float " + decalare_arg_exclude_constexpr [i ]
198+ declare_arg_exclude_constexpr [i ] = "const float " + declare_arg_exclude_constexpr [i ]
199199 else :
200200 assert False , f"Unsupported arg type: { type (ele )} for arg '{ self .arg_names [i ]} '"
201201
@@ -215,9 +215,9 @@ def decorator(*args, **kwargs):
215215 const_args = [f"{{{ ele } }}" for ele in const_args ]
216216 const_args = "," .join (const_args )
217217
218- lanuch_grid = list (self .grid )
219- for i in range (len (lanuch_grid )):
220- ele = lanuch_grid [i ]
218+ launch_grid = list (self .grid )
219+ for i in range (len (launch_grid )):
220+ ele = launch_grid [i ]
221221 if isinstance (ele , str ):
222222 keys = list (const_hint_dict .keys ())
223223 keys .sort (key = len , reverse = True )
@@ -226,15 +226,15 @@ def decorator(*args, **kwargs):
226226 ele = ele .replace (key , f"{ const_hint_dict [key ]} " )
227227 else :
228228 ele = str (ele )
229- lanuch_grid [i ] = ele
229+ launch_grid [i ] = ele
230230
231- if len (lanuch_grid ) < 3 :
232- lanuch_grid += ["1" ] * (3 - len (lanuch_grid ))
233- lanuch_grid = "," .join (lanuch_grid )
231+ if len (launch_grid ) < 3 :
232+ launch_grid += ["1" ] * (3 - len (launch_grid ))
233+ launch_grid = "," .join (launch_grid )
234234
235235 op_dict = {"op_name" : op_name }
236236 op_dict ["triton_kernel_args" ] = "," .join (passed_arg_exclude_constexpr )
237- op_dict ["tensor_and_attr" ] = "," .join (decalare_arg_exclude_constexpr )
237+ op_dict ["tensor_and_attr" ] = "," .join (declare_arg_exclude_constexpr )
238238
239239 paddle_custom_op_file_path = f"{ generated_dir } /{ op_name } .cu"
240240 so_path = find_so_path (generated_dir , python_package_name )
@@ -257,7 +257,7 @@ def decorator(*args, **kwargs):
257257 + f"""--out-name { op_name } _kernel """
258258 + """ -w {num_warps} -ns {num_stages} """
259259 + f""" -s"{ address_hint } { value_hint } { const_args } " """
260- + f""" -g "{ lanuch_grid } " """
260+ + f""" -g "{ launch_grid } " """
261261 )
262262
263263 all_tune_config = [const_hint_dict ]
0 commit comments