4040from datafusion ._internal import DataFrame as DataFrameInternal
4141from datafusion ._internal import ParquetColumnOptions as ParquetColumnOptionsInternal
4242from datafusion ._internal import ParquetWriterOptions as ParquetWriterOptionsInternal
43- from datafusion .expr import (
44- _EXPR_TYPE_ERROR ,
45- Expr ,
46- SortExpr ,
47- expr_list_to_raw_expr_list ,
48- sort_or_default ,
49- )
43+ from datafusion .expr import Expr , SortExpr , sort_or_default
5044from datafusion .plan import ExecutionPlan , LogicalPlan
5145from datafusion .record_batch import RecordBatchStream
5246
@@ -400,7 +394,9 @@ def select(self, *exprs: Expr | str) -> DataFrame:
400394 df = df.select("a", col("b"), col("a").alias("alternate_a"))
401395
402396 """
403- exprs_internal = expr_list_to_raw_expr_list (exprs )
397+ exprs_internal = [
398+ Expr .column (arg ).expr if isinstance (arg , str ) else arg .expr for arg in exprs
399+ ]
404400 return DataFrame (self .df .select (* exprs_internal ))
405401
406402 def drop (self , * columns : str ) -> DataFrame :
@@ -430,9 +426,7 @@ def filter(self, *predicates: Expr) -> DataFrame:
430426 """
431427 df = self .df
432428 for p in predicates :
433- if isinstance (p , str ) or not isinstance (p , Expr ):
434- raise TypeError (_EXPR_TYPE_ERROR )
435- df = df .filter (expr_list_to_raw_expr_list (p )[0 ])
429+ df = df .filter (p .expr )
436430 return DataFrame (df )
437431
438432 def with_column (self , name : str , expr : Expr ) -> DataFrame :
@@ -445,8 +439,6 @@ def with_column(self, name: str, expr: Expr) -> DataFrame:
445439 Returns:
446440 DataFrame with the new column.
447441 """
448- if not isinstance (expr , Expr ):
449- raise TypeError (_EXPR_TYPE_ERROR )
450442 return DataFrame (self .df .with_column (name , expr .expr ))
451443
452444 def with_columns (
@@ -478,18 +470,14 @@ def _simplify_expression(
478470 ) -> list [expr_internal .Expr ]:
479471 expr_list = []
480472 for expr in exprs :
481- if isinstance (expr , str ):
482- raise TypeError (_EXPR_TYPE_ERROR )
483- if isinstance (expr , Iterable ) and not isinstance (expr , Expr ):
484- if any (not isinstance (inner_expr , Expr ) for inner_expr in expr ):
485- raise TypeError (_EXPR_TYPE_ERROR )
486- elif not isinstance (expr , Expr ):
487- raise TypeError (_EXPR_TYPE_ERROR )
488- expr_list .extend (expr_list_to_raw_expr_list (expr ))
473+ if isinstance (expr , Expr ):
474+ expr_list .append (expr .expr )
475+ elif isinstance (expr , Iterable ):
476+ expr_list .extend (inner_expr .expr for inner_expr in expr )
477+ else :
478+ raise NotImplementedError
489479 if named_exprs :
490480 for alias , expr in named_exprs .items ():
491- if not isinstance (expr , Expr ):
492- raise TypeError (_EXPR_TYPE_ERROR )
493481 expr_list .append (expr .alias (alias ).expr )
494482 return expr_list
495483
@@ -515,56 +503,37 @@ def with_column_renamed(self, old_name: str, new_name: str) -> DataFrame:
515503 return DataFrame (self .df .with_column_renamed (old_name , new_name ))
516504
517505 def aggregate (
518- self ,
519- group_by : list [Expr | str ] | Expr | str ,
520- aggs : list [Expr ] | Expr ,
506+ self , group_by : list [Expr ] | Expr , aggs : list [Expr ] | Expr
521507 ) -> DataFrame :
522508 """Aggregates the rows of the current DataFrame.
523509
524510 Args:
525- group_by: List of expressions or column names to group by.
511+ group_by: List of expressions to group by.
526512 aggs: List of expressions to aggregate.
527513
528514 Returns:
529515 DataFrame after aggregation.
530516 """
531- group_by_list = group_by if isinstance (group_by , list ) else [group_by ]
532- aggs_list = aggs if isinstance (aggs , list ) else [aggs ]
517+ group_by = group_by if isinstance (group_by , list ) else [group_by ]
518+ aggs = aggs if isinstance (aggs , list ) else [aggs ]
533519
534- group_by_exprs = expr_list_to_raw_expr_list (group_by_list )
535- aggs_exprs = []
536- for agg in aggs_list :
537- if not isinstance (agg , Expr ):
538- raise TypeError (_EXPR_TYPE_ERROR )
539- aggs_exprs .append (agg .expr )
540- return DataFrame (self .df .aggregate (group_by_exprs , aggs_exprs ))
520+ group_by = [e .expr for e in group_by ]
521+ aggs = [e .expr for e in aggs ]
522+ return DataFrame (self .df .aggregate (group_by , aggs ))
541523
542- def sort (self , * exprs : Expr | SortExpr | str ) -> DataFrame :
543- """Sort the DataFrame by the specified sorting expressions or column names .
524+ def sort (self , * exprs : Expr | SortExpr ) -> DataFrame :
525+ """Sort the DataFrame by the specified sorting expressions.
544526
545527 Note that any expression can be turned into a sort expression by
546- calling its ``sort`` method.
528+ calling its` ``sort`` method.
547529
548530 Args:
549- exprs: Sort expressions or column names , applied in order.
531+ exprs: Sort expressions, applied in order.
550532
551533 Returns:
552534 DataFrame after sorting.
553535 """
554- exprs_raw = []
555- for e in exprs :
556- if isinstance (e , SortExpr ):
557- exprs_raw .append (sort_or_default (e ))
558- elif isinstance (e , str ):
559- exprs_raw .append (sort_or_default (Expr .column (e )))
560- elif isinstance (e , Expr ):
561- exprs_raw .append (sort_or_default (e ))
562- else :
563- error = (
564- "Expected Expr or column name, found:"
565- f" { type (e ).__name__ } . { _EXPR_TYPE_ERROR } ."
566- )
567- raise TypeError (error )
536+ exprs_raw = [sort_or_default (expr ) for expr in exprs ]
568537 return DataFrame (self .df .sort (* exprs_raw ))
569538
570539 def cast (self , mapping : dict [str , pa .DataType [Any ]]) -> DataFrame :
@@ -788,11 +757,7 @@ def join_on(
788757 Returns:
789758 DataFrame after join.
790759 """
791- exprs = []
792- for expr in on_exprs :
793- if not isinstance (expr , Expr ):
794- raise TypeError (_EXPR_TYPE_ERROR )
795- exprs .append (expr .expr )
760+ exprs = [expr .expr for expr in on_exprs ]
796761 return DataFrame (self .df .join_on (right .df , exprs , how ))
797762
798763 def explain (self , verbose : bool = False , analyze : bool = False ) -> None :
0 commit comments