Skip to content

Commit 8637ea4

Browse files
committed
feat: add --extra-header flag for custom gRPC headers
Signed-off-by: Erik Hennings <erik.hennings@freda.com>
1 parent 937e0b3 commit 8637ea4

File tree

4 files changed

+61
-0
lines changed

4 files changed

+61
-0
lines changed

internal/client/client.go

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import (
2121
"google.golang.org/grpc"
2222
"google.golang.org/grpc/codes"
2323
"google.golang.org/grpc/credentials/insecure"
24+
"google.golang.org/grpc/metadata"
2425

2526
v1 "github.com/authzed/authzed-go/proto/authzed/api/v1"
2627
"github.com/authzed/authzed-go/v1"
@@ -214,6 +215,43 @@ func isNoneOf(routes ...string) func(_ context.Context, c interceptors.CallMeta)
214215
}
215216
}
216217

218+
func extraHeadersUnaryInterceptor(headers map[string]string) grpc.UnaryClientInterceptor {
219+
return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
220+
if len(headers) > 0 {
221+
md := metadata.New(headers)
222+
ctx = metadata.NewOutgoingContext(ctx, md)
223+
}
224+
return invoker(ctx, method, req, reply, cc, opts...)
225+
}
226+
}
227+
228+
func extraHeadersStreamInterceptor(headers map[string]string) grpc.StreamClientInterceptor {
229+
return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
230+
if len(headers) > 0 {
231+
md := metadata.New(headers)
232+
ctx = metadata.NewOutgoingContext(ctx, md)
233+
}
234+
return streamer(ctx, desc, cc, method, opts...)
235+
}
236+
}
237+
238+
func parseExtraHeaders(headerStrings []string) (map[string]string, error) {
239+
headers := make(map[string]string)
240+
for _, headerStr := range headerStrings {
241+
parts := strings.SplitN(headerStr, "=", 2)
242+
if len(parts) != 2 {
243+
return nil, fmt.Errorf("invalid header format '%s': expected 'key=value'", headerStr)
244+
}
245+
key := strings.TrimSpace(parts[0])
246+
value := strings.TrimSpace(parts[1])
247+
if key == "" {
248+
return nil, fmt.Errorf("invalid header format '%s': key cannot be empty", headerStr)
249+
}
250+
headers[key] = value
251+
}
252+
return headers, nil
253+
}
254+
217255
// DialOptsFromFlags returns the dial options from the CLI-specified flags.
218256
func DialOptsFromFlags(cmd *cobra.Command, token storage.Token) ([]grpc.DialOption, error) {
219257
maxRetries := cobrautil.MustGetUint(cmd, "max-retries")
@@ -239,6 +277,17 @@ func DialOptsFromFlags(cmd *cobra.Command, token storage.Token) ([]grpc.DialOpti
239277
selector.StreamClientInterceptor(retry.StreamClientInterceptor(retryOpts...), selector.MatchFunc(isNoneOf(importBulkRoute, exportBulkRoute, watchRoute))),
240278
}
241279

280+
// Parse and add extra headers if provided
281+
extraHeaderStrings := cobrautil.MustGetStringSlice(cmd, "extra-header")
282+
if len(extraHeaderStrings) > 0 {
283+
headers, err := parseExtraHeaders(extraHeaderStrings)
284+
if err != nil {
285+
return nil, fmt.Errorf("failed to parse extra headers: %w", err)
286+
}
287+
unaryInterceptors = append(unaryInterceptors, extraHeadersUnaryInterceptor(headers))
288+
streamInterceptors = append(streamInterceptors, extraHeadersStreamInterceptor(headers))
289+
}
290+
242291
if !cobrautil.MustGetBool(cmd, "skip-version-check") {
243292
unaryInterceptors = append(unaryInterceptors, zgrpcutil.CheckServerVersion)
244293
}

internal/client/client_test.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ func TestRetries(t *testing.T) {
180180
zedtesting.StringFlag{FlagName: "proxy", FlagValue: "", Changed: true},
181181
zedtesting.StringFlag{FlagName: "hostname-override", FlagValue: "", Changed: true},
182182
zedtesting.IntFlag{FlagName: "max-message-size", FlagValue: 1000, Changed: true},
183+
zedtesting.StringSliceFlag{FlagName: "extra-header", FlagValue: []string{}, Changed: false},
183184
)
184185
dialOpts, err := client.DialOptsFromFlags(cmd, storage.Token{Insecure: &secure})
185186
require.NoError(t, err)
@@ -224,6 +225,7 @@ func TestDoesNotRetry(t *testing.T) {
224225
zedtesting.StringFlag{FlagName: "proxy", FlagValue: "", Changed: true},
225226
zedtesting.StringFlag{FlagName: "hostname-override", FlagValue: "", Changed: true},
226227
zedtesting.IntFlag{FlagName: "max-message-size", FlagValue: 1000, Changed: true},
228+
zedtesting.StringSliceFlag{FlagName: "extra-header", FlagValue: []string{}, Changed: false},
227229
)
228230
dialOpts, err := client.DialOptsFromFlags(cmd, storage.Token{Insecure: &secure})
229231
require.NoError(t, err)

internal/cmd/cmd.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ zed permission check --explain document:firstdoc writer user:emilia
9191
rootCmd.PersistentFlags().Int("max-message-size", 0, "maximum size *in bytes* (defaults to 4_194_304 bytes ~= 4MB) of a gRPC message that can be sent or received by zed")
9292
rootCmd.PersistentFlags().String("proxy", "", "specify a SOCKS5 proxy address")
9393
rootCmd.PersistentFlags().Uint("max-retries", 10, "maximum number of sequential retries to attempt when a request fails")
94+
rootCmd.PersistentFlags().StringSlice("extra-header", []string{}, "extra header(s) to add to gRPC requests in the format 'key=value' (can be specified multiple times)")
9495
_ = rootCmd.PersistentFlags().MarkHidden("debug") // This cannot return its error.
9596

9697
versionCmd := &cobra.Command{

internal/testing/test_helpers.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,12 @@ type StringFlag struct {
6868
Changed bool
6969
}
7070

71+
type StringSliceFlag struct {
72+
FlagName string
73+
FlagValue []string
74+
Changed bool
75+
}
76+
7177
type BoolFlag struct {
7278
FlagName string
7379
FlagValue bool
@@ -107,6 +113,9 @@ func CreateTestCobraCommandWithFlagValue(t *testing.T, flagAndValues ...any) *co
107113
case StringFlag:
108114
c.Flags().String(f.FlagName, f.FlagValue, "")
109115
c.Flag(f.FlagName).Changed = f.Changed
116+
case StringSliceFlag:
117+
c.Flags().StringSlice(f.FlagName, f.FlagValue, "")
118+
c.Flag(f.FlagName).Changed = f.Changed
110119
case BoolFlag:
111120
c.Flags().Bool(f.FlagName, f.FlagValue, "")
112121
c.Flag(f.FlagName).Changed = f.Changed

0 commit comments

Comments
 (0)