Skip to content

Commit e3696bb

Browse files
committed
fix: implement more robust retries for Watch
1 parent a2cb94f commit e3696bb

4 files changed

Lines changed: 170 additions & 38 deletions

File tree

internal/client/client.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ const (
4444
defaultRetryJitterFraction = 0.5
4545
bulkImportRoute = "/authzed.api.v1.ExperimentalService/BulkImportRelationships"
4646
importBulkRoute = "/authzed.api.v1.PermissionsService/ImportBulkRelationships"
47+
watchRoute = "/authzed.api.v1.WatchService/Watch"
4748
)
4849

4950
// NewClient defines an (overridable) means of creating a new client.
@@ -232,7 +233,7 @@ func DialOptsFromFlags(cmd *cobra.Command, token storage.Token) ([]grpc.DialOpti
232233

233234
streamInterceptors := []grpc.StreamClientInterceptor{
234235
zgrpcutil.StreamLogDispatchTrailers,
235-
selector.StreamClientInterceptor(retry.StreamClientInterceptor(retryOpts...), selector.MatchFunc(isNoneOf(bulkImportRoute, importBulkRoute))),
236+
selector.StreamClientInterceptor(retry.StreamClientInterceptor(retryOpts...), selector.MatchFunc(isNoneOf(bulkImportRoute, importBulkRoute, watchRoute))),
236237
}
237238

238239
if !cobrautil.MustGetBool(cmd, "skip-version-check") {

internal/client/client_test.go

Lines changed: 48 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -133,35 +133,41 @@ func TestGetCurrentTokenWithCLIOverrideWithoutSecretFile(t *testing.T) {
133133
require.Equal(&bTrue, token.Insecure)
134134
}
135135

136-
type fakeSchemaServer struct {
136+
type fakeServer struct {
137137
v1.UnimplementedSchemaServiceServer
138138
v1.UnimplementedExperimentalServiceServer
139+
v1.UnimplementedWatchServiceServer
139140
v1.UnimplementedPermissionsServiceServer
140141
testFunc func()
141142
}
142143

143-
func (fss *fakeSchemaServer) ReadSchema(_ context.Context, _ *v1.ReadSchemaRequest) (*v1.ReadSchemaResponse, error) {
144+
func (fss *fakeServer) ReadSchema(_ context.Context, _ *v1.ReadSchemaRequest) (*v1.ReadSchemaResponse, error) {
144145
fss.testFunc()
145146
return nil, status.Error(codes.Unavailable, "")
146147
}
147148

148-
func (fss *fakeSchemaServer) BulkImportRelationships(grpc.ClientStreamingServer[v1.BulkImportRelationshipsRequest, v1.BulkImportRelationshipsResponse]) error {
149+
func (fss *fakeServer) BulkImportRelationships(grpc.ClientStreamingServer[v1.BulkImportRelationshipsRequest, v1.BulkImportRelationshipsResponse]) error {
149150
fss.testFunc()
150151
return status.Errorf(codes.Aborted, "")
151152
}
152153

153-
func (fss *fakeSchemaServer) ImportBulkRelationships(grpc.ClientStreamingServer[v1.ImportBulkRelationshipsRequest, v1.ImportBulkRelationshipsResponse]) error {
154+
func (fss *fakeServer) ImportBulkRelationships(grpc.ClientStreamingServer[v1.ImportBulkRelationshipsRequest, v1.ImportBulkRelationshipsResponse]) error {
154155
fss.testFunc()
155156
return status.Errorf(codes.Aborted, "")
156157
}
157158

159+
func (fss *fakeServer) Watch(*v1.WatchRequest, grpc.ServerStreamingServer[v1.WatchResponse]) error {
160+
fss.testFunc()
161+
return status.Errorf(codes.Unavailable, "")
162+
}
163+
158164
func TestRetries(t *testing.T) {
159165
ctx := t.Context()
160166
var callCount uint
161167
lis := bufconn.Listen(1024 * 1024)
162168
s := grpc.NewServer()
163169

164-
fakeServer := &fakeSchemaServer{testFunc: func() {
170+
fakeServer := &fakeServer{testFunc: func() {
165171
callCount++
166172
}}
167173
v1.RegisterSchemaServiceServer(s, fakeServer)
@@ -190,22 +196,25 @@ func TestRetries(t *testing.T) {
190196
c, err := authzed.NewClient("passthrough://bufnet", dialOpts...)
191197
require.NoError(t, err)
192198

193-
_, err = c.ReadSchema(ctx, &v1.ReadSchemaRequest{})
194-
grpcutil.RequireStatus(t, codes.Unavailable, err)
195-
require.Equal(t, retries, callCount)
199+
t.Run("read_schema", func(t *testing.T) {
200+
_, err = c.ReadSchema(ctx, &v1.ReadSchemaRequest{})
201+
grpcutil.RequireStatus(t, codes.Unavailable, err)
202+
require.Equal(t, retries, callCount)
203+
})
196204
}
197205

198-
func TestDoesNotRetryBackupRestore(t *testing.T) {
206+
func TestDoesNotRetry(t *testing.T) {
199207
ctx := t.Context()
200208
var callCount uint
201209
lis := bufconn.Listen(1024 * 1024)
202210
s := grpc.NewServer()
203211

204-
fakeServer := &fakeSchemaServer{testFunc: func() {
212+
fakeServer := &fakeServer{testFunc: func() {
205213
callCount++
206214
}}
207215
v1.RegisterPermissionsServiceServer(s, fakeServer)
208216
v1.RegisterExperimentalServiceServer(s, fakeServer)
217+
v1.RegisterWatchServiceServer(s, fakeServer)
209218

210219
go func() {
211220
_ = s.Serve(lis)
@@ -231,20 +240,34 @@ func TestDoesNotRetryBackupRestore(t *testing.T) {
231240
c, err := authzed.NewClientWithExperimentalAPIs("passthrough://bufnet", dialOpts...)
232241
require.NoError(t, err)
233242

234-
ibc, err := c.ImportBulkRelationships(ctx)
235-
require.NoError(t, err)
236-
err = ibc.SendMsg(&v1.ImportBulkRelationshipsRequest{})
237-
require.NoError(t, err)
238-
_, err = ibc.CloseAndRecv()
239-
grpcutil.RequireStatus(t, codes.Aborted, err)
240-
require.Equal(t, uint(1), callCount)
243+
t.Run("import_bulk", func(t *testing.T) {
244+
ibc, err := c.ImportBulkRelationships(ctx)
245+
require.NoError(t, err)
246+
err = ibc.SendMsg(&v1.ImportBulkRelationshipsRequest{})
247+
require.NoError(t, err)
248+
_, err = ibc.CloseAndRecv()
249+
grpcutil.RequireStatus(t, codes.Aborted, err)
250+
require.Equal(t, uint(1), callCount)
251+
})
241252

242-
callCount = 0
243-
bic, err := c.BulkImportRelationships(ctx)
244-
require.NoError(t, err)
245-
err = bic.SendMsg(&v1.BulkImportRelationshipsRequest{})
246-
require.NoError(t, err)
247-
_, err = bic.CloseAndRecv()
248-
grpcutil.RequireStatus(t, codes.Aborted, err)
249-
require.Equal(t, uint(1), callCount)
253+
t.Run("bulk_import", func(t *testing.T) {
254+
callCount = 0
255+
bic, err := c.BulkImportRelationships(ctx)
256+
require.NoError(t, err)
257+
err = bic.SendMsg(&v1.BulkImportRelationshipsRequest{})
258+
require.NoError(t, err)
259+
_, err = bic.CloseAndRecv()
260+
grpcutil.RequireStatus(t, codes.Aborted, err)
261+
require.Equal(t, uint(1), callCount)
262+
})
263+
264+
t.Run("watch", func(t *testing.T) {
265+
callCount = 0
266+
watchReq, err := c.Watch(ctx, &v1.WatchRequest{})
267+
require.NoError(t, err)
268+
resp, err := watchReq.Recv()
269+
require.Nil(t, resp)
270+
grpcutil.RequireStatus(t, codes.Unavailable, err)
271+
require.Equal(t, uint(1), callCount)
272+
})
250273
}

internal/commands/watch.go

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"syscall"
1010
"time"
1111

12+
"github.com/rs/zerolog/log"
1213
"github.com/spf13/cobra"
1314

1415
v1 "github.com/authzed/authzed-go/proto/authzed/api/v1"
@@ -74,26 +75,27 @@ func watchCmdFunc(cmd *cobra.Command, _ []string) error {
7475
relFilters = append(relFilters, relFilter)
7576
}
7677

77-
req := &v1.WatchRequest{
78-
OptionalObjectTypes: watchObjectTypes,
79-
OptionalRelationshipFilters: relFilters,
80-
}
81-
if watchRevision != "" {
82-
req.OptionalStartCursor = &v1.ZedToken{Token: watchRevision}
83-
}
84-
8578
ctx, cancel := context.WithCancel(cmd.Context())
8679
defer cancel()
8780

8881
signalctx, interruptCancel := signal.NotifyContext(ctx, os.Interrupt, syscall.SIGTERM, syscall.SIGINT)
8982
defer interruptCancel()
9083

91-
watchStream, err := cli.Watch(ctx, req)
92-
if err != nil {
93-
return err
84+
req := &v1.WatchRequest{
85+
OptionalObjectTypes: watchObjectTypes,
86+
OptionalRelationshipFilters: relFilters,
9487
}
9588

9689
for {
90+
if watchRevision != "" {
91+
req.OptionalStartCursor = &v1.ZedToken{Token: watchRevision}
92+
}
93+
94+
watchStream, err := cli.Watch(ctx, req)
95+
if err != nil {
96+
return err
97+
}
98+
9799
select {
98100
case <-signalctx.Done():
99101
console.Errorf("stream interrupted after program termination\n")
@@ -104,7 +106,15 @@ func watchCmdFunc(cmd *cobra.Command, _ []string) error {
104106
default:
105107
resp, err := watchStream.Recv()
106108
if err != nil {
107-
return err
109+
if !strings.Contains(err.Error(), "stream timeout") && !strings.Contains(err.Error(), "RST_STREAM closed stream") {
110+
return err
111+
}
112+
log.Trace().Err(err).Msg("error receiving from watch stream. will retry")
113+
continue
114+
}
115+
116+
if resp.ChangesThrough != nil {
117+
watchRevision = resp.ChangesThrough.Token
108118
}
109119

110120
for _, update := range resp.Updates {

internal/commands/watch_test.go

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,21 @@
11
package commands
22

33
import (
4+
"context"
5+
"net"
46
"reflect"
7+
"sync"
58
"testing"
9+
"time"
10+
11+
"github.com/stretchr/testify/require"
12+
"google.golang.org/grpc"
13+
"google.golang.org/grpc/test/bufconn"
614

715
v1 "github.com/authzed/authzed-go/proto/authzed/api/v1"
16+
17+
"github.com/authzed/zed/internal/client"
18+
zedtesting "github.com/authzed/zed/internal/testing"
819
)
920

1021
func TestParseRelationshipFilter(t *testing.T) {
@@ -108,3 +119,90 @@ func TestParseRelationshipFilter(t *testing.T) {
108119
}
109120
}
110121
}
122+
123+
type mockWatchServer struct {
124+
v1.UnimplementedWatchServiceServer
125+
sendOnce uint
126+
}
127+
128+
func (mws *mockWatchServer) Watch(_ *v1.WatchRequest, stream grpc.ServerStreamingServer[v1.WatchResponse]) error {
129+
update := &v1.RelationshipUpdate{
130+
Operation: v1.RelationshipUpdate_OPERATION_CREATE,
131+
Relationship: &v1.Relationship{
132+
Resource: &v1.ObjectReference{
133+
ObjectType: "document",
134+
ObjectId: "1",
135+
},
136+
Relation: "viewer",
137+
Subject: &v1.SubjectReference{
138+
Object: &v1.ObjectReference{
139+
ObjectType: "user",
140+
ObjectId: "alice",
141+
},
142+
},
143+
},
144+
}
145+
146+
response := &v1.WatchResponse{
147+
Updates: []*v1.RelationshipUpdate{update},
148+
ChangesThrough: &v1.ZedToken{Token: "revision1"},
149+
}
150+
151+
if mws.sendOnce == 0 {
152+
mws.sendOnce++
153+
return stream.Send(response)
154+
}
155+
156+
return nil
157+
}
158+
159+
func TestWatchCmdFunc(t *testing.T) {
160+
lis := bufconn.Listen(1024 * 1024)
161+
s := grpc.NewServer()
162+
163+
mockServer := &mockWatchServer{}
164+
v1.RegisterWatchServiceServer(s, mockServer)
165+
166+
go func() {
167+
_ = s.Serve(lis)
168+
}()
169+
t.Cleanup(s.Stop)
170+
171+
conn, err := grpc.NewClient("passthrough://bufnet",
172+
grpc.WithContextDialer(func(context.Context, string) (net.Conn, error) {
173+
return lis.Dial()
174+
}),
175+
grpc.WithInsecure(), // nolint:staticcheck
176+
)
177+
require.NoError(t, err)
178+
t.Cleanup(func() { conn.Close() })
179+
180+
client.NewClient = zedtesting.ClientFromConn(conn)
181+
182+
ctx, cancel := context.WithCancel(context.Background())
183+
defer cancel()
184+
185+
cmd := zedtesting.CreateTestCobraCommandWithFlagValue(t,
186+
zedtesting.StringFlag{FlagName: "log-level", FlagValue: "trace", Changed: true},
187+
)
188+
cmd.SetContext(ctx)
189+
190+
var wg sync.WaitGroup
191+
wg.Add(1)
192+
193+
watchErr := make(chan error, 1)
194+
195+
go func() {
196+
defer wg.Done()
197+
watchErr <- watchCmdFunc(cmd, []string{})
198+
}()
199+
200+
time.Sleep(100 * time.Millisecond)
201+
202+
cancel()
203+
204+
wg.Wait()
205+
206+
err = <-watchErr
207+
require.ErrorContains(t, err, "EOF")
208+
}

0 commit comments

Comments
 (0)