@@ -62,14 +62,8 @@ def _(*exprs: TypedExpr, op: ops.AIIf) -> sge.Expression:
6262
6363@register_nary_op (ops .AIClassify , pass_op = True )
6464def _ (* exprs : TypedExpr , op : ops .AIClassify ) -> sge .Expression :
65- category_literals = [sge .Literal .string (cat ) for cat in op .categories ]
66- categories_arg = sge .Kwarg (
67- this = "categories" , expression = sge .array (* category_literals )
68- )
69-
7065 args = [
7166 _construct_prompt (exprs , op .prompt_context , param_name = "input" ),
72- categories_arg ,
7367 ] + _construct_named_args (op )
7468
7569 return sge .func ("AI.CLASSIFY" , * args )
@@ -105,44 +99,33 @@ def _construct_named_args(op: ops.NaryOp) -> list[sge.Kwarg]:
10599
106100 op_args = asdict (op )
107101
108- connection_id = op_args .get ("connection_id" , None )
109- if connection_id is not None :
110- args .append (
111- sge .Kwarg (
112- this = "connection_id" , expression = sge .Literal .string (connection_id )
113- )
114- )
115-
116- endpoint = op_args .get ("endpoint" , None )
117- if endpoint is not None :
118- args .append (sge .Kwarg (this = "endpoint" , expression = sge .Literal .string (endpoint )))
102+ for field , value in op_args .items ():
103+ if value is None or field == "prompt_context" :
104+ continue
119105
120- request_type = op_args .get ("request_type" , None )
121- if request_type is not None :
122- args .append (
123- sge .Kwarg (
124- this = "request_type" , expression = sge .Literal .string (request_type .upper ())
106+ if field == "categories" :
107+ category_literals = [sge .Literal .string (cat ) for cat in value ]
108+ categories_arg = sge .Kwarg (
109+ this = "categories" , expression = sge .array (* category_literals )
125110 )
126- )
127-
128- model_params = op_args . get ( " model_params" , None )
129- if model_params is not None :
130- args . append (
131- sge . Kwarg (
132- this = "model_params" ,
133- # sge.JSON requires the SQLGlot version to be at least 25.18.0
134- # PARSE_JSON won't work as the function requires a JSON literal.
135- expression = sge . JSON ( this = sge . Literal . string ( model_params )),
111+ args . append ( categories_arg )
112+ elif field == "model_params" :
113+ # model_params is a JSON string, so we need to use the JSON function to pass it as a named argument.
114+ args . append (
115+ sge . Kwarg (
116+ this = "model_params" ,
117+ # sge.JSON requires the SQLGlot version to be at least 25.18.0
118+ # PARSE_JSON won't work as the function requires a JSON literal.
119+ expression = sge . JSON ( this = sge . Literal . string ( value )),
120+ )
136121 )
137- )
138-
139- output_schema = op_args .get ("output_schema" , None )
140- if output_schema is not None :
141- args .append (
142- sge .Kwarg (
143- this = "output_schema" ,
144- expression = sge .Literal .string (output_schema ),
122+ elif field == "request_type" :
123+ args .append (
124+ sge .Kwarg (this = field , expression = sge .Literal .string (value .upper ()))
125+ )
126+ else :
127+ args .append (
128+ sge .Kwarg (this = field , expression = sge .Literal .string (str (value )))
145129 )
146- )
147130
148131 return args
0 commit comments