Skip to content

Commit 7be8cfd

Browse files
committed
fix: implement more robust retries for Watch
1 parent 80884b3 commit 7be8cfd

4 files changed

Lines changed: 239 additions & 63 deletions

File tree

internal/client/client.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ const (
4444
defaultRetryJitterFraction = 0.5
4545
importBulkRoute = "/authzed.api.v1.PermissionsService/ImportBulkRelationships"
4646
exportBulkRoute = "/authzed.api.v1.PermissionsService/ExportBulkRelationships"
47+
watchRoute = "/authzed.api.v1.WatchService/Watch"
4748
)
4849

4950
// NewClient defines an (overridable) means of creating a new client.
@@ -235,7 +236,7 @@ func DialOptsFromFlags(cmd *cobra.Command, token storage.Token) ([]grpc.DialOpti
235236
// retrying the bulk import in backup/restore logic is handled manually.
236237
// retrying bulk export is also handled manually, because the default behavior is
237238
// to start at the beginning of the stream, which produces duplicate values.
238-
selector.StreamClientInterceptor(retry.StreamClientInterceptor(retryOpts...), selector.MatchFunc(isNoneOf(importBulkRoute, exportBulkRoute))),
239+
selector.StreamClientInterceptor(retry.StreamClientInterceptor(retryOpts...), selector.MatchFunc(isNoneOf(importBulkRoute, exportBulkRoute, watchRoute))),
239240
}
240241

