diff --git a/internal/servers/plugin/v3/plugin.go b/internal/servers/plugin/v3/plugin.go index c84771597b..83dcc9a0c6 100644 --- a/internal/servers/plugin/v3/plugin.go +++ b/internal/servers/plugin/v3/plugin.go @@ -427,16 +427,23 @@ func (s *Server) Transform(stream pb.Plugin_TransformServer) error { // The reading never closes the writer, because it's up to the Plugin to decide when to finish // writing, regardless of if the reading finished. eg.Go(func() error { + var sendErr error for record := range sendRecords { + // We cannot terminate the stream here, because the plugin may still be sending records. So if error was returned channel has to be drained + if sendErr != nil { + continue + } recordBytes, err := pb.RecordToBytes(record) if err != nil { - return status.Errorf(codes.Internal, "failed to convert record to bytes: %v", err) + sendErr = status.Errorf(codes.Internal, "failed to convert record to bytes: %v", err) + continue } if err := stream.Send(&pb.Transform_Response{Record: recordBytes}); err != nil { - return status.Errorf(codes.Internal, "error sending response: %v", err) + sendErr = status.Errorf(codes.Internal, "error sending response: %v", err) + continue } } - return nil + return sendErr }) // Read records from source to transformer diff --git a/internal/servers/plugin/v3/plugin_test.go b/internal/servers/plugin/v3/plugin_test.go index b6cf41302a..5ad3c47e13 100644 --- a/internal/servers/plugin/v3/plugin_test.go +++ b/internal/servers/plugin/v3/plugin_test.go @@ -2,8 +2,12 @@ package plugin import ( "context" + "errors" "io" + "strings" + "sync/atomic" "testing" + "time" "github.com/apache/arrow-go/v18/arrow" "github.com/apache/arrow-go/v18/arrow/array" @@ -248,3 +252,143 @@ func (*mockSourceColumnAdderPluginClient) TransformSchema(_ context.Context, old return old.AddField(1, arrow.Field{Name: "source", Type: arrow.BinaryTypes.String}) } func (*mockSourceColumnAdderPluginClient) Close(context.Context) error { return nil } + +type testTransformPluginClient struct { + plugin.UnimplementedDestination + plugin.UnimplementedSource + recordsSent int32 +} + +func (c *testTransformPluginClient) Transform(ctx context.Context, recvRecords <-chan arrow.Record, sendRecords chan<- arrow.Record) error { + for record := range recvRecords { + select { + default: + time.Sleep(1 * time.Second) + sendRecords <- record + atomic.AddInt32(&c.recordsSent, 1) + case <-ctx.Done(): + return ctx.Err() + } + } + return nil +} + +func (*testTransformPluginClient) TransformSchema(_ context.Context, old *arrow.Schema) (*arrow.Schema, error) { + return old, nil +} + +func (*testTransformPluginClient) Close(context.Context) error { + return nil +} + +func TestTransformNoDeadlockOnSendError(t *testing.T) { + client := &testTransformPluginClient{} + p := plugin.NewPlugin("test", "development", func(context.Context, zerolog.Logger, []byte, plugin.NewClientOptions) (plugin.Client, error) { + return client, nil + }) + s := Server{ + Plugin: p, + } + _, err := s.Init(context.Background(), &pb.Init_Request{}) + require.NoError(t, err) + + // Create a channel to signal when Send was called + sendCalled := make(chan struct{}) + // Create a channel to signal when we should return from the test + done := make(chan struct{}) + defer close(done) + + stream := &mockTransformServerWithBlockingSend{ + incomingMessages: makeRequests(3), // Multiple messages to ensure Transform tries to keep sending + sendCalled: sendCalled, + done: done, + } + + // Run Transform in a goroutine with a timeout + errCh := make(chan error) + go func() { + errCh <- s.Transform(stream) + }() + + // Wait for the first Send to be called + select { + case <-sendCalled: + // Send was called, good + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for Send to be called") + } + + // Now wait for Transform to complete or timeout + select { + case err := <-errCh: + require.Error(t, err) + // Check for either the simulated error or context cancellation + if !strings.Contains(err.Error(), "simulated stream send error") && + !strings.Contains(err.Error(), "context canceled") { + t.Fatalf("unexpected error: %v", err) + } + case <-time.After(5 * time.Second): + t.Fatal("Transform got deadlocked") + } +} + +type mockTransformServerWithBlockingSend struct { + grpc.ServerStream + incomingMessages []*pb.Transform_Request + sendCalled chan struct{} + done chan struct{} + sendCount int32 +} + +func (s *mockTransformServerWithBlockingSend) Recv() (*pb.Transform_Request, error) { + if len(s.incomingMessages) > 0 { + msg := s.incomingMessages[0] + s.incomingMessages = s.incomingMessages[1:] + return msg, nil + } + return nil, io.EOF +} + +func (s *mockTransformServerWithBlockingSend) Send(*pb.Transform_Response) error { + // Signal that Send was called + select { + case s.sendCalled <- struct{}{}: + default: + } + + // Return error on first send + if atomic.AddInt32(&s.sendCount, 1) == 1 { + return errors.New("simulated stream send error") + } + + // Block until test is done + <-s.done + return nil +} + +func (*mockTransformServerWithBlockingSend) Context() context.Context { + return context.Background() +} + +func makeRequests(i int) []*pb.Transform_Request { + requests := make([]*pb.Transform_Request, i) + for i := range i { + requests[i] = makeRequestFromString("test") + } + return requests +} + +func makeRequestFromString(s string) *pb.Transform_Request { + record := makeRecordFromString(s) + bs, _ := pb.RecordToBytes(record) + return &pb.Transform_Request{Record: bs} +} + +func makeRecordFromString(s string) arrow.Record { + str := array.NewStringBuilder(memory.DefaultAllocator) + str.AppendString(s) + arr := str.NewStringArray() + sch := arrow.NewSchema([]arrow.Field{{Name: "col1", Type: arrow.BinaryTypes.String}}, nil) + + return array.NewRecord(sch, []arrow.Array{arr}, 1) +}