Skip to content

Commit 7c8412a

Browse files
authored
refactor(bigframes): update SQLGlot compiler to process AI func params better (googleapis#16757)
1 parent f728bd6 commit 7c8412a

1 file changed

Lines changed: 24 additions & 41 deletions

File tree

  • packages/bigframes/bigframes/core/compile/sqlglot/expressions

packages/bigframes/bigframes/core/compile/sqlglot/expressions/ai_ops.py

Lines changed: 24 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -62,14 +62,8 @@ def _(*exprs: TypedExpr, op: ops.AIIf) -> sge.Expression:
6262

6363
@register_nary_op(ops.AIClassify, pass_op=True)
6464
def _(*exprs: TypedExpr, op: ops.AIClassify) -> sge.Expression:
65-
category_literals = [sge.Literal.string(cat) for cat in op.categories]
66-
categories_arg = sge.Kwarg(
67-
this="categories", expression=sge.array(*category_literals)
68-
)
69-
7065
args = [
7166
_construct_prompt(exprs, op.prompt_context, param_name="input"),
72-
categories_arg,
7367
] + _construct_named_args(op)
7468

7569
return sge.func("AI.CLASSIFY", *args)
@@ -105,44 +99,33 @@ def _construct_named_args(op: ops.NaryOp) -> list[sge.Kwarg]:
10599

106100
op_args = asdict(op)
107101

108-
connection_id = op_args.get("connection_id", None)
109-
if connection_id is not None:
110-
args.append(
111-
sge.Kwarg(
112-
this="connection_id", expression=sge.Literal.string(connection_id)
113-
)
114-
)
115-
116-
endpoint = op_args.get("endpoint", None)
117-
if endpoint is not None:
118-
args.append(sge.Kwarg(this="endpoint", expression=sge.Literal.string(endpoint)))
102+
for field, value in op_args.items():
103+
if value is None or field == "prompt_context":
104+
continue
119105

120-
request_type = op_args.get("request_type", None)
121-
if request_type is not None:
122-
args.append(
123-
sge.Kwarg(
124-
this="request_type", expression=sge.Literal.string(request_type.upper())
106+
if field == "categories":
107+
category_literals = [sge.Literal.string(cat) for cat in value]
108+
categories_arg = sge.Kwarg(
109+
this="categories", expression=sge.array(*category_literals)
125110
)
126-
)
127-
128-
model_params = op_args.get("model_params", None)
129-
if model_params is not None:
130-
args.append(
131-
sge.Kwarg(
132-
this="model_params",
133-
# sge.JSON requires the SQLGlot version to be at least 25.18.0
134-
# PARSE_JSON won't work as the function requires a JSON literal.
135-
expression=sge.JSON(this=sge.Literal.string(model_params)),
111+
args.append(categories_arg)
112+
elif field == "model_params":
113+
# model_params is a JSON string, so we need to use the JSON function to pass it as a named argument.
114+
args.append(
115+
sge.Kwarg(
116+
this="model_params",
117+
# sge.JSON requires the SQLGlot version to be at least 25.18.0
118+
# PARSE_JSON won't work as the function requires a JSON literal.
119+
expression=sge.JSON(this=sge.Literal.string(value)),
120+
)
136121
)
137-
)
138-
139-
output_schema = op_args.get("output_schema", None)
140-
if output_schema is not None:
141-
args.append(
142-
sge.Kwarg(
143-
this="output_schema",
144-
expression=sge.Literal.string(output_schema),
122+
elif field == "request_type":
123+
args.append(
124+
sge.Kwarg(this=field, expression=sge.Literal.string(value.upper()))
125+
)
126+
else:
127+
args.append(
128+
sge.Kwarg(this=field, expression=sge.Literal.string(str(value)))
145129
)
146-
)
147130

148131
return args

0 commit comments

Comments
 (0)