Skip to content

Commit ca6b45e

Browse files
cleop-googlecopybara-github
authored andcommitted
fix: GenAI SDK client(multimodal) - Fix Pydantic validation errors when using create_* in some cases
PiperOrigin-RevId: 906376276
1 parent f5c4f8f commit ca6b45e

2 files changed

Lines changed: 33 additions & 16 deletions

File tree

vertexai/_genai/_datasets_utils.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
import google.auth.credentials
2323
from vertexai._genai.types import common
24-
from pydantic import BaseModel
24+
from google.genai import _common
2525

2626

2727
METADATA_SCHEMA_URI = (
@@ -31,18 +31,27 @@
3131
_DEFAULT_BQ_DATASET_PREFIX = "vertex_datasets"
3232
_DEFAULT_BQ_TABLE_PREFIX = "multimodal_dataset"
3333

34-
T = TypeVar("T", bound=BaseModel)
34+
T = TypeVar("T", bound=_common.BaseModel)
3535

3636

37-
def create_from_response(model_type: Type[T], response: dict[str, Any]) -> T:
37+
def create_from_response(
38+
model_type: Type[T],
39+
response: dict[str, Any],
40+
config: Any | None = None,
41+
) -> T:
3842
"""Creates a model from a response."""
39-
model_field_names = model_type.model_fields.keys()
40-
filtered_response = {}
41-
for key, value in response.items():
42-
snake_key = common.camel_to_snake(key)
43-
if snake_key in model_field_names:
44-
filtered_response[snake_key] = value
45-
return model_type(**filtered_response)
43+
kwargs = (
44+
{
45+
"config": {
46+
"response_schema": getattr(config, "response_schema", None),
47+
"response_json_schema": getattr(config, "response_json_schema", None),
48+
"include_all_fields": getattr(config, "include_all_fields", None),
49+
}
50+
}
51+
if config
52+
else {}
53+
)
54+
return model_type._from_response(response=response, kwargs=kwargs)
4655

4756

4857
def validate_multimodal_dataset_bigquery_uri(

vertexai/_genai/datasets.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -963,7 +963,9 @@ def create_from_bigquery(
963963
operation=multimodal_dataset_operation,
964964
timeout_seconds=config.timeout,
965965
)
966-
return _datasets_utils.create_from_response(types.MultimodalDataset, response)
966+
return _datasets_utils.create_from_response(
967+
types.MultimodalDataset, response, config
968+
)
967969

968970
def create_from_pandas(
969971
self,
@@ -1302,6 +1304,7 @@ def assess_tuning_resources(
13021304
return _datasets_utils.create_from_response(
13031305
types.TuningResourceUsageAssessmentResult,
13041306
response["tuningResourceUsageAssessmentResult"],
1307+
config,
13051308
)
13061309

13071310
def assess_tuning_validity(
@@ -1368,6 +1371,7 @@ def assess_tuning_validity(
13681371
return _datasets_utils.create_from_response(
13691372
types.TuningValidationAssessmentResult,
13701373
response["tuningValidationAssessmentResult"],
1374+
config,
13711375
)
13721376

13731377
def assess_batch_prediction_resources(
@@ -1430,7 +1434,7 @@ def assess_batch_prediction_resources(
14301434
)
14311435
result = response["batchPredictionResourceUsageAssessmentResult"]
14321436
return _datasets_utils.create_from_response(
1433-
types.BatchPredictionResourceUsageAssessmentResult, result
1437+
types.BatchPredictionResourceUsageAssessmentResult, result, config
14341438
)
14351439

14361440
def assess_batch_prediction_validity(
@@ -1493,7 +1497,7 @@ def assess_batch_prediction_validity(
14931497
)
14941498
result = response["batchPredictionValidationAssessmentResult"]
14951499
return _datasets_utils.create_from_response(
1496-
types.BatchPredictionValidationAssessmentResult, result
1500+
types.BatchPredictionValidationAssessmentResult, result, config
14971501
)
14981502

14991503

@@ -2231,7 +2235,9 @@ async def create_from_bigquery(
22312235
operation=multimodal_dataset_operation,
22322236
timeout_seconds=config.timeout,
22332237
)
2234-
return _datasets_utils.create_from_response(types.MultimodalDataset, response)
2238+
return _datasets_utils.create_from_response(
2239+
types.MultimodalDataset, response, config
2240+
)
22352241

22362242
async def create_from_pandas(
22372243
self,
@@ -2568,6 +2574,7 @@ async def assess_tuning_resources(
25682574
return _datasets_utils.create_from_response(
25692575
types.TuningResourceUsageAssessmentResult,
25702576
response["tuningResourceUsageAssessmentResult"],
2577+
config,
25712578
)
25722579

25732580
async def assess_tuning_validity(
@@ -2634,6 +2641,7 @@ async def assess_tuning_validity(
26342641
return _datasets_utils.create_from_response(
26352642
types.TuningValidationAssessmentResult,
26362643
response["tuningValidationAssessmentResult"],
2644+
config,
26372645
)
26382646

26392647
async def assess_batch_prediction_resources(
@@ -2696,7 +2704,7 @@ async def assess_batch_prediction_resources(
26962704
)
26972705
result = response["batchPredictionResourceUsageAssessmentResult"]
26982706
return _datasets_utils.create_from_response(
2699-
types.BatchPredictionResourceUsageAssessmentResult, result
2707+
types.BatchPredictionResourceUsageAssessmentResult, result, config
27002708
)
27012709

27022710
async def assess_batch_prediction_validity(
@@ -2759,5 +2767,5 @@ async def assess_batch_prediction_validity(
27592767
)
27602768
result = response["batchPredictionValidationAssessmentResult"]
27612769
return _datasets_utils.create_from_response(
2762-
types.BatchPredictionValidationAssessmentResult, result
2770+
types.BatchPredictionValidationAssessmentResult, result, config
27632771
)

0 commit comments

Comments
 (0)