@@ -494,6 +494,30 @@ def py_flatten(arr):
494494 lambda col : f .list_slice (col , literal (- 1 ), literal (2 )),
495495 lambda data : [arr [- 1 :2 ] for arr in data ],
496496 ),
497+ (
498+ lambda col : col [:3 ],
499+ lambda data : [arr [:3 ] for arr in data ],
500+ ),
501+ (
502+ lambda col : col [1 :3 ],
503+ lambda data : [arr [1 :3 ] for arr in data ],
504+ ),
505+ (
506+ lambda col : col [1 :4 :2 ],
507+ lambda data : [arr [1 :4 :2 ] for arr in data ],
508+ ),
509+ (
510+ lambda col : col [literal (1 ) : literal (4 )],
511+ lambda data : [arr [1 :4 ] for arr in data ],
512+ ),
513+ (
514+ lambda col : col [column ("indices" ) : column ("indices" ) + literal (2 )],
515+ lambda data : [[2.0 , 3.0 ], [], [6.0 ]],
516+ ),
517+ (
518+ lambda col : col [literal (1 ) : literal (4 ) : literal (2 )],
519+ lambda data : [arr [1 :4 :2 ] for arr in data ],
520+ ),
497521 (
498522 lambda col : f .array_intersect (col , literal ([3.0 , 4.0 ])),
499523 lambda data : [np .intersect1d (arr , [3.0 , 4.0 ]) for arr in data ],
@@ -534,8 +558,11 @@ def py_flatten(arr):
534558)
535559def test_array_functions (stmt , py_expr ):
536560 data = [[1.0 , 2.0 , 3.0 , 3.0 ], [4.0 , 5.0 , 3.0 ], [6.0 ]]
561+ indices = [1 , 3 , 0 ]
537562 ctx = SessionContext ()
538- batch = pa .RecordBatch .from_arrays ([np .array (data , dtype = object )], names = ["arr" ])
563+ batch = pa .RecordBatch .from_arrays (
564+ [np .array (data , dtype = object ), indices ], names = ["arr" , "indices" ]
565+ )
539566 df = ctx .create_dataframe ([[batch ]])
540567
541568 col = column ("arr" )
0 commit comments