Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions packages/bigframes/bigframes/bigquery/_operations/ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -941,6 +941,8 @@ def score(
prompt: PROMPT_TYPE,
*,
connection_id: str | None = None,
endpoint: str | None = None,
max_error_ratio: float | None = None,
) -> series.Series:
"""
Computes a score based on rubrics described in natural language. It will return a double value.
Expand All @@ -958,20 +960,21 @@ def score(
2 3.0
dtype: Float64

.. note::

This product or feature is subject to the "Pre-GA Offerings Terms" in the General Service Terms section of the
Service Specific Terms(https://cloud.google.com/terms/service-terms#1). Pre-GA products and features are available "as is"
and might have limited support. For more information, see the launch stage descriptions
(https://cloud.google.com/products#product-launch-stages).

Args:
prompt (str | Series | List[str|Series] | Tuple[str|Series, ...]):
A mixture of Series and string literals that specifies the prompt to send to the model. The Series can be BigFrames Series
or pandas Series.
connection_id (str, optional):
Specifies the connection to use to communicate with the model. For example, `myproject.us.myconnection`.
If not provided, the query uses your end-user credential.
endpoint (str, optional):
Specifies the Vertex AI endpoint to use for the model. For example `"gemini-2.5-flash"`. You can specify any
generally available or preview Gemini model. If you specify the model name, BigQuery ML automatically identifies and
uses the full endpoint of the model. If you don't specify an ENDPOINT value, BigQuery ML dynamically chooses a model
Comment thread
sycai marked this conversation as resolved.
Outdated
based on your query to have the best cost to quality tradeoff for the task.
max_error_ratio (float, optional):
A value between `0.0` and `1.0` that contains the maximum acceptable ratio of row-level inference failures to
rows processed on this function. If this value is exceeded, then the query fails.

Returns:
bigframes.series.Series: A new series of double (float) values.
Expand All @@ -983,6 +986,8 @@ def score(
operator = ai_ops.AIScore(
prompt_context=tuple(prompt_context),
connection_id=connection_id,
endpoint=endpoint,
max_error_ratio=max_error_ratio,
)

return series_list[0]._apply_nary_op(operator, series_list[1:])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2005,6 +2005,8 @@ def ai_score(*values: ibis_types.Value, op: ops.AIScore) -> ibis_types.StructVal
return ai_ops.AIScore(
_construct_prompt(values, op.prompt_context), # type: ignore
op.connection_id, # type: ignore
op.endpoint, # type: ignore
op.max_error_ratio, # type: ignore
).to_expr()


Expand Down
2 changes: 2 additions & 0 deletions packages/bigframes/bigframes/operations/ai_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,8 @@ class AIScore(base_ops.NaryOp):

prompt_context: Tuple[str | None, ...]
connection_id: str | None
endpoint: str | None
max_error_ratio: float | None

def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
return dtypes.FLOAT_DTYPE
Expand Down
2 changes: 1 addition & 1 deletion packages/bigframes/bigframes/pandas/io/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,8 +654,8 @@ def from_glob_path(
def _get_bqclient_and_project() -> Tuple[bigquery.Client, str]:
# Address circular imports in doctest due to bigframes/session/__init__.py
# containing a lot of logic and samples.
from bigframes.session import clients
import bigframes._config.auth
from bigframes.session import clients

credentials, project = bigframes._config.auth.resolve_credentials_and_project(
config.options.bigquery
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
SELECT
AI.SCORE(
prompt => (`string_col`, ' is the same as ', `string_col`),
endpoint => 'gemini-2.5-flash',
max_error_ratio => 0.5
) AS `result`
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0`
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,27 @@ def test_ai_score(scalar_types_df: dataframe.DataFrame, snapshot, connection_id)
op = ops.AIScore(
prompt_context=(None, " is the same as ", None),
connection_id=connection_id,
endpoint=None,
max_error_ratio=None,
)

sql = utils._apply_ops_to_sql(
scalar_types_df, [op.as_expr(col_name, col_name)], ["result"]
)

snapshot.assert_match(sql, "out.sql")


def test_ai_score_with_endpoint_and_max_error_ratio(
scalar_types_df: dataframe.DataFrame, snapshot
):
col_name = "string_col"

op = ops.AIScore(
prompt_context=(None, " is the same as ", None),
connection_id=None,
endpoint="gemini-2.5-flash",
max_error_ratio=0.5,
)

sql = utils._apply_ops_to_sql(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,9 @@ class AIIf(Value):

prompt: Value
connection_id: Optional[Value[dt.String]]
endpoint: Optional[Value[dt.String]] = None
optimization_mode: Optional[Value[dt.String]] = None
max_error_ratio: Optional[Value[dt.Float64]] = None
endpoint: Optional[Value[dt.String]]
optimization_mode: Optional[Value[dt.String]]
max_error_ratio: Optional[Value[dt.Float64]]

shape = rlz.shape_like("prompt")

Expand All @@ -151,7 +151,7 @@ def dtype(self) -> dt.Struct:

@public
class AIClassify(Value):
"""Generate True/False based on the prompt"""
"""Generate categories based on the prompt"""

input: Value
categories: Value[dt.Array[dt.String]]
Expand All @@ -166,13 +166,19 @@ def dtype(self) -> dt.Struct:

@public
class AIScore(Value):
"""Generate doubles based on the prompt"""
"""Generate scores based on the prompt"""

prompt: Value
connection_id: Optional[Value[dt.String]]
endpoint: Optional[Value[dt.String]]
max_error_ratio: Optional[Value[dt.Float64]]

shape = rlz.shape_like("prompt")

@attribute
def dtype(self) -> dt.Struct:
Comment thread
sycai marked this conversation as resolved.
Outdated
return dt.float64


@public
class AISimilarity(Value):
Expand Down
Loading