Skip to content

Commit ffe7fc7

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
fix: Add serializer to async_create_session to address 500 error in Agent Engine (AgentServerMode.EXPERIMENTAL).
PiperOrigin-RevId: 903556708
1 parent ac5a5e4 commit ffe7fc7

4 files changed

Lines changed: 106 additions & 47 deletions

File tree

tests/unit/vertex_adk/test_agent_engine_templates_adk.py

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -546,23 +546,23 @@ async def test_streaming_agent_run_with_events_force_flush_otel(
546546
async def test_async_create_session(self, get_project_id_mock: mock.Mock):
547547
app = agent_engines.AdkApp(agent=_TEST_AGENT)
548548
session1 = await app.async_create_session(user_id=_TEST_USER_ID)
549-
assert session1.user_id == _TEST_USER_ID
549+
assert session1["user_id"] == _TEST_USER_ID
550550
session2 = await app.async_create_session(
551551
user_id=_TEST_USER_ID, session_id="test_session_id"
552552
)
553-
assert session2.user_id == _TEST_USER_ID
554-
assert session2.id == "test_session_id"
553+
assert session2["user_id"] == _TEST_USER_ID
554+
assert session2["id"] == "test_session_id"
555555

556556
@pytest.mark.asyncio
557557
async def test_async_get_session(self, get_project_id_mock: mock.Mock):
558558
app = agent_engines.AdkApp(agent=_TEST_AGENT)
559559
session1 = await app.async_create_session(user_id=_TEST_USER_ID)
560560
session2 = await app.async_get_session(
561561
user_id=_TEST_USER_ID,
562-
session_id=session1.id,
562+
session_id=session1["id"],
563563
)
564564
assert session2.user_id == _TEST_USER_ID
565-
assert session1.id == session2.id
565+
assert session1["id"] == session2.id
566566

567567
@pytest.mark.asyncio
568568
async def test_async_list_sessions(self, get_project_id_mock: mock.Mock):
@@ -572,12 +572,12 @@ async def test_async_list_sessions(self, get_project_id_mock: mock.Mock):
572572
session = await app.async_create_session(user_id=_TEST_USER_ID)
573573
response1 = await app.async_list_sessions(user_id=_TEST_USER_ID)
574574
assert len(response1.sessions) == 1
575-
assert response1.sessions[0].id == session.id
575+
assert response1.sessions[0].id == session["id"]
576576
session2 = await app.async_create_session(user_id=_TEST_USER_ID)
577577
response2 = await app.async_list_sessions(user_id=_TEST_USER_ID)
578578
assert len(response2.sessions) == 2
579-
assert response2.sessions[0].id == session.id
580-
assert response2.sessions[1].id == session2.id
579+
assert response2.sessions[0].id == session["id"]
580+
assert response2.sessions[1].id == session2["id"]
581581

582582
@pytest.mark.asyncio
583583
async def test_async_delete_session(self, get_project_id_mock: mock.Mock):
@@ -592,30 +592,30 @@ async def test_async_delete_session(self, get_project_id_mock: mock.Mock):
592592
assert len(response1.sessions) == 1
593593
await app.async_delete_session(
594594
user_id=_TEST_USER_ID,
595-
session_id=session.id,
595+
session_id=session["id"],
596596
)
597597
response0 = await app.async_list_sessions(user_id=_TEST_USER_ID)
598598
assert not response0.sessions
599599

600600
def test_create_session(self, get_project_id_mock: mock.Mock):
601601
app = agent_engines.AdkApp(agent=_TEST_AGENT)
602602
session1 = app.create_session(user_id=_TEST_USER_ID)
603-
assert session1.user_id == _TEST_USER_ID
603+
assert session1["user_id"] == _TEST_USER_ID
604604
session2 = app.create_session(
605605
user_id=_TEST_USER_ID, session_id="test_session_id"
606606
)
607-
assert session2.user_id == _TEST_USER_ID
608-
assert session2.id == "test_session_id"
607+
assert session2["user_id"] == _TEST_USER_ID
608+
assert session2["id"] == "test_session_id"
609609

610610
def test_get_session(self, get_project_id_mock: mock.Mock):
611611
app = agent_engines.AdkApp(agent=_TEST_AGENT)
612612
session1 = app.create_session(user_id=_TEST_USER_ID)
613613
session2 = app.get_session(
614614
user_id=_TEST_USER_ID,
615-
session_id=session1.id,
615+
session_id=session1["id"],
616616
)
617617
assert session2.user_id == _TEST_USER_ID
618-
assert session1.id == session2.id
618+
assert session1["id"] == session2.id
619619

620620
def test_list_sessions(self, get_project_id_mock: mock.Mock):
621621
app = agent_engines.AdkApp(agent=_TEST_AGENT)
@@ -624,12 +624,12 @@ def test_list_sessions(self, get_project_id_mock: mock.Mock):
624624
session = app.create_session(user_id=_TEST_USER_ID)
625625
response1 = app.list_sessions(user_id=_TEST_USER_ID)
626626
assert len(response1.sessions) == 1
627-
assert response1.sessions[0].id == session.id
627+
assert response1.sessions[0].id == session["id"]
628628
session2 = app.create_session(user_id=_TEST_USER_ID)
629629
response2 = app.list_sessions(user_id=_TEST_USER_ID)
630630
assert len(response2.sessions) == 2
631-
assert response2.sessions[0].id == session.id
632-
assert response2.sessions[1].id == session2.id
631+
assert response2.sessions[0].id == session["id"]
632+
assert response2.sessions[1].id == session2["id"]
633633

634634
def test_delete_session(self, get_project_id_mock: mock.Mock):
635635
app = agent_engines.AdkApp(agent=_TEST_AGENT)
@@ -638,7 +638,7 @@ def test_delete_session(self, get_project_id_mock: mock.Mock):
638638
session = app.create_session(user_id=_TEST_USER_ID)
639639
response1 = app.list_sessions(user_id=_TEST_USER_ID)
640640
assert len(response1.sessions) == 1
641-
app.delete_session(user_id=_TEST_USER_ID, session_id=session.id)
641+
app.delete_session(user_id=_TEST_USER_ID, session_id=session["id"])
642642
response0 = app.list_sessions(user_id=_TEST_USER_ID)
643643
assert not response0.sessions
644644

@@ -817,7 +817,8 @@ def test_tracing_setup(
817817
"uuid.uuid4", lambda: uuid.UUID("12345678123456781234567812345678")
818818
)
819819
monkeypatch.setattr("os.getpid", lambda: 123123123)
820-
app = agent_engines.AdkApp(agent=_TEST_AGENT, enable_tracing=True)
820+
with mock.patch.object(initializer.global_config, "_project", _TEST_PROJECT):
821+
app = agent_engines.AdkApp(agent=_TEST_AGENT, enable_tracing=True)
821822
app.set_up()
822823

823824
otlp_span_exporter_mock.assert_called_once_with(
@@ -826,7 +827,7 @@ def test_tracing_setup(
826827
headers=mock.ANY,
827828
)
828829

829-
get_project_id_mock.assert_called_with(_TEST_PROJECT_ID)
830+
get_project_id_mock.assert_called_with(_TEST_PROJECT)
830831

831832
user_agent = otlp_span_exporter_mock.call_args.kwargs["headers"]["User-Agent"]
832833
assert (

tests/unit/vertex_adk/test_reasoning_engine_templates_adk.py

Lines changed: 54 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -251,13 +251,18 @@ def adk_version_mock():
251251
yield adk_version_mock
252252

253253

254-
@pytest.fixture
254+
@pytest.fixture(autouse=True)
255255
def get_project_id_mock():
256256
with mock.patch(
257257
"google.cloud.aiplatform.aiplatform.utils.resource_manager_utils.get_project_id"
258258
) as get_project_id_mock:
259259
get_project_id_mock.return_value = _TEST_PROJECT_ID
260-
yield get_project_id_mock
260+
with mock.patch.object(initializer.global_config, "_project", _TEST_PROJECT):
261+
with mock.patch(
262+
"google.cloud.aiplatform.vertexai.preview.reasoning_engines.templates.adk.AdkApp._warn_if_telemetry_api_disabled",
263+
return_value=None,
264+
):
265+
yield get_project_id_mock
261266

262267

263268
class _MockRunner:
@@ -376,7 +381,7 @@ def test_initialization(self):
376381
app = reasoning_engines.AdkApp(
377382
agent=Agent(name=_TEST_AGENT_NAME, model=_TEST_MODEL),
378383
)
379-
assert app._tmpl_attrs.get("project") == _TEST_PROJECT
384+
assert app._tmpl_attrs.get("project") == _TEST_PROJECT_ID
380385
assert app._tmpl_attrs.get("location") == _TEST_LOCATION
381386
assert app._tmpl_attrs.get("runner") is None
382387

@@ -568,7 +573,17 @@ def test_streaming_agent_run_with_events(self):
568573
"artifacts": [
569574
{
570575
"file_name": "test_file_name",
571-
"versions": [{"version": "v1", "data": "v1data"}],
576+
"versions": [
577+
{
578+
"version": "v1",
579+
"data": {
580+
"inline_data": {
581+
"data": "djFkYXRh",
582+
"mime_type": "text/plain",
583+
}
584+
},
585+
}
586+
],
572587
}
573588
],
574589
"authorizations": {
@@ -606,7 +621,17 @@ async def test_streaming_agent_run_with_events_force_flush_otel(
606621
"artifacts": [
607622
{
608623
"file_name": "test_file_name",
609-
"versions": [{"version": "v1", "data": "v1data"}],
624+
"versions": [
625+
{
626+
"version": "v1",
627+
"data": {
628+
"inline_data": {
629+
"data": "djFkYXRh",
630+
"mime_type": "text/plain",
631+
}
632+
},
633+
}
634+
],
610635
}
611636
],
612637
"authorizations": {
@@ -682,12 +707,12 @@ def test_create_session(self):
682707
agent=Agent(name=_TEST_AGENT_NAME, model=_TEST_MODEL)
683708
)
684709
session1 = app.create_session(user_id=_TEST_USER_ID)
685-
assert session1.user_id == _TEST_USER_ID
710+
assert session1["user_id"] == _TEST_USER_ID
686711
session2 = app.create_session(
687712
user_id=_TEST_USER_ID, session_id="test_session_id"
688713
)
689-
assert session2.user_id == _TEST_USER_ID
690-
assert session2.id == "test_session_id"
714+
assert session2["user_id"] == _TEST_USER_ID
715+
assert session2["id"] == "test_session_id"
691716

692717
def test_get_session(self):
693718
app = reasoning_engines.AdkApp(
@@ -696,10 +721,10 @@ def test_get_session(self):
696721
session1 = app.create_session(user_id=_TEST_USER_ID)
697722
session2 = app.get_session(
698723
user_id=_TEST_USER_ID,
699-
session_id=session1.id,
724+
session_id=session1["id"],
700725
)
701726
assert session2.user_id == _TEST_USER_ID
702-
assert session1.id == session2.id
727+
assert session1["id"] == session2.id
703728

704729
def test_list_sessions(self):
705730
app = reasoning_engines.AdkApp(
@@ -710,12 +735,12 @@ def test_list_sessions(self):
710735
session = app.create_session(user_id=_TEST_USER_ID)
711736
response1 = app.list_sessions(user_id=_TEST_USER_ID)
712737
assert len(response1.sessions) == 1
713-
assert response1.sessions[0].id == session.id
738+
assert response1.sessions[0].id == session["id"]
714739
session2 = app.create_session(user_id=_TEST_USER_ID)
715740
response2 = app.list_sessions(user_id=_TEST_USER_ID)
716741
assert len(response2.sessions) == 2
717-
assert response2.sessions[0].id == session.id
718-
assert response2.sessions[1].id == session2.id
742+
assert response2.sessions[0].id == session["id"]
743+
assert response2.sessions[1].id == session2["id"]
719744

720745
def test_delete_session(self):
721746
app = reasoning_engines.AdkApp(
@@ -726,7 +751,7 @@ def test_delete_session(self):
726751
session = app.create_session(user_id=_TEST_USER_ID)
727752
response1 = app.list_sessions(user_id=_TEST_USER_ID)
728753
assert len(response1.sessions) == 1
729-
app.delete_session(user_id=_TEST_USER_ID, session_id=session.id)
754+
app.delete_session(user_id=_TEST_USER_ID, session_id=session["id"])
730755
response0 = app.list_sessions(user_id=_TEST_USER_ID)
731756
assert not response0.sessions
732757

@@ -740,14 +765,14 @@ async def test_async_add_session_to_memory(self):
740765
list(
741766
app.stream_query(
742767
user_id=_TEST_USER_ID,
743-
session_id=session.id,
768+
session_id=session["id"],
744769
message="My cat's name is Garfield",
745770
)
746771
)
747772
await app.async_add_session_to_memory(
748773
session=app.get_session(
749774
user_id=_TEST_USER_ID,
750-
session_id=session.id,
775+
session_id=session["id"],
751776
)
752777
)
753778
response = await app.async_search_memory(
@@ -838,7 +863,7 @@ def test_default_instrumentor_enablement(
838863

839864
# Assert
840865
default_instrumentor_builder_mock.assert_called_once_with(
841-
_TEST_PROJECT,
866+
_TEST_PROJECT_ID,
842867
enable_tracing=want_tracing_setup,
843868
enable_logging=want_logging_setup,
844869
)
@@ -863,11 +888,16 @@ def test_tracing_setup(
863888
monkeypatch.setattr("os.getpid", lambda: 123123123)
864889
app = reasoning_engines.AdkApp(agent=_TEST_AGENT, enable_tracing=True)
865890
app._warn_if_telemetry_api_disabled = lambda: None
866-
app.set_up()
891+
with mock.patch(
892+
"google.cloud.aiplatform.vertexai.agent_engines._utils.is_noop_or_proxy_tracer_provider",
893+
return_value=True,
894+
):
895+
app.set_up()
867896

868897
expected_attributes = {
869898
"cloud.account.id": _TEST_PROJECT_ID,
870899
"cloud.platform": "gcp.agent_engine",
900+
"cloud.provider": "gcp",
871901
"cloud.region": "us-central1",
872902
"cloud.resource_id": "//aiplatform.googleapis.com/projects/test-project-id/locations/us-central1/reasoningEngines/test_agent_id",
873903
"gcp.project_id": _TEST_PROJECT_ID,
@@ -876,7 +906,7 @@ def test_tracing_setup(
876906
"some-attribute": "some-value",
877907
"telemetry.sdk.language": "python",
878908
"telemetry.sdk.name": "opentelemetry",
879-
"telemetry.sdk.version": "1.36.0",
909+
"telemetry.sdk.version": "1.39.0",
880910
"some-attribute": "some-value",
881911
}
882912

@@ -886,7 +916,11 @@ def test_tracing_setup(
886916
headers=mock.ANY,
887917
)
888918

889-
get_project_id_mock.assert_called_once_with(_TEST_PROJECT)
919+
calls = [
920+
mock.call(project_number=_TEST_PROJECT_ID, credentials=mock.ANY),
921+
mock.call(_TEST_PROJECT_ID),
922+
]
923+
get_project_id_mock.assert_has_calls(calls)
890924

891925
user_agent = otlp_span_exporter_mock.call_args.kwargs["headers"]["User-Agent"]
892926
assert (

vertexai/agent_engines/templates/adk.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -657,6 +657,18 @@ def __init__(
657657
),
658658
}
659659

660+
def _serialize(self, obj: Any) -> Any:
661+
"""Serializes an object to be JSON compatible."""
662+
if hasattr(obj, "model_dump"):
663+
return obj.model_dump(mode="json")
664+
elif hasattr(obj, "dict"):
665+
return self._serialize(obj.dict())
666+
elif isinstance(obj, dict):
667+
return {k: self._serialize(v) for k, v in obj.items()}
668+
elif isinstance(obj, list):
669+
return [self._serialize(v) for v in obj]
670+
return obj
671+
660672
def _app_name(self) -> str:
661673
"""Returns the app name."""
662674
app = self._tmpl_attrs.get("app")
@@ -1062,7 +1074,7 @@ async def async_stream_query(
10621074
)
10631075
if not session_id:
10641076
session = await self.async_create_session(user_id=user_id)
1065-
session_id = session.id
1077+
session_id = session["id"]
10661078
if session_events is not None:
10671079
# We allow for session_events to be an empty list.
10681080
from google.adk.events.event import Event
@@ -1163,7 +1175,7 @@ def stream_query(
11631175
self.set_up()
11641176
if not session_id:
11651177
session = self.create_session(user_id=user_id)
1166-
session_id = session.id
1178+
session_id = session["id"]
11671179
run_config = _validate_run_config(run_config)
11681180
if run_config:
11691181
for event in self._tmpl_attrs.get("runner").run(
@@ -1469,7 +1481,7 @@ async def async_create_session(
14691481
state=state,
14701482
**kwargs,
14711483
)
1472-
return session
1484+
return self._serialize(session)
14731485

14741486
def create_session(
14751487
self,

0 commit comments

Comments
 (0)