Skip to content

Commit 979db72

Browse files
committed
fix: implement more robust retries for Watch
1 parent 80884b3 commit 979db72

4 files changed

Lines changed: 177 additions & 37 deletions

File tree

internal/client/client.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ const (
4444
defaultRetryJitterFraction = 0.5
4545
importBulkRoute = "/authzed.api.v1.PermissionsService/ImportBulkRelationships"
4646
exportBulkRoute = "/authzed.api.v1.PermissionsService/ExportBulkRelationships"
47+
watchRoute = "/authzed.api.v1.WatchService/Watch"
4748
)
4849

4950
// NewClient defines an (overridable) means of creating a new client.
@@ -235,7 +236,7 @@ func DialOptsFromFlags(cmd *cobra.Command, token storage.Token) ([]grpc.DialOpti
235236
// retrying the bulk import in backup/restore logic is handled manually.
236237
// retrying bulk export is also handled manually, because the default behavior is
237238
// to start at the beginning of the stream, which produces duplicate values.
238-
selector.StreamClientInterceptor(retry.StreamClientInterceptor(retryOpts...), selector.MatchFunc(isNoneOf(importBulkRoute, exportBulkRoute))),
239+
selector.StreamClientInterceptor(retry.StreamClientInterceptor(retryOpts...), selector.MatchFunc(isNoneOf(importBulkRoute, exportBulkRoute, watchRoute))),
239240
}
240241

