@@ -458,12 +458,12 @@ def __call__(self, position_ids, max_len_lst, cumsum_seqlens):
458458
459459 # Build position_ids_3d: [bsz, max_position, 3]
460460 position_ids_3d = paddle .tile (
461- paddle .arange (self .max_position , dtype = "int64 " ).unsqueeze (0 ).unsqueeze (- 1 ),
461+ paddle .arange (self .max_position , dtype = "float32 " ).unsqueeze (0 ).unsqueeze (- 1 ),
462462 [bsz , 1 , 3 ],
463463 )
464464 for i in range (bsz ):
465465 position_ids_cur = position_ids [cumsum_seqlens [i ] : cumsum_seqlens [i + 1 ]]
466- prefix_max_position_ids = paddle .max (position_ids_cur ) + 1
466+ prefix_max_position_ids = paddle .max (position_ids_cur [..., 0 ] ) + 1
467467 dec_pos_ids = paddle .tile (
468468 paddle .arange (max_len_lst [i ], dtype = "int64" ).unsqueeze (- 1 ),
469469 [1 , 3 ],
@@ -530,12 +530,12 @@ def __call__(self, position_ids, max_len_lst, cumsum_seqlens):
530530 bsz = len (cumsum_seqlens ) - 1
531531 # position_ids_3d: [bsz, seq_len, 3]
532532 position_ids_3d = paddle .tile (
533- paddle .arange (self .max_position , dtype = "int64 " ).unsqueeze (0 ).unsqueeze (- 1 ),
533+ paddle .arange (self .max_position , dtype = "float32 " ).unsqueeze (0 ).unsqueeze (- 1 ),
534534 [bsz , 1 , 3 ],
535535 )
536536 for i in range (bsz ):
537537 position_ids_cur = position_ids [cumsum_seqlens [i ] : cumsum_seqlens [i + 1 ]]
538- prefix_max_position_ids = paddle .max (position_ids_cur ) + 1
538+ prefix_max_position_ids = paddle .max (position_ids_cur [..., 0 ] ) + 1
539539 dec_pos_ids = paddle .tile (
540540 paddle .arange (max_len_lst [i ], dtype = "int64" ).unsqueeze (- 1 ),
541541 [1 , 3 ],
0 commit comments