Skip to content

Commit c9cf4b7

Browse files
committed
fix: update AI tests and library logic
1 parent 3340a72 commit c9cf4b7

2 files changed

Lines changed: 14 additions & 12 deletions

File tree

  • packages/bigframes

packages/bigframes/bigframes/bigquery/_operations/ai.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1172,6 +1172,8 @@ def forecast(
11721172
return ml_core.BaseBqml(df._session).ai_forecast(input_data=df, options=options)
11731173

11741174

1175+
1176+
11751177
def _separate_context_and_series(
11761178
prompt: PROMPT_TYPE,
11771179
) -> Tuple[List[str | None], List[series.Series]]:
@@ -1189,9 +1191,6 @@ def _separate_context_and_series(
11891191
return [None], [series.Series([prompt])]
11901192

11911193
if isinstance(prompt, series.Series):
1192-
if prompt.dtype == dtypes.OBJ_REF_DTYPE:
1193-
# Multi-model support
1194-
return [None], [prompt._blob._read_url()]
11951194
return [None], [prompt]
11961195

11971196
prompt_context: List[str | None] = []
@@ -1226,9 +1225,6 @@ def _convert_series(
12261225
) -> series.Series:
12271226
result = convert.to_bf_series(s, default_index=None, session=session)
12281227

1229-
if result.dtype == dtypes.OBJ_REF_DTYPE:
1230-
# Support multimodel
1231-
return result._blob._read_url()
12321228
return result
12331229

12341230

packages/bigframes/tests/system/small/bigquery/test_ai.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,8 @@ def test_ai_generate_bool_multi_model(session):
188188
session, ["gs://cloud-samples-data/vision/ocr/sign.jpg"], name="image"
189189
)
190190

191-
result = bbq.ai.generate_bool((df["image"], " contains an animal"))
191+
image_runtime = bbq.obj.get_access_url(df["image"], mode="R")
192+
result = bbq.ai.generate_bool((image_runtime, " contains an animal"))
192193

193194
assert _contains_no_nulls(result)
194195
assert result.dtype == pd.ArrowDtype(
@@ -225,8 +226,9 @@ def test_ai_generate_int_multi_model(session):
225226
session, ["gs://cloud-samples-data/vision/ocr/sign.jpg"], name="image"
226227
)
227228

229+
image_runtime = bbq.obj.get_access_url(df["image"], mode="R")
228230
result = bbq.ai.generate_int(
229-
("How many animals are there in the picture ", df["image"])
231+
("How many animals are there in the picture ", image_runtime)
230232
)
231233

232234
assert _contains_no_nulls(result)
@@ -264,8 +266,9 @@ def test_ai_generate_double_multi_model(session):
264266
session, ["gs://cloud-samples-data/vision/ocr/sign.jpg"], name="image"
265267
)
266268

269+
image_runtime = bbq.obj.get_access_url(df["image"], mode="R")
267270
result = bbq.ai.generate_double(
268-
("How many animals are there in the picture ", df["image"])
271+
("How many animals are there in the picture ", image_runtime)
269272
)
270273

271274
assert _contains_no_nulls(result)
@@ -359,7 +362,8 @@ def test_ai_if_multi_model(session, bq_connection):
359362
session, ["gs://cloud-samples-data/vision/ocr/sign.jpg"], name="image"
360363
)
361364

362-
result = bbq.ai.if_((df["image"], " contains an animal"))
365+
image_runtime = bbq.obj.get_access_url(df["image"], mode="R")
366+
result = bbq.ai.if_((image_runtime, " contains an animal"))
363367

364368
assert _contains_no_nulls(result)
365369
assert result.dtype == dtypes.BOOL_DTYPE
@@ -379,7 +383,8 @@ def test_ai_classify_multi_model(session, bq_connection):
379383
session, ["gs://cloud-samples-data/vision/ocr/sign.jpg"], name="image"
380384
)
381385

382-
result = bbq.ai.classify(df["image"], ["photo", "cartoon"])
386+
image_runtime = bbq.obj.get_access_url(df["image"], mode="R")
387+
result = bbq.ai.classify(image_runtime, ["photo", "cartoon"])
383388

384389
assert _contains_no_nulls(result)
385390
assert result.dtype == dtypes.STRING_DTYPE
@@ -399,7 +404,8 @@ def test_ai_score_multi_model(session):
399404
df = _create_mock_obj_ref_df(
400405
session, ["gs://cloud-samples-data/vision/ocr/sign.jpg"], name="image"
401406
)
402-
prompt = ("Rank the liveliness of ", df["image"], "on the scale from 1 to 3")
407+
image_runtime = bbq.obj.get_access_url(df["image"], mode="R")
408+
prompt = ("Rank the liveliness of ", image_runtime, "on the scale from 1 to 3")
403409

404410
result = bbq.ai.score(prompt)
405411

0 commit comments

Comments
 (0)