@@ -3,7 +3,6 @@ package plugin
3
3
import (
4
4
"context"
5
5
"encoding/json"
6
- "fmt"
7
6
"net/http"
8
7
"net/http/httptest"
9
8
"testing"
@@ -71,28 +70,13 @@ func Test_errorResponse(t *testing.T) {
71
70
72
71
func Test_OnTrafficFromClinet (t * testing.T ) {
73
72
p := & Plugin {
74
- Logger : hclog .NewNullLogger (),
75
- ModelName : "sqli_model" ,
76
- ModelVersion : "2" ,
73
+ Logger : hclog .NewNullLogger (),
77
74
}
78
75
79
76
server := httptest .NewServer (
80
77
http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
81
78
switch r .URL .Path {
82
- case TokenizeAndSequencePath :
83
- w .WriteHeader (http .StatusOK )
84
- w .Header ().Set ("Content-Type" , "application/json" )
85
- // This is the tokenized query:
86
- // {"query":"select * from users where id = 1 or 1=1"}
87
- resp := map [string ][]float32 {
88
- "tokens" : {
89
- 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 3 , 6 , 5 , 73 , 7 , 68 , 4 , 11 , 12 ,
90
- },
91
- }
92
- data , _ := json .Marshal (resp )
93
- _ , err := w .Write (data )
94
- require .NoError (t , err )
95
- case fmt .Sprintf (PredictPath , p .ModelName , p .ModelVersion ):
79
+ case PredictPath :
96
80
w .WriteHeader (http .StatusOK )
97
81
w .Header ().Set ("Content-Type" , "application/json" )
98
82
// This is the output of the deep learning model.
@@ -107,8 +91,7 @@ func Test_OnTrafficFromClinet(t *testing.T) {
107
91
)
108
92
defer server .Close ()
109
93
110
- p .TokenizerAPIAddress = server .URL
111
- p .ServingAPIAddress = server .URL
94
+ p .PredictionAPIAddress = server .URL
112
95
113
96
query := pgproto3.Query {String : "SELECT * FROM users WHERE id = 1 OR 1=1" }
114
97
queryBytes , err := query .Encode (nil )
@@ -136,17 +119,13 @@ func Test_OnTrafficFromClinet(t *testing.T) {
136
119
func Test_OnTrafficFromClinetFailedTokenization (t * testing.T ) {
137
120
plugins := []* Plugin {
138
121
{
139
- Logger : hclog .NewNullLogger (),
140
- ModelName : "sqli_model" ,
141
- ModelVersion : "2" ,
122
+ Logger : hclog .NewNullLogger (),
142
123
// If libinjection is enabled, the response should contain the "response" field,
143
124
// and the "signals" field, which means the plugin will terminate the request.
144
125
EnableLibinjection : true ,
145
126
},
146
127
{
147
- Logger : hclog .NewNullLogger (),
148
- ModelName : "sqli_model" ,
149
- ModelVersion : "2" ,
128
+ Logger : hclog .NewNullLogger (),
150
129
// If libinjection is disabled, the response should not contain the "response" field,
151
130
// and the "signals" field, which means the plugin will not terminate the request.
152
131
EnableLibinjection : false ,
@@ -156,7 +135,7 @@ func Test_OnTrafficFromClinetFailedTokenization(t *testing.T) {
156
135
server := httptest .NewServer (
157
136
http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
158
137
switch r .URL .Path {
159
- case TokenizeAndSequencePath :
138
+ case PredictPath :
160
139
w .WriteHeader (http .StatusInternalServerError )
161
140
default :
162
141
w .WriteHeader (http .StatusNotFound )
@@ -166,8 +145,7 @@ func Test_OnTrafficFromClinetFailedTokenization(t *testing.T) {
166
145
defer server .Close ()
167
146
168
147
for i := range plugins {
169
- plugins [i ].TokenizerAPIAddress = server .URL
170
- plugins [i ].ServingAPIAddress = server .URL
148
+ plugins [i ].PredictionAPIAddress = server .URL
171
149
172
150
query := pgproto3.Query {String : "SELECT * FROM users WHERE id = 1 OR 1=1" }
173
151
queryBytes , err := query .Encode (nil )
@@ -204,43 +182,22 @@ func Test_OnTrafficFromClinetFailedTokenization(t *testing.T) {
204
182
func Test_OnTrafficFromClinetFailedPrediction (t * testing.T ) {
205
183
plugins := []* Plugin {
206
184
{
207
- Logger : hclog .NewNullLogger (),
208
- ModelName : "sqli_model" ,
209
- ModelVersion : "2" ,
185
+ Logger : hclog .NewNullLogger (),
210
186
// If libinjection is disabled, the response should not contain the "response" field,
211
187
// and the "signals" field, which means the plugin will not terminate the request.
212
188
EnableLibinjection : false ,
213
189
},
214
190
{
215
- Logger : hclog .NewNullLogger (),
216
- ModelName : "sqli_model" ,
217
- ModelVersion : "2" ,
191
+ Logger : hclog .NewNullLogger (),
218
192
// If libinjection is enabled, the response should contain the "response" field,
219
193
// and the "signals" field, which means the plugin will terminate the request.
220
194
EnableLibinjection : true ,
221
195
},
222
196
}
223
-
224
- // This is the same for both plugins.
225
- predictPath := fmt .Sprintf (PredictPath , plugins [0 ].ModelName , plugins [1 ].ModelVersion )
226
-
227
197
server := httptest .NewServer (
228
198
http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
229
199
switch r .URL .Path {
230
- case TokenizeAndSequencePath :
231
- w .WriteHeader (http .StatusOK )
232
- w .Header ().Set ("Content-Type" , "application/json" )
233
- // This is the tokenized query:
234
- // {"query":"select * from users where id = 1 or 1=1"}
235
- resp := map [string ][]float32 {
236
- "tokens" : {
237
- 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 3 , 6 , 5 , 73 , 7 , 68 , 4 , 11 , 12 ,
238
- },
239
- }
240
- data , _ := json .Marshal (resp )
241
- _ , err := w .Write (data )
242
- require .NoError (t , err )
243
- case predictPath :
200
+ case PredictPath :
244
201
w .WriteHeader (http .StatusInternalServerError )
245
202
default :
246
203
w .WriteHeader (http .StatusNotFound )
@@ -250,8 +207,7 @@ func Test_OnTrafficFromClinetFailedPrediction(t *testing.T) {
250
207
defer server .Close ()
251
208
252
209
for i := range plugins {
253
- plugins [i ].TokenizerAPIAddress = server .URL
254
- plugins [i ].ServingAPIAddress = server .URL
210
+ plugins [i ].PredictionAPIAddress = server .URL
255
211
256
212
query := pgproto3.Query {String : "SELECT * FROM users WHERE id = 1 OR 1=1" }
257
213
queryBytes , err := query .Encode (nil )
0 commit comments