241242
if !cobrautil.MustGetBool(cmd, "skip-version-check") {

internal/client/client_test.go

Lines changed: 36 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -133,30 +133,36 @@ func TestGetCurrentTokenWithCLIOverrideWithoutSecretFile(t *testing.T) {
133133
require.Equal(&bTrue, token.Insecure)
134134
}
135135

136-
type fakeSchemaServer struct {
136+
type fakeServer struct {
137137
v1.UnimplementedSchemaServiceServer
138138
v1.UnimplementedExperimentalServiceServer
139+
v1.UnimplementedWatchServiceServer
139140
v1.UnimplementedPermissionsServiceServer
140141
testFunc func()
141142
}
142143

143-
func (fss *fakeSchemaServer) ReadSchema(_ context.Context, _ *v1.ReadSchemaRequest) (*v1.ReadSchemaResponse, error) {
144+
func (fss *fakeServer) ReadSchema(_ context.Context, _ *v1.ReadSchemaRequest) (*v1.ReadSchemaResponse, error) {
144145
fss.testFunc()
145146
return nil, status.Error(codes.Unavailable, "")
146147
}
147148

148-
func (fss *fakeSchemaServer) ImportBulkRelationships(grpc.ClientStreamingServer[v1.ImportBulkRelationshipsRequest, v1.ImportBulkRelationshipsResponse]) error {
149+
func (fss *fakeServer) ImportBulkRelationships(grpc.ClientStreamingServer[v1.ImportBulkRelationshipsRequest, v1.ImportBulkRelationshipsResponse]) error {
149150
fss.testFunc()
150151
return status.Errorf(codes.Aborted, "")
151152
}
152153

154+
func (fss *fakeServer) Watch(*v1.WatchRequest, grpc.ServerStreamingServer[v1.WatchResponse]) error {
155+
fss.testFunc()
156+
return status.Errorf(codes.Unavailable, "")
157+
}
158+
153159
func TestRetries(t *testing.T) {
154160
ctx := t.Context()
155161
var callCount uint
156162
lis := bufconn.Listen(1024 * 1024)
157163
s := grpc.NewServer()
158164

159-
fakeServer := &fakeSchemaServer{testFunc: func() {
165+
fakeServer := &fakeServer{testFunc: func() {
160166
callCount++
161167
}}
162168
v1.RegisterSchemaServiceServer(s, fakeServer)
@@ -185,22 +191,25 @@ func TestRetries(t *testing.T) {
185191
c, err := authzed.NewClient("passthrough://bufnet", dialOpts...)
186192
require.NoError(t, err)
187193

188-
_, err = c.ReadSchema(ctx, &v1.ReadSchemaRequest{})
189-
grpcutil.RequireStatus(t, codes.Unavailable, err)
190-
require.Equal(t, retries, callCount)
194+
t.Run("read_schema", func(t *testing.T) {
195+
_, err = c.ReadSchema(ctx, &v1.ReadSchemaRequest{})
196+
grpcutil.RequireStatus(t, codes.Unavailable, err)
197+
require.Equal(t, retries, callCount)
198+
})
191199
}
192200

193-
func TestDoesNotRetryBackupRestore(t *testing.T) {
201+
func TestDoesNotRetry(t *testing.T) {
194202
ctx := t.Context()
195203
var callCount uint
196204
lis := bufconn.Listen(1024 * 1024)
197205
s := grpc.NewServer()
198206

199-
fakeServer := &fakeSchemaServer{testFunc: func() {
207+
fakeServer := &fakeServer{testFunc: func() {
200208
callCount++
201209
}}
202210
v1.RegisterPermissionsServiceServer(s, fakeServer)
203211
v1.RegisterExperimentalServiceServer(s, fakeServer)
212+
v1.RegisterWatchServiceServer(s, fakeServer)
204213

205214
go func() {
206215
_ = s.Serve(lis)
@@ -226,20 +235,23 @@ func TestDoesNotRetryBackupRestore(t *testing.T) {
226235
c, err := authzed.NewClientWithExperimentalAPIs("passthrough://bufnet", dialOpts...)
227236
require.NoError(t, err)
228237

229-
ibc, err := c.ImportBulkRelationships(ctx)
230-
require.NoError(t, err)
231-
err = ibc.SendMsg(&v1.ImportBulkRelationshipsRequest{})
232-
require.NoError(t, err)
233-
_, err = ibc.CloseAndRecv()
234-
grpcutil.RequireStatus(t, codes.Aborted, err)
235-
require.Equal(t, uint(1), callCount)
238+
t.Run("import_bulk", func(t *testing.T) {
239+
ibc, err := c.ImportBulkRelationships(ctx)
240+
require.NoError(t, err)
241+
err = ibc.SendMsg(&v1.ImportBulkRelationshipsRequest{})
242+
require.NoError(t, err)
243+
_, err = ibc.CloseAndRecv()
244+
grpcutil.RequireStatus(t, codes.Aborted, err)
245+
require.Equal(t, uint(1), callCount)
246+
})
236247

237-
callCount = 0
238-
bic, err := c.ImportBulkRelationships(ctx)
239-
require.NoError(t, err)
240-
err = bic.SendMsg(&v1.ImportBulkRelationshipsRequest{})
241-
require.NoError(t, err)
242-
_, err = bic.CloseAndRecv()
243-
grpcutil.RequireStatus(t, codes.Aborted, err)
244-
require.Equal(t, uint(1), callCount)
248+
t.Run("watch", func(t *testing.T) {
249+
callCount = 0
250+
watchReq, err := c.Watch(ctx, &v1.WatchRequest{})
251+
require.NoError(t, err)
252+
resp, err := watchReq.Recv()
253+
require.Nil(t, resp)
254+
grpcutil.RequireStatus(t, codes.Unavailable, err)
255+
require.Equal(t, uint(1), callCount)
256+
})
245257
}

internal/commands/watch.go

Lines changed: 83 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,10 @@ import (
99
"syscall"
1010
"time"
1111

12+
"github.com/rs/zerolog/log"
1213
"github.com/spf13/cobra"
14+
"google.golang.org/grpc/codes"
15+
"google.golang.org/grpc/status"
1316

1417
v1 "github.com/authzed/authzed-go/proto/authzed/api/v1"
1518

@@ -44,26 +47,29 @@ func RegisterWatchRelationshipCmd(parentCmd *cobra.Command) *cobra.Command {
4447

4548
var watchCmd = &cobra.Command{
4649
Use: "watch [object_types, ...] [start_cursor]",
47-
Short: "Watches the stream of relationship updates from the server",
50+
Short: "Watches the stream of relationship updates and schema updates from the server",
4851
Args: ValidationWrapper(cobra.RangeArgs(0, 2)),
4952
RunE: watchCmdFunc,
5053
Deprecated: "please use `zed relationships watch` instead",
5154
}
5255

5356
var watchRelationshipsCmd = &cobra.Command{
5457
Use: "watch [object_types, ...] [start_cursor]",
55-
Short: "Watches the stream of relationship updates from the server",
58+
Short: "Watches the stream of relationship updates and schema updates from the server",
5659
Args: ValidationWrapper(cobra.RangeArgs(0, 2)),
5760
RunE: watchCmdFunc,
5861
}
5962

6063
func watchCmdFunc(cmd *cobra.Command, _ []string) error {
61-
console.Printf("starting watch stream over types %v and revision %v\n", watchObjectTypes, watchRevision)
62-
63-
cli, err := client.NewClient(cmd)
64+
client, err := client.NewClient(cmd)
6465
if err != nil {
6566
return err
6667
}
68+
return watchCmdFuncImpl(cmd, client, processResponse)
69+
}
70+
71+
func watchCmdFuncImpl(cmd *cobra.Command, watchClient v1.WatchServiceClient, processResponse func(resp *v1.WatchResponse)) error {
72+
console.Printf("starting watch stream over types %v and revision %v\n", watchObjectTypes, watchRevision)
6773

6874
relFilters := make([]*v1.RelationshipFilter, 0, len(watchRelationshipFilters))
6975
for _, filter := range watchRelationshipFilters {
@@ -74,21 +80,26 @@ func watchCmdFunc(cmd *cobra.Command, _ []string) error {
7480
relFilters = append(relFilters, relFilter)
7581
}
7682

83+
ctx, cancel := context.WithCancel(cmd.Context())
84+
defer cancel()
85+
86+
signalctx, interruptCancel := signal.NotifyContext(ctx, os.Interrupt, syscall.SIGTERM, syscall.SIGINT)
87+
defer interruptCancel()
88+
7789
req := &v1.WatchRequest{
7890
OptionalObjectTypes: watchObjectTypes,
7991
OptionalRelationshipFilters: relFilters,
92+
OptionalUpdateKinds: []v1.WatchKind{
93+
v1.WatchKind_WATCH_KIND_INCLUDE_CHECKPOINTS, // keeps connection open during quiet periods
94+
v1.WatchKind_WATCH_KIND_INCLUDE_SCHEMA_UPDATES,
95+
},
8096
}
97+
8198
if watchRevision != "" {
8299
req.OptionalStartCursor = &v1.ZedToken{Token: watchRevision}
83100
}
84101

85-
ctx, cancel := context.WithCancel(cmd.Context())
86-
defer cancel()
87-
88-
signalctx, interruptCancel := signal.NotifyContext(ctx, os.Interrupt, syscall.SIGTERM, syscall.SIGINT)
89-
defer interruptCancel()
90-
91-
watchStream, err := cli.Watch(ctx, req)
102+
watchStream, err := watchClient.Watch(ctx, req)
92103
if err != nil {
93104
return err
94105
}
@@ -104,40 +115,74 @@ func watchCmdFunc(cmd *cobra.Command, _ []string) error {
104115
default:
105116
resp, err := watchStream.Recv()
106117
if err != nil {
107-
return err
108-
}
118+
ok, err := isRetryable(err)
119+
if !ok {
120+
return err
121+
}
109122

110-
for _, update := range resp.Updates {
111-
if watchTimestamps {
112-
console.Printf("%v: ", time.Now())
123+
log.Trace().Err(err).Msg("will retry from the last known revision " + watchRevision)
124+
req.OptionalStartCursor = &v1.ZedToken{Token: watchRevision}
125+
watchStream, err = watchClient.Watch(ctx, req)
126+
if err != nil {
127+
return err
113128
}
129+
continue
130+
}
114131

115-
switch update.Operation {
116-
case v1.RelationshipUpdate_OPERATION_CREATE:
117-
console.Printf("CREATED ")
132+
processResponse(resp)
133+
}
134+
}
135+
}
118136

119-
case v1.RelationshipUpdate_OPERATION_DELETE:
120-
console.Printf("DELETED ")
137+
func isRetryable(err error) (bool, error) {
138+
statusErr, ok := status.FromError(err)
139+
if !ok || (statusErr.Code() != codes.Unavailable) {
140+
return false, err
141+
}
142+
return true, nil
143+
}
121144

122-
case v1.RelationshipUpdate_OPERATION_TOUCH:
123-
console.Printf("TOUCHED ")
124-
}
145+
func processResponse(resp *v1.WatchResponse) {
146+
if resp.ChangesThrough != nil {
147+
watchRevision = resp.ChangesThrough.Token
148+
}
125149

126-
subjectRelation := ""
127-
if update.Relationship.Subject.OptionalRelation != "" {
128-
subjectRelation = " " + update.Relationship.Subject.OptionalRelation
129-
}
150+
if resp.SchemaUpdated {
151+
if watchTimestamps {
152+
console.Printf("%v: ", time.Now())
153+
}
154+
console.Println("SCHEMA UPDATED")
155+
}
130156

131-
console.Printf("%s:%s %s %s:%s%s\n",
132-
update.Relationship.Resource.ObjectType,
133-
update.Relationship.Resource.ObjectId,
134-
update.Relationship.Relation,
135-
update.Relationship.Subject.Object.ObjectType,
136-
update.Relationship.Subject.Object.ObjectId,
137-
subjectRelation,
138-
)
139-
}
157+
for _, update := range resp.Updates {
158+
if watchTimestamps {
159+
console.Printf("%v: ", time.Now())
140160
}
161+
162+
switch update.Operation {
163+
case v1.RelationshipUpdate_OPERATION_CREATE:
164+
console.Printf("CREATED ")
165+
166+
case v1.RelationshipUpdate_OPERATION_DELETE:
167+
console.Printf("DELETED ")
168+
169+
case v1.RelationshipUpdate_OPERATION_TOUCH:
170+
console.Printf("TOUCHED ")
171+
}
172+
173+
subjectRelation := ""
174+
if update.Relationship.Subject.OptionalRelation != "" {
175+
subjectRelation = " " + update.Relationship.Subject.OptionalRelation
176+
}
177+
178+
console.Printf("%s:%s %s %s:%s%s\n",
179+
update.Relationship.Resource.ObjectType,
180+
update.Relationship.Resource.ObjectId,
181+
update.Relationship.Relation,
182+
update.Relationship.Subject.Object.ObjectType,
183+
update.Relationship.Subject.Object.ObjectId,
184+
subjectRelation,
185+
)
141186
}
142187
}
143188

0 commit comments

Comments
 (0)