Skip to content

Commit 9865e83

Browse files
committed
Add integration with the new Prediction API
Remove the integration with Tokenization and Serving APIs Remove unused env-vars Update tests to reflect changes Update plugin config
1 parent a1511a8 commit 9865e83

File tree

6 files changed

+29
-115
lines changed

6 files changed

+29
-115
lines changed

gatewayd_plugin.yaml

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,7 @@ plugins:
2727
- METRICS_ENABLED=True
2828
- METRICS_UNIX_DOMAIN_SOCKET=/tmp/gatewayd-plugin-sql-ids-ips.sock
2929
- METRICS_PATH=/metrics
30-
- TOKENIZER_API_ADDRESS=http://localhost:8000
31-
- SERVING_API_ADDRESS=http://localhost:8501
32-
- MODEL_NAME=sqli_model
33-
- MODEL_VERSION=3
30+
- PREDICTION_API_ADDRESS=http://localhost:8000
3431
# Threshold determine the minimum prediction confidence
3532
# required to detect an SQL injection attack. Any value
3633
# between 0 and 1 is valid, and it is inclusive.

main.go

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,7 @@ func main() {
5454
pluginInstance.Impl.EnableLibinjection = cast.ToBool(cfg["enableLibinjection"])
5555
pluginInstance.Impl.LibinjectionPermissiveMode = cast.ToBool(
5656
cfg["libinjectionPermissiveMode"])
57-
pluginInstance.Impl.TokenizerAPIAddress = cast.ToString(cfg["tokenizerAPIAddress"])
58-
pluginInstance.Impl.ServingAPIAddress = cast.ToString(cfg["servingAPIAddress"])
59-
pluginInstance.Impl.ModelName = cast.ToString(cfg["modelName"])
60-
pluginInstance.Impl.ModelVersion = cast.ToString(cfg["modelVersion"])
57+
pluginInstance.Impl.PredictionAPIAddress = cast.ToString(cfg["predictionAPIAddress"])
6158

6259
pluginInstance.Impl.ResponseType = cast.ToString(cfg["responseType"])
6360
pluginInstance.Impl.ErrorMessage = cast.ToString(cfg["errorMessage"])

plugin/constants.go

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,11 @@ package plugin
33
const (
44
DecodedQueryField string = "decodedQuery"
55
DetectorField string = "detector"
6-
ScoreField string = "score"
76
QueryField string = "query"
87
ErrorField string = "error"
98
IsInjectionField string = "is_injection"
109
ResponseField string = "response"
11-
OutputsField string = "outputs"
10+
ConfidenceField string = "confidence"
1211
TokensField string = "tokens"
1312
StringField string = "String"
1413
ResponseTypeField string = "response_type"
@@ -23,6 +22,5 @@ const (
2322
ErrorDetail string = "Back off, you're not welcome here."
2423
LogLevel string = "error"
2524

26-
TokenizeAndSequencePath string = "/tokenize_and_sequence"
27-
PredictPath string = "/v1/models/%s/versions/%s:predict"
25+
PredictPath string = "/predict"
2826
)

plugin/module.go

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,8 @@ var (
3636
"metricsUnixDomainSocket": sdkConfig.GetEnv(
3737
"METRICS_UNIX_DOMAIN_SOCKET", "/tmp/gatewayd-plugin-sql-ids-ips.sock"),
3838
"metricsEndpoint": sdkConfig.GetEnv("METRICS_ENDPOINT", "/metrics"),
39-
"tokenizerAPIAddress": sdkConfig.GetEnv(
40-
"TOKENIZER_API_ADDRESS", "http://localhost:8000"),
41-
"servingAPIAddress": sdkConfig.GetEnv(
42-
"SERVING_API_ADDRESS", "http://localhost:8501"),
43-
"modelName": sdkConfig.GetEnv("MODEL_NAME", "sqli_model"),
44-
"modelVersion": sdkConfig.GetEnv("MODEL_VERSION", "1"),
39+
"predictionAPIAddress": sdkConfig.GetEnv(
40+
"PREDICTION_API_ADDRESS", "http://localhost:8000"),
4541
"threshold": sdkConfig.GetEnv("THRESHOLD", "0.8"),
4642
"enableLibinjection": sdkConfig.GetEnv("ENABLE_LIBINJECTION", "true"),
4743
"libinjectionPermissiveMode": sdkConfig.GetEnv("LIBINJECTION_MODE", "true"),

plugin/plugin.go

Lines changed: 12 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ import (
44
"context"
55
"encoding/base64"
66
"encoding/json"
7-
"fmt"
87

98
"github.com/carlmjohnson/requests"
109
"github.com/corazawaf/libinjection-go"
@@ -28,10 +27,7 @@ type Plugin struct {
2827
Threshold float32
2928
EnableLibinjection bool
3029
LibinjectionPermissiveMode bool
31-
TokenizerAPIAddress string
32-
ServingAPIAddress string
33-
ModelName string
34-
ModelVersion string
30+
PredictionAPIAddress string
3531
ResponseType string
3632
ErrorMessage string
3733
ErrorSeverity string
@@ -111,36 +107,12 @@ func (p *Plugin) OnTrafficFromClient(ctx context.Context, req *v1.Struct) (*v1.S
111107
}
112108
queryString := cast.ToString(queryMap[StringField])
113109

114-
var tokens map[string]any
115-
err = requests.
116-
URL(p.TokenizerAPIAddress).
117-
Path(TokenizeAndSequencePath).
118-
BodyJSON(map[string]any{
119-
QueryField: queryString,
120-
}).
121-
ToJSON(&tokens).
122-
Fetch(context.Background())
123-
if err != nil {
124-
p.Logger.Error("Failed to make POST request", ErrorField, err)
125-
if p.isSQLi(queryString) && !p.LibinjectionPermissiveMode {
126-
return p.prepareResponse(
127-
req,
128-
map[string]any{
129-
QueryField: queryString,
130-
DetectorField: Libinjection,
131-
ErrorField: "Failed to make POST request to tokenizer API",
132-
},
133-
), nil
134-
}
135-
return req, nil
136-
}
137-
138110
var output map[string]any
139111
err = requests.
140-
URL(p.ServingAPIAddress).
141-
Path(fmt.Sprintf(PredictPath, p.ModelName, p.ModelVersion)).
112+
URL(p.PredictionAPIAddress).
113+
Path(PredictPath).
142114
BodyJSON(map[string]any{
143-
"inputs": []any{cast.ToSlice(tokens[TokensField])},
115+
QueryField: queryString,
144116
}).
145117
ToJSON(&output).
146118
Fetch(context.Background())
@@ -152,34 +124,32 @@ func (p *Plugin) OnTrafficFromClient(ctx context.Context, req *v1.Struct) (*v1.S
152124
map[string]any{
153125
QueryField: queryString,
154126
DetectorField: Libinjection,
155-
ErrorField: "Failed to make POST request to serving API",
127+
ErrorField: "Failed to make POST request to tokenizer API",
156128
},
157129
), nil
158130
}
159131
return req, nil
160132
}
161133

162-
predictions := cast.ToSlice(output[OutputsField])
163-
scores := cast.ToSlice(predictions[0])
164-
score := cast.ToFloat32(scores[0])
165-
p.Logger.Trace("Deep learning model prediction", ScoreField, score)
134+
confidence := cast.ToFloat32(output[ConfidenceField])
135+
p.Logger.Trace("Deep learning model prediction", ConfidenceField, confidence)
166136

167137
// Check the prediction against the threshold,
168138
// otherwise check if the query is an SQL injection using libinjection.
169139
injection := p.isSQLi(queryString)
170-
if score >= p.Threshold {
140+
if confidence >= p.Threshold {
171141
if p.EnableLibinjection && !injection {
172142
p.Logger.Debug("False positive detected", DetectorField, Libinjection)
173143
}
174144

175145
Detections.With(map[string]string{DetectorField: DeepLearningModel}).Inc()
176-
p.Logger.Warn(p.ErrorMessage, ScoreField, score, DetectorField, DeepLearningModel)
146+
p.Logger.Warn(p.ErrorMessage, ConfidenceField, confidence, DetectorField, DeepLearningModel)
177147
return p.prepareResponse(
178148
req,
179149
map[string]any{
180-
QueryField: queryString,
181-
ScoreField: score,
182-
DetectorField: DeepLearningModel,
150+
QueryField: queryString,
151+
ConfidenceField: confidence,
152+
DetectorField: DeepLearningModel,
183153
},
184154
), nil
185155
} else if p.EnableLibinjection && injection && !p.LibinjectionPermissiveMode {

plugin/plugin_test.go

Lines changed: 11 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ package plugin
33
import (
44
"context"
55
"encoding/json"
6-
"fmt"
76
"net/http"
87
"net/http/httptest"
98
"testing"
@@ -71,28 +70,13 @@ func Test_errorResponse(t *testing.T) {
7170

7271
func Test_OnTrafficFromClinet(t *testing.T) {
7372
p := &Plugin{
74-
Logger: hclog.NewNullLogger(),
75-
ModelName: "sqli_model",
76-
ModelVersion: "2",
73+
Logger: hclog.NewNullLogger(),
7774
}
7875

7976
server := httptest.NewServer(
8077
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
8178
switch r.URL.Path {
82-
case TokenizeAndSequencePath:
83-
w.WriteHeader(http.StatusOK)
84-
w.Header().Set("Content-Type", "application/json")
85-
// This is the tokenized query:
86-
// {"query":"select * from users where id = 1 or 1=1"}
87-
resp := map[string][]float32{
88-
"tokens": {
89-
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 6, 5, 73, 7, 68, 4, 11, 12,
90-
},
91-
}
92-
data, _ := json.Marshal(resp)
93-
_, err := w.Write(data)
94-
require.NoError(t, err)
95-
case fmt.Sprintf(PredictPath, p.ModelName, p.ModelVersion):
79+
case PredictPath:
9680
w.WriteHeader(http.StatusOK)
9781
w.Header().Set("Content-Type", "application/json")
9882
// This is the output of the deep learning model.
@@ -107,8 +91,7 @@ func Test_OnTrafficFromClinet(t *testing.T) {
10791
)
10892
defer server.Close()
10993

110-
p.TokenizerAPIAddress = server.URL
111-
p.ServingAPIAddress = server.URL
94+
p.PredictionAPIAddress = server.URL
11295

11396
query := pgproto3.Query{String: "SELECT * FROM users WHERE id = 1 OR 1=1"}
11497
queryBytes, err := query.Encode(nil)
@@ -136,17 +119,13 @@ func Test_OnTrafficFromClinet(t *testing.T) {
136119
func Test_OnTrafficFromClinetFailedTokenization(t *testing.T) {
137120
plugins := []*Plugin{
138121
{
139-
Logger: hclog.NewNullLogger(),
140-
ModelName: "sqli_model",
141-
ModelVersion: "2",
122+
Logger: hclog.NewNullLogger(),
142123
// If libinjection is enabled, the response should contain the "response" field,
143124
// and the "signals" field, which means the plugin will terminate the request.
144125
EnableLibinjection: true,
145126
},
146127
{
147-
Logger: hclog.NewNullLogger(),
148-
ModelName: "sqli_model",
149-
ModelVersion: "2",
128+
Logger: hclog.NewNullLogger(),
150129
// If libinjection is disabled, the response should not contain the "response" field,
151130
// and the "signals" field, which means the plugin will not terminate the request.
152131
EnableLibinjection: false,
@@ -156,7 +135,7 @@ func Test_OnTrafficFromClinetFailedTokenization(t *testing.T) {
156135
server := httptest.NewServer(
157136
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
158137
switch r.URL.Path {
159-
case TokenizeAndSequencePath:
138+
case PredictPath:
160139
w.WriteHeader(http.StatusInternalServerError)
161140
default:
162141
w.WriteHeader(http.StatusNotFound)
@@ -166,8 +145,7 @@ func Test_OnTrafficFromClinetFailedTokenization(t *testing.T) {
166145
defer server.Close()
167146

168147
for i := range plugins {
169-
plugins[i].TokenizerAPIAddress = server.URL
170-
plugins[i].ServingAPIAddress = server.URL
148+
plugins[i].PredictionAPIAddress = server.URL
171149

172150
query := pgproto3.Query{String: "SELECT * FROM users WHERE id = 1 OR 1=1"}
173151
queryBytes, err := query.Encode(nil)
@@ -204,43 +182,22 @@ func Test_OnTrafficFromClinetFailedTokenization(t *testing.T) {
204182
func Test_OnTrafficFromClinetFailedPrediction(t *testing.T) {
205183
plugins := []*Plugin{
206184
{
207-
Logger: hclog.NewNullLogger(),
208-
ModelName: "sqli_model",
209-
ModelVersion: "2",
185+
Logger: hclog.NewNullLogger(),
210186
// If libinjection is disabled, the response should not contain the "response" field,
211187
// and the "signals" field, which means the plugin will not terminate the request.
212188
EnableLibinjection: false,
213189
},
214190
{
215-
Logger: hclog.NewNullLogger(),
216-
ModelName: "sqli_model",
217-
ModelVersion: "2",
191+
Logger: hclog.NewNullLogger(),
218192
// If libinjection is enabled, the response should contain the "response" field,
219193
// and the "signals" field, which means the plugin will terminate the request.
220194
EnableLibinjection: true,
221195
},
222196
}
223-
224-
// This is the same for both plugins.
225-
predictPath := fmt.Sprintf(PredictPath, plugins[0].ModelName, plugins[1].ModelVersion)
226-
227197
server := httptest.NewServer(
228198
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
229199
switch r.URL.Path {
230-
case TokenizeAndSequencePath:
231-
w.WriteHeader(http.StatusOK)
232-
w.Header().Set("Content-Type", "application/json")
233-
// This is the tokenized query:
234-
// {"query":"select * from users where id = 1 or 1=1"}
235-
resp := map[string][]float32{
236-
"tokens": {
237-
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 6, 5, 73, 7, 68, 4, 11, 12,
238-
},
239-
}
240-
data, _ := json.Marshal(resp)
241-
_, err := w.Write(data)
242-
require.NoError(t, err)
243-
case predictPath:
200+
case PredictPath:
244201
w.WriteHeader(http.StatusInternalServerError)
245202
default:
246203
w.WriteHeader(http.StatusNotFound)
@@ -250,8 +207,7 @@ func Test_OnTrafficFromClinetFailedPrediction(t *testing.T) {
250207
defer server.Close()
251208

252209
for i := range plugins {
253-
plugins[i].TokenizerAPIAddress = server.URL
254-
plugins[i].ServingAPIAddress = server.URL
210+
plugins[i].PredictionAPIAddress = server.URL
255211

256212
query := pgproto3.Query{String: "SELECT * FROM users WHERE id = 1 OR 1=1"}
257213
queryBytes, err := query.Encode(nil)

0 commit comments

Comments
 (0)