241242
if !cobrautil.MustGetBool(cmd, "skip-version-check") {

internal/client/client_test.go

Lines changed: 36 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -133,30 +133,36 @@ func TestGetCurrentTokenWithCLIOverrideWithoutSecretFile(t *testing.T) {
133133
require.Equal(&bTrue, token.Insecure)
134134
}
135135

136-
type fakeSchemaServer struct {
136+
type fakeServer struct {
137137
v1.UnimplementedSchemaServiceServer
138138
v1.UnimplementedExperimentalServiceServer
139+
v1.UnimplementedWatchServiceServer
139140
v1.UnimplementedPermissionsServiceServer
140141
testFunc func()
141142
}
142143

143-
func (fss *fakeSchemaServer) ReadSchema(_ context.Context, _ *v1.ReadSchemaRequest) (*v1.ReadSchemaResponse, error) {
144+
func (fss *fakeServer) ReadSchema(_ context.Context, _ *v1.ReadSchemaRequest) (*v1.ReadSchemaResponse, error) {
144145
fss.testFunc()
145146
return nil, status.Error(codes.Unavailable, "")
146147
}
147148

148-
func (fss *fakeSchemaServer) ImportBulkRelationships(grpc.ClientStreamingServer[v1.ImportBulkRelationshipsRequest, v1.ImportBulkRelationshipsResponse]) error {
149+
func (fss *fakeServer) ImportBulkRelationships(grpc.ClientStreamingServer[v1.ImportBulkRelationshipsRequest, v1.ImportBulkRelationshipsResponse]) error {
149150
fss.testFunc()
150151
return status.Errorf(codes.Aborted, "")
151152
}
152153

154+
func (fss *fakeServer) Watch(*v1.WatchRequest, grpc.ServerStreamingServer[v1.WatchResponse]) error {
155+
fss.testFunc()
156+
return status.Errorf(codes.Unavailable, "")
157+
}
158+
153159
func TestRetries(t *testing.T) {
154160
ctx := t.Context()
155161
var callCount uint
156162
lis := bufconn.Listen(1024 * 1024)
157163
s := grpc.NewServer()
158164

159-
fakeServer := &fakeSchemaServer{testFunc: func() {
165+
fakeServer := &fakeServer{testFunc: func() {
160166
callCount++
161167
}}
162168
v1.RegisterSchemaServiceServer(s, fakeServer)
@@ -185,22 +191,25 @@ func TestRetries(t *testing.T) {
185191
c, err := authzed.NewClient("passthrough://bufnet", dialOpts...)
186192
require.NoError(t, err)
187193

188-
_, err = c.ReadSchema(ctx, &v1.ReadSchemaRequest{})
189-
grpcutil.RequireStatus(t, codes.Unavailable, err)
190-
require.Equal(t, retries, callCount)
194+
t.Run("read_schema", func(t *testing.T) {
195+
_, err = c.ReadSchema(ctx, &v1.ReadSchemaRequest{})
196+
grpcutil.RequireStatus(t, codes.Unavailable, err)
197+
require.Equal(t, retries, callCount)
198+
})
191199
}
192200

193-
func TestDoesNotRetryBackupRestore(t *testing.T) {
201+
func TestDoesNotRetry(t *testing.T) {
194202
ctx := t.Context()
195203
var callCount uint
196204
lis := bufconn.Listen(1024 * 1024)
197205
s := grpc.NewServer()
198206

199-
fakeServer := &fakeSchemaServer{testFunc: func() {
207+
fakeServer := &fakeServer{testFunc: func() {
200208
callCount++
201209
}}
202210
v1.RegisterPermissionsServiceServer(s, fakeServer)
203211
v1.RegisterExperimentalServiceServer(s, fakeServer)
212+
v1.RegisterWatchServiceServer(s, fakeServer)
204213

205214
go func() {
206215
_ = s.Serve(lis)
@@ -226,20 +235,23 @@ func TestDoesNotRetryBackupRestore(t *testing.T) {
226235
c, err := authzed.NewClientWithExperimentalAPIs("passthrough://bufnet", dialOpts...)
227236
require.NoError(t, err)
228237

229-
ibc, err := c.ImportBulkRelationships(ctx)
230-
require.NoError(t, err)
231-
err = ibc.SendMsg(&v1.ImportBulkRelationshipsRequest{})
232-
require.NoError(t, err)
233-
_, err = ibc.CloseAndRecv()
234-
grpcutil.RequireStatus(t, codes.Aborted, err)
235-
require.Equal(t, uint(1), callCount)
238+
t.Run("import_bulk", func(t *testing.T) {
239+
ibc, err := c.ImportBulkRelationships(ctx)
240+
require.NoError(t, err)
241+
err = ibc.SendMsg(&v1.ImportBulkRelationshipsRequest{})
242+
require.NoError(t, err)
243+
_, err = ibc.CloseAndRecv()
244+
grpcutil.RequireStatus(t, codes.Aborted, err)
245+
require.Equal(t, uint(1), callCount)
246+
})
236247

237-
callCount = 0
238-
bic, err := c.ImportBulkRelationships(ctx)
239-
require.NoError(t, err)
240-
err = bic.SendMsg(&v1.ImportBulkRelationshipsRequest{})
241-
require.NoError(t, err)
242-
_, err = bic.CloseAndRecv()
243-
grpcutil.RequireStatus(t, codes.Aborted, err)
244-
require.Equal(t, uint(1), callCount)
248+
t.Run("watch", func(t *testing.T) {
249+
callCount = 0
250+
watchReq, err := c.Watch(ctx, &v1.WatchRequest{})
251+
require.NoError(t, err)
252+
resp, err := watchReq.Recv()
253+
require.Nil(t, resp)
254+
grpcutil.RequireStatus(t, codes.Unavailable, err)
255+
require.Equal(t, uint(1), callCount)
256+
})
245257
}

internal/commands/watch.go

Lines changed: 40 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,10 @@ import (
99
"syscall"
1010
"time"
1111

12+
"github.com/rs/zerolog/log"
1213
"github.com/spf13/cobra"
14+
"google.golang.org/grpc/codes"
15+
"google.golang.org/grpc/status"
1316

1417
v1 "github.com/authzed/authzed-go/proto/authzed/api/v1"
1518

@@ -74,26 +77,32 @@ func watchCmdFunc(cmd *cobra.Command, _ []string) error {
7477
relFilters = append(relFilters, relFilter)
7578
}
7679

77-
req := &v1.WatchRequest{
78-
OptionalObjectTypes: watchObjectTypes,
79-
OptionalRelationshipFilters: relFilters,
80-
}
81-
if watchRevision != "" {
82-
req.OptionalStartCursor = &v1.ZedToken{Token: watchRevision}
83-
}
84-
8580
ctx, cancel := context.WithCancel(cmd.Context())
8681
defer cancel()
8782

8883
signalctx, interruptCancel := signal.NotifyContext(ctx, os.Interrupt, syscall.SIGTERM, syscall.SIGINT)
8984
defer interruptCancel()
9085

91-
watchStream, err := cli.Watch(ctx, req)
92-
if err != nil {
93-
return err
86+
req := &v1.WatchRequest{
87+
OptionalObjectTypes: watchObjectTypes,
88+
OptionalRelationshipFilters: relFilters,
89+
OptionalUpdateKinds: []v1.WatchKind{
90+
v1.WatchKind_WATCH_KIND_INCLUDE_CHECKPOINTS,
91+
v1.WatchKind_WATCH_KIND_INCLUDE_SCHEMA_UPDATES,
92+
},
9493
}
9594

9695
for {
96+
log.Trace().Msg("calling watch with token " + watchRevision)
97+
if watchRevision != "" {
98+
req.OptionalStartCursor = &v1.ZedToken{Token: watchRevision}
99+
}
100+
101+
watchStream, err := cli.Watch(ctx, req)
102+
if err != nil {
103+
return err
104+
}
105+
97106
select {
98107
case <-signalctx.Done():
99108
console.Errorf("stream interrupted after program termination\n")
@@ -104,7 +113,26 @@ func watchCmdFunc(cmd *cobra.Command, _ []string) error {
104113
default:
105114
resp, err := watchStream.Recv()
106115
if err != nil {
107-
return err
116+
statusErr, ok := status.FromError(err)
117+
if !ok || (ok && statusErr.Code() != codes.Unavailable) {
118+
// cannot retry; return
119+
return err
120+
}
121+
122+
log.Trace().Err(err).Msg("will retry")
123+
continue
124+
}
125+
126+
if resp.ChangesThrough != nil {
127+
watchRevision = resp.ChangesThrough.Token
128+
log.Trace().Msg("updated watch revision to " + watchRevision)
129+
}
130+
131+
if resp.SchemaUpdated {
132+
console.Println("SCHEMA UPDATED")
133+
}
134+
if resp.IsCheckpoint {
135+
console.Println("CHECKPOINT REACHED")
108136
}
109137

110138
for _, update := range resp.Updates {

internal/commands/watch_test.go

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,21 @@
11
package commands
22

33
import (
4+
"context"
5+
"net"
46
"reflect"
7+
"sync"
58
"testing"
9+
"time"
10+
11+
"github.com/stretchr/testify/require"
12+
"google.golang.org/grpc"
13+
"google.golang.org/grpc/test/bufconn"
614

715
v1 "github.com/authzed/authzed-go/proto/authzed/api/v1"
16+
17+
"github.com/authzed/zed/internal/client"
18+
zedtesting "github.com/authzed/zed/internal/testing"
819
)
920

1021
func TestParseRelationshipFilter(t *testing.T) {
@@ -108,3 +119,91 @@ func TestParseRelationshipFilter(t *testing.T) {
108119
}
109120
}
110121
}
122+
123+
type mockWatchServer struct {
124+
v1.UnimplementedWatchServiceServer
125+
sendOnce uint
126+
}
127+
128+
func (mws *mockWatchServer) Watch(_ *v1.WatchRequest, stream grpc.ServerStreamingServer[v1.WatchResponse]) error {
129+
update := &v1.RelationshipUpdate{
130+
Operation: v1.RelationshipUpdate_OPERATION_CREATE,
131+
Relationship: &v1.Relationship{
132+
Resource: &v1.ObjectReference{
133+
ObjectType: "document",
134+
ObjectId: "1",
135+
},
136+
Relation: "viewer",
137+
Subject: &v1.SubjectReference{
138+
Object: &v1.ObjectReference{
139+
ObjectType: "user",
140+
ObjectId: "alice",
141+
},
142+
},
143+
},
144+
}
145+
146+
response := &v1.WatchResponse{
147+
Updates: []*v1.RelationshipUpdate{update},
148+
ChangesThrough: &v1.ZedToken{Token: "revision1"},
149+
}
150+
151+
if mws.sendOnce == 0 {
152+
mws.sendOnce++
153+
return stream.Send(response)
154+
}
155+
156+
return nil
157+
}
158+
159+
func TestWatchCmdFunc(t *testing.T) {
160+
lis := bufconn.Listen(1024 * 1024)
161+
s := grpc.NewServer()
162+
163+
mockServer := &mockWatchServer{}
164+
v1.RegisterWatchServiceServer(s, mockServer)
165+
166+
go func() {
167+
_ = s.Serve(lis)
168+
}()
169+
t.Cleanup(s.Stop)
170+
171+
conn, err := grpc.NewClient("passthrough://bufnet",
172+
grpc.WithContextDialer(func(context.Context, string) (net.Conn, error) {
173+
return lis.Dial()
174+
}),
175+
grpc.WithInsecure(), // nolint:staticcheck
176+
)
177+
require.NoError(t, err)
178+
t.Cleanup(func() { conn.Close() })
179+
180+
client.NewClient = zedtesting.ClientFromConn(conn)
181+
182+
ctx, cancel := context.WithCancel(context.Background())
183+
defer cancel()
184+
185+
cmd := zedtesting.CreateTestCobraCommandWithFlagValue(t,
186+
zedtesting.StringFlag{FlagName: "log-level", FlagValue: "trace", Changed: true},
187+
)
188+
cmd.SetContext(ctx)
189+
190+
var wg sync.WaitGroup
191+
wg.Add(1)
192+
193+
watchErr := make(chan error, 1)
194+
195+
go func() {
196+
defer wg.Done()
197+
watchErr <- watchCmdFunc(cmd, []string{})
198+
// NOTE: no assertions on the changes themselves, as the cmd just prints them out and we have no way of capturing that
199+
}()
200+
201+
time.Sleep(100 * time.Millisecond)
202+
203+
cancel()
204+
205+
wg.Wait()
206+
207+
err = <-watchErr
208+
require.ErrorContains(t, err, "EOF")
209+
}

0 commit comments

Comments
 (0)