@@ -46,46 +46,65 @@ google::protobuf::Message* ToBaseMessage(Message& message) {
4646 }
4747}
4848
49+ ugrpc::impl::AsyncMethodInvocation::WaitStatus WaitAndTryCancelIfNeeded (
50+ ugrpc::impl::AsyncMethodInvocation& invocation,
51+ engine::Deadline deadline,
52+ grpc::ClientContext& context
53+ ) noexcept ;
54+
55+ ugrpc::impl::AsyncMethodInvocation::WaitStatus
56+ WaitAndTryCancelIfNeeded (ugrpc::impl::AsyncMethodInvocation& invocation, grpc::ClientContext& context) noexcept ;
57+
4958void CheckOk (CallState& state, AsyncMethodInvocation::WaitStatus status, std::string_view stage);
5059
5160template <typename GrpcStream>
5261void StartCall (GrpcStream& stream, CallState& state) {
5362 AsyncMethodInvocation start_call;
5463 stream.StartCall (start_call.GetTag ());
55- CheckOk (state, Wait (start_call, state.GetContext ()), " StartCall" );
64+ CheckOk (state, WaitAndTryCancelIfNeeded (start_call, state.GetContext ()), " StartCall" );
5665}
5766
5867void PrepareFinish (CallState& state);
5968
6069void ProcessFinish (CallState& state, google::protobuf::Message* final_response);
6170
62- void CheckFinishStatus (CallState& state);
71+ void ProcessFinishCancelled (CallState& state);
6372
64- void ProcessFinishResult (
65- CallState& state,
66- AsyncMethodInvocation::WaitStatus wait_status,
67- google::protobuf::Message* final_response,
68- bool throw_on_error
69- );
73+ void CheckFinishStatus (CallState& state);
7074
7175template <typename GrpcStream>
7276void Finish (GrpcStream& stream, CallState& state, google::protobuf::Message* final_response, bool throw_on_error) {
7377 PrepareFinish (state);
7478
75- FinishAsyncMethodInvocation finish (state) ;
79+ FinishAsyncMethodInvocation finish;
7680 auto & status = state.GetStatus ();
7781 stream.Finish (&status, finish.GetTag ());
7882
79- const auto wait_status = Wait (finish, state.GetContext ());
80- ProcessFinishResult (state, wait_status, final_response, throw_on_error);
83+ const auto wait_status = WaitAndTryCancelIfNeeded (finish, state.GetContext ());
84+ if (wait_status == impl::AsyncMethodInvocation::WaitStatus::kCancelled ) {
85+ ProcessFinishCancelled (state);
86+ // Finish AsyncMethodInvocation will be awaited in its destructor.
87+ if (throw_on_error) {
88+ throw RpcCancelledError (state.GetCallName (), " Finish" );
89+ }
90+ }
91+
92+ UINVARIANT (
93+ impl::AsyncMethodInvocation::WaitStatus::kOk == wait_status, " Client-side Finish: ok should always be true"
94+ );
95+ state.GetStatsScope ().SetFinishTime (finish.GetFinishTime ());
96+ ProcessFinish (state, final_response);
97+ if (throw_on_error) {
98+ CheckFinishStatus (state);
99+ }
81100}
82101
83102template <typename GrpcStream, typename Response>
84103[[nodiscard]] bool Read (GrpcStream& stream, Response& response, CallState& state) {
85104 UINVARIANT (state.IsReadAvailable (), " 'impl::Read' called on a finished call" );
86105 AsyncMethodInvocation read;
87106 stream.Read (&response, read.GetTag ());
88- const auto wait_status = Wait (read, state.GetContext ());
107+ const auto wait_status = WaitAndTryCancelIfNeeded (read, state.GetContext ());
89108 if (wait_status == impl::AsyncMethodInvocation::WaitStatus::kCancelled ) {
90109 state.GetStatsScope ().OnCancelled ();
91110 }
@@ -105,7 +124,7 @@ bool Write(GrpcStream& stream, const Request& request, grpc::WriteOptions option
105124 UINVARIANT (state.IsWriteAvailable (), " 'impl::Write' called on a stream that is closed for writes" );
106125 AsyncMethodInvocation write;
107126 stream.Write (request, options, write.GetTag ());
108- const auto result = Wait (write, state.GetContext ());
127+ const auto result = WaitAndTryCancelIfNeeded (write, state.GetContext ());
109128 if (result == impl::AsyncMethodInvocation::WaitStatus::kCancelled ) {
110129 state.GetStatsScope ().OnCancelled ();
111130 }
@@ -120,7 +139,7 @@ void WriteAndCheck(GrpcStream& stream, const Request& request, grpc::WriteOption
120139 UINVARIANT (state.IsWriteAndCheckAvailable (), " 'impl::WriteAndCheck' called on a finished or closed stream" );
121140 AsyncMethodInvocation write;
122141 stream.Write (request, options, write.GetTag ());
123- CheckOk (state, Wait (write, state.GetContext ()), " WriteAndCheck" );
142+ CheckOk (state, WaitAndTryCancelIfNeeded (write, state.GetContext ()), " WriteAndCheck" );
124143}
125144
126145template <typename GrpcStream>
@@ -129,7 +148,7 @@ bool WritesDone(GrpcStream& stream, CallState& state) {
129148 state.SetWritesFinished ();
130149 AsyncMethodInvocation writes_done;
131150 stream.WritesDone (writes_done.GetTag ());
132- const auto wait_status = Wait (writes_done, state.GetContext ());
151+ const auto wait_status = WaitAndTryCancelIfNeeded (writes_done, state.GetContext ());
133152 if (wait_status == impl::AsyncMethodInvocation::WaitStatus::kCancelled ) {
134153 state.GetStatsScope ().OnCancelled ();
135154 }
0 commit comments