Skip to content

[common-go] Rate limiting on booleans and composite keys #17026

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 31 additions & 6 deletions components/common-go/grpc/ratelimit.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,12 @@ type RateLimit struct {
// so effectively this is the rate at which requests can be made.
RefillInterval util.Duration `json:"refillInterval"`

KeyCacheSize uint `json:"keyCacheSize,omitempty"`
Key string `json:"key,omitempty"`
// Key is the proto field name to rate limit on. Each unique value of this
// field gets its own rate limit bucket. Must be a String, Enum, or Boolean field.
// Can be a composite key by separating fields by comma, e.g. `foo.bar,foo.baz`
Key string `json:"key,omitempty"`
// KeyCacheSize is the max number of buckets kept in a LRU cache.
KeyCacheSize uint `json:"keyCacheSize,omitempty"`
}

func (r RateLimit) Limiter() *rate.Limiter {
Expand Down Expand Up @@ -82,17 +86,31 @@ func NewRatelimitingInterceptor(f map[string]RateLimit) RatelimitingInterceptor
}

func fieldAccessKey(key string) keyFunc {
fields := strings.Split(key, ",")
paths := make([][]string, len(fields))
for i, field := range fields {
paths[i] = strings.Split(field, ".")
}
Comment on lines +89 to +93
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Extracting these string splits to outside the returned function improves performance by ~60%:

// OLD:
// BenchmarkFieldAccessKey_String-32                5800192               215.1 ns/op
// BenchmarkFieldAccessKey_Composite-32             3239696               371.9 ns/op
// NEW:
// BenchmarkFieldAccessKey_String-32                9338356               127.8 ns/op
// BenchmarkFieldAccessKey_Composite-32             5505349               217.8 ns/op

return func(req interface{}) (string, error) {
msg, ok := req.(proto.Message)
if !ok {
return "", status.Errorf(codes.Internal, "request was not a protobuf message")
}

val, ok := getFieldValue(msg.ProtoReflect(), strings.Split(key, "."))
if !ok {
return "", status.Errorf(codes.Internal, "Field %s does not exist in message. This is a rate limiting configuration error.", key)
var composite string
for i, field := range fields {
val, ok := getFieldValue(msg.ProtoReflect(), paths[i])
if !ok {
return "", status.Errorf(codes.Internal, "Field %s does not exist in message. This is a rate limiting configuration error.", field)
}
// It's technically possible that `|` is part of one of the field values, and therefore could cause collisions
// in composite keys, e.g. values (`a|`, `b`), and (`a`, `|b`) would result in the same composite key `a||b`
// and share the rate limit. This is highly unlikely though given the current fields we rate limit on and
// otherwise unlikely to cause issues.
composite += "|" + val
}
return val, nil

return composite, nil
}
}

Expand Down Expand Up @@ -120,6 +138,13 @@ func getFieldValue(msg protoreflect.Message, path []string) (val string, ok bool
case protoreflect.EnumKind:
enumNum := msg.Get(field).Enum()
return strconv.Itoa(int(enumNum)), true
case protoreflect.BoolKind:
if msg.Get(field).Bool() {
return "t", true
} else {
return "f", true
}

default:
// we only support string and enum fields
return "", false
Expand Down
100 changes: 100 additions & 0 deletions components/common-go/grpc/ratelimit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,18 @@ func TestGetFieldValue(t *testing.T) {
Path: "syntax",
Expectation: Expectation{Found: true, Val: "0"},
},
{
Name: "bool field",
Message: &apipb.Method{RequestStreaming: true},
Path: "request_streaming",
Expectation: Expectation{Found: true, Val: "t"},
},
{
Name: "empty bool field",
Message: &apipb.Method{},
Path: "request_streaming",
Expectation: Expectation{Found: true, Val: "f"},
},
{
Name: "non-existent field",
Message: &apipb.Api{},
Expand Down Expand Up @@ -79,6 +91,41 @@ func TestGetFieldValue(t *testing.T) {
}
}

func TestFieldAccessKey(t *testing.T) {
type Expectation struct {
Val string
Err error
}
tests := []struct {
Name string
Message proto.Message
Key string
Expectation Expectation
}{
{
Name: "composite key",
Message: &apipb.Api{
SourceContext: &sourcecontextpb.SourceContext{
FileName: "bar",
},
Syntax: typepb.Syntax_SYNTAX_PROTO3,
},
Key: "source_context.file_name,syntax",
Expectation: Expectation{Val: "|bar|1", Err: nil},
},
}
for _, test := range tests {
t.Run(test.Name, func(t *testing.T) {
var act Expectation
keyFn := fieldAccessKey(test.Key)
act.Val, act.Err = keyFn(test.Message)
if diff := cmp.Diff(test.Expectation, act); diff != "" {
t.Errorf("unexpected fieldAccessKey (-want +got):\n%s", diff)
}
})
}
}

func BenchmarkGetFieldValue(b *testing.B) {
msg := apipb.Api{
SourceContext: &sourcecontextpb.SourceContext{
Expand All @@ -92,3 +139,56 @@ func BenchmarkGetFieldValue(b *testing.B) {
getFieldValue(msgr, path)
}
}

func BenchmarkFieldAccessKey_String(b *testing.B) {
msg := &apipb.Api{
Name: "bar",
}
keyFn := fieldAccessKey("name")
for n := 0; n < b.N; n++ {
if _, err := keyFn(msg); err != nil {
b.Logf("failed to access key: %v", err)
b.Fail()
}
}
}

func BenchmarkFieldAccessKey_Enum(b *testing.B) {
msg := &apipb.Api{
Syntax: typepb.Syntax_SYNTAX_PROTO3,
}
keyFn := fieldAccessKey("syntax")
for n := 0; n < b.N; n++ {
if _, err := keyFn(msg); err != nil {
b.Logf("failed to access key: %v", err)
b.Fail()
}
}
}

func BenchmarkFieldAccessKey_Bool(b *testing.B) {
msg := &apipb.Method{
RequestStreaming: true,
}
keyFn := fieldAccessKey("request_streaming")
for n := 0; n < b.N; n++ {
if _, err := keyFn(msg); err != nil {
b.Logf("failed to access key: %v", err)
b.Fail()
}
}
}

func BenchmarkFieldAccessKey_Composite(b *testing.B) {
msg := &apipb.Method{
Name: "bar",
RequestStreaming: true,
}
keyFn := fieldAccessKey("name,request_streaming")
for n := 0; n < b.N; n++ {
if _, err := keyFn(msg); err != nil {
b.Logf("failed to access key: %v", err)
b.Fail()
}
}
}