Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions internal/servers/plugin/v3/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -427,16 +427,23 @@ func (s *Server) Transform(stream pb.Plugin_TransformServer) error {
// The reading never closes the writer, because it's up to the Plugin to decide when to finish
// writing, regardless of if the reading finished.
eg.Go(func() error {
var sendErr error
for record := range sendRecords {
// We cannot terminate the stream here, because the plugin may still be sending records. So if error was returned channel has to be drained
if sendErr != nil {
continue
}
recordBytes, err := pb.RecordToBytes(record)
if err != nil {
return status.Errorf(codes.Internal, "failed to convert record to bytes: %v", err)
sendErr = status.Errorf(codes.Internal, "failed to convert record to bytes: %v", err)
continue
}
if err := stream.Send(&pb.Transform_Response{Record: recordBytes}); err != nil {
return status.Errorf(codes.Internal, "error sending response: %v", err)
sendErr = status.Errorf(codes.Internal, "error sending response: %v", err)
continue
}
}
return nil
return sendErr
})

// Read records from source to transformer
Expand Down
144 changes: 144 additions & 0 deletions internal/servers/plugin/v3/plugin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,12 @@ package plugin

import (
"context"
"errors"
"io"
"strings"
"sync/atomic"
"testing"
"time"

"github.com/apache/arrow-go/v18/arrow"
"github.com/apache/arrow-go/v18/arrow/array"
Expand Down Expand Up @@ -248,3 +252,143 @@ func (*mockSourceColumnAdderPluginClient) TransformSchema(_ context.Context, old
return old.AddField(1, arrow.Field{Name: "source", Type: arrow.BinaryTypes.String})
}
func (*mockSourceColumnAdderPluginClient) Close(context.Context) error { return nil }

type testTransformPluginClient struct {
plugin.UnimplementedDestination
plugin.UnimplementedSource
recordsSent int32
}

func (c *testTransformPluginClient) Transform(ctx context.Context, recvRecords <-chan arrow.Record, sendRecords chan<- arrow.Record) error {
for record := range recvRecords {
select {
default:
time.Sleep(1 * time.Second)
sendRecords <- record
atomic.AddInt32(&c.recordsSent, 1)
case <-ctx.Done():
return ctx.Err()
}
}
return nil
}

func (*testTransformPluginClient) TransformSchema(_ context.Context, old *arrow.Schema) (*arrow.Schema, error) {
return old, nil
}

func (*testTransformPluginClient) Close(context.Context) error {
return nil
}

func TestTransformNoDeadlockOnSendError(t *testing.T) {
client := &testTransformPluginClient{}
p := plugin.NewPlugin("test", "development", func(context.Context, zerolog.Logger, []byte, plugin.NewClientOptions) (plugin.Client, error) {
return client, nil
})
s := Server{
Plugin: p,
}
_, err := s.Init(context.Background(), &pb.Init_Request{})
require.NoError(t, err)

// Create a channel to signal when Send was called
sendCalled := make(chan struct{})
// Create a channel to signal when we should return from the test
done := make(chan struct{})
defer close(done)

stream := &mockTransformServerWithBlockingSend{
incomingMessages: makeRequests(3), // Multiple messages to ensure Transform tries to keep sending
sendCalled: sendCalled,
done: done,
}

// Run Transform in a goroutine with a timeout
errCh := make(chan error)
go func() {
errCh <- s.Transform(stream)
}()

// Wait for the first Send to be called
select {
case <-sendCalled:
// Send was called, good
case <-time.After(5 * time.Second):
t.Fatal("timeout waiting for Send to be called")
}

// Now wait for Transform to complete or timeout
select {
case err := <-errCh:
require.Error(t, err)
// Check for either the simulated error or context cancellation
if !strings.Contains(err.Error(), "simulated stream send error") &&
!strings.Contains(err.Error(), "context canceled") {
t.Fatalf("unexpected error: %v", err)
}
case <-time.After(5 * time.Second):
t.Fatal("Transform got deadlocked")
}
}

type mockTransformServerWithBlockingSend struct {
grpc.ServerStream
incomingMessages []*pb.Transform_Request
sendCalled chan struct{}
done chan struct{}
sendCount int32
}

func (s *mockTransformServerWithBlockingSend) Recv() (*pb.Transform_Request, error) {
if len(s.incomingMessages) > 0 {
msg := s.incomingMessages[0]
s.incomingMessages = s.incomingMessages[1:]
return msg, nil
}
return nil, io.EOF
}

func (s *mockTransformServerWithBlockingSend) Send(*pb.Transform_Response) error {
// Signal that Send was called
select {
case s.sendCalled <- struct{}{}:
default:
}

// Return error on first send
if atomic.AddInt32(&s.sendCount, 1) == 1 {
return errors.New("simulated stream send error")
}

// Block until test is done
<-s.done
return nil
}

func (*mockTransformServerWithBlockingSend) Context() context.Context {
return context.Background()
}

func makeRequests(i int) []*pb.Transform_Request {
requests := make([]*pb.Transform_Request, i)
for i := range i {
requests[i] = makeRequestFromString("test")
}
return requests
}

func makeRequestFromString(s string) *pb.Transform_Request {
record := makeRecordFromString(s)
bs, _ := pb.RecordToBytes(record)
return &pb.Transform_Request{Record: bs}
}

func makeRecordFromString(s string) arrow.Record {
str := array.NewStringBuilder(memory.DefaultAllocator)
str.AppendString(s)
arr := str.NewStringArray()
sch := arrow.NewSchema([]arrow.Field{{Name: "col1", Type: arrow.BinaryTypes.String}}, nil)

return array.NewRecord(sch, []arrow.Array{arr}, 1)
}
Loading