Skip to content

Commit 4e3051d

Browse files
authored
[common-go] Rate limiting on booleans and composite keys (#17026)
* [common-go] Rate limiting on booleans and composite keys * Optimize FieldAccessKey
1 parent 4edc0ef commit 4e3051d

File tree

2 files changed

+131
-6
lines changed

2 files changed

+131
-6
lines changed

components/common-go/grpc/ratelimit.go

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,12 @@ type RateLimit struct {
3333
// so effectively this is the rate at which requests can be made.
3434
RefillInterval util.Duration `json:"refillInterval"`
3535

36-
KeyCacheSize uint `json:"keyCacheSize,omitempty"`
37-
Key string `json:"key,omitempty"`
36+
// Key is the proto field name to rate limit on. Each unique value of this
37+
// field gets its own rate limit bucket. Must be a String, Enum, or Boolean field.
38+
// Can be a composite key by separating fields by comma, e.g. `foo.bar,foo.baz`
39+
Key string `json:"key,omitempty"`
40+
// KeyCacheSize is the max number of buckets kept in a LRU cache.
41+
KeyCacheSize uint `json:"keyCacheSize,omitempty"`
3842
}
3943

4044
func (r RateLimit) Limiter() *rate.Limiter {
@@ -82,17 +86,31 @@ func NewRatelimitingInterceptor(f map[string]RateLimit) RatelimitingInterceptor
8286
}
8387

8488
func fieldAccessKey(key string) keyFunc {
89+
fields := strings.Split(key, ",")
90+
paths := make([][]string, len(fields))
91+
for i, field := range fields {
92+
paths[i] = strings.Split(field, ".")
93+
}
8594
return func(req interface{}) (string, error) {
8695
msg, ok := req.(proto.Message)
8796
if !ok {
8897
return "", status.Errorf(codes.Internal, "request was not a protobuf message")
8998
}
9099

91-
val, ok := getFieldValue(msg.ProtoReflect(), strings.Split(key, "."))
92-
if !ok {
93-
return "", status.Errorf(codes.Internal, "Field %s does not exist in message. This is a rate limiting configuration error.", key)
100+
var composite string
101+
for i, field := range fields {
102+
val, ok := getFieldValue(msg.ProtoReflect(), paths[i])
103+
if !ok {
104+
return "", status.Errorf(codes.Internal, "Field %s does not exist in message. This is a rate limiting configuration error.", field)
105+
}
106+
// It's technically possible that `|` is part of one of the field values, and therefore could cause collisions
107+
// in composite keys, e.g. values (`a|`, `b`), and (`a`, `|b`) would result in the same composite key `a||b`
108+
// and share the rate limit. This is highly unlikely though given the current fields we rate limit on and
109+
// otherwise unlikely to cause issues.
110+
composite += "|" + val
94111
}
95-
return val, nil
112+
113+
return composite, nil
96114
}
97115
}
98116

@@ -120,6 +138,13 @@ func getFieldValue(msg protoreflect.Message, path []string) (val string, ok bool
120138
case protoreflect.EnumKind:
121139
enumNum := msg.Get(field).Enum()
122140
return strconv.Itoa(int(enumNum)), true
141+
case protoreflect.BoolKind:
142+
if msg.Get(field).Bool() {
143+
return "t", true
144+
} else {
145+
return "f", true
146+
}
147+
123148
default:
124149
// we only support string and enum fields
125150
return "", false

components/common-go/grpc/ratelimit_test.go

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,18 @@ func TestGetFieldValue(t *testing.T) {
5151
Path: "syntax",
5252
Expectation: Expectation{Found: true, Val: "0"},
5353
},
54+
{
55+
Name: "bool field",
56+
Message: &apipb.Method{RequestStreaming: true},
57+
Path: "request_streaming",
58+
Expectation: Expectation{Found: true, Val: "t"},
59+
},
60+
{
61+
Name: "empty bool field",
62+
Message: &apipb.Method{},
63+
Path: "request_streaming",
64+
Expectation: Expectation{Found: true, Val: "f"},
65+
},
5466
{
5567
Name: "non-existent field",
5668
Message: &apipb.Api{},
@@ -79,6 +91,41 @@ func TestGetFieldValue(t *testing.T) {
7991
}
8092
}
8193

94+
func TestFieldAccessKey(t *testing.T) {
95+
type Expectation struct {
96+
Val string
97+
Err error
98+
}
99+
tests := []struct {
100+
Name string
101+
Message proto.Message
102+
Key string
103+
Expectation Expectation
104+
}{
105+
{
106+
Name: "composite key",
107+
Message: &apipb.Api{
108+
SourceContext: &sourcecontextpb.SourceContext{
109+
FileName: "bar",
110+
},
111+
Syntax: typepb.Syntax_SYNTAX_PROTO3,
112+
},
113+
Key: "source_context.file_name,syntax",
114+
Expectation: Expectation{Val: "|bar|1", Err: nil},
115+
},
116+
}
117+
for _, test := range tests {
118+
t.Run(test.Name, func(t *testing.T) {
119+
var act Expectation
120+
keyFn := fieldAccessKey(test.Key)
121+
act.Val, act.Err = keyFn(test.Message)
122+
if diff := cmp.Diff(test.Expectation, act); diff != "" {
123+
t.Errorf("unexpected fieldAccessKey (-want +got):\n%s", diff)
124+
}
125+
})
126+
}
127+
}
128+
82129
func BenchmarkGetFieldValue(b *testing.B) {
83130
msg := apipb.Api{
84131
SourceContext: &sourcecontextpb.SourceContext{
@@ -92,3 +139,56 @@ func BenchmarkGetFieldValue(b *testing.B) {
92139
getFieldValue(msgr, path)
93140
}
94141
}
142+
143+
func BenchmarkFieldAccessKey_String(b *testing.B) {
144+
msg := &apipb.Api{
145+
Name: "bar",
146+
}
147+
keyFn := fieldAccessKey("name")
148+
for n := 0; n < b.N; n++ {
149+
if _, err := keyFn(msg); err != nil {
150+
b.Logf("failed to access key: %v", err)
151+
b.Fail()
152+
}
153+
}
154+
}
155+
156+
func BenchmarkFieldAccessKey_Enum(b *testing.B) {
157+
msg := &apipb.Api{
158+
Syntax: typepb.Syntax_SYNTAX_PROTO3,
159+
}
160+
keyFn := fieldAccessKey("syntax")
161+
for n := 0; n < b.N; n++ {
162+
if _, err := keyFn(msg); err != nil {
163+
b.Logf("failed to access key: %v", err)
164+
b.Fail()
165+
}
166+
}
167+
}
168+
169+
func BenchmarkFieldAccessKey_Bool(b *testing.B) {
170+
msg := &apipb.Method{
171+
RequestStreaming: true,
172+
}
173+
keyFn := fieldAccessKey("request_streaming")
174+
for n := 0; n < b.N; n++ {
175+
if _, err := keyFn(msg); err != nil {
176+
b.Logf("failed to access key: %v", err)
177+
b.Fail()
178+
}
179+
}
180+
}
181+
182+
func BenchmarkFieldAccessKey_Composite(b *testing.B) {
183+
msg := &apipb.Method{
184+
Name: "bar",
185+
RequestStreaming: true,
186+
}
187+
keyFn := fieldAccessKey("name,request_streaming")
188+
for n := 0; n < b.N; n++ {
189+
if _, err := keyFn(msg); err != nil {
190+
b.Logf("failed to access key: %v", err)
191+
b.Fail()
192+
}
193+
}
194+
}

0 commit comments

Comments
 (0)