@@ -1213,6 +1213,78 @@ struct test_get_rows_back : public test_case {
1213
1213
}
1214
1214
};
1215
1215
1216
+ // GGML_OP_SET_ROWS
1217
+ struct test_set_rows : public test_case {
1218
+ const ggml_type type;
1219
+ const int n; // cols
1220
+ const int m; // rows
1221
+ const int r; // rows to set
1222
+ const int b0; // batch size
1223
+ const int b1; // batch size
1224
+ const int bs; // batch size src (for testing broadcast)
1225
+ const bool v; // view (non-contiguous src1)
1226
+
1227
+ std::string vars () override {
1228
+ return VARS_TO_STR7 (type, n, m, r, b0, bs, v);
1229
+ }
1230
+
1231
+ test_set_rows (ggml_type type = GGML_TYPE_F32, int n = 10 , int m = 5 , int r = 3 , int b = 1 , int bs = 1 , bool v = false )
1232
+ : type(type), n(n), m(m), r(r), b0(b), b1(3 ), bs(bs), v(v) {
1233
+ GGML_ASSERT (b0 % bs == 0 && " b0 must be a multiple of bs" );
1234
+ GGML_ASSERT (r <= m && " r must be less than or equal to m" );
1235
+ }
1236
+
1237
+ ggml_tensor * build_graph (ggml_context * ctx) override {
1238
+ ggml_tensor * dst = ggml_new_tensor_4d (ctx, type, n, m, b0, b1);
1239
+ ggml_set_name (dst, " dst" );
1240
+
1241
+ ggml_tensor * src = ggml_new_tensor_4d (ctx, GGML_TYPE_F32, n, r, b0, b1);
1242
+ ggml_set_name (src, " src" );
1243
+
1244
+ ggml_tensor * row_idxs = ggml_new_tensor_3d (ctx, GGML_TYPE_I64, r, bs, b1);
1245
+ ggml_set_name (row_idxs, " row_idxs" );
1246
+
1247
+ if (v) {
1248
+ src = ggml_view_4d (ctx, src, n, r/2 , b0, b1, src->nb [1 ], src->nb [2 ], src->nb [3 ], 0 );
1249
+ row_idxs = ggml_view_3d (ctx, row_idxs, r/2 , bs, b1, row_idxs->nb [1 ], row_idxs->nb [2 ], 0 );
1250
+ ggml_set_name (row_idxs, " view_of_rows" );
1251
+ }
1252
+
1253
+ ggml_tensor * out = ggml_set_rows (ctx, dst, src, row_idxs);
1254
+ ggml_set_name (out, " out" );
1255
+
1256
+ return out;
1257
+ }
1258
+
1259
+ void initialize_tensors (ggml_context * ctx) override {
1260
+ std::random_device rd;
1261
+ std::default_random_engine rng (rd ());
1262
+ for (ggml_tensor * t = ggml_get_first_tensor (ctx); t != NULL ; t = ggml_get_next_tensor (ctx, t)) {
1263
+ if (t->type == GGML_TYPE_I64) {
1264
+ if (ggml_is_view_op (t->op )) {
1265
+ continue ;
1266
+ }
1267
+
1268
+ for (int i2 = 0 ; i2 < t->ne [2 ]; i2++) {
1269
+ for (int i1 = 0 ; i1 < t->ne [1 ]; i1++) {
1270
+ std::vector<int64_t > data (m);
1271
+ for (int i = 0 ; i < m; i++) {
1272
+ data[i] = i;
1273
+ }
1274
+ std::shuffle (data.begin (), data.end (), rng);
1275
+ data.resize (t->ne [0 ]);
1276
+
1277
+ const size_t offs = i1*t->nb [1 ] + i2*t->nb [2 ];
1278
+ ggml_backend_tensor_set (t, data.data (), offs, t->ne [0 ]*sizeof (int64_t ));
1279
+ }
1280
+ }
1281
+ } else {
1282
+ init_tensor_uniform (t);
1283
+ }
1284
+ }
1285
+ }
1286
+ };
1287
+
1216
1288
// GGML_OP_ARGMAX
1217
1289
struct test_argmax : public test_case {
1218
1290
const ggml_type type;
@@ -3984,6 +4056,15 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
3984
4056
test_cases.emplace_back (new test_get_rows_back (GGML_TYPE_I32, 256 , 5 , 4 , 1 , v));
3985
4057
}
3986
4058
4059
+ test_cases.emplace_back (new test_set_rows (GGML_TYPE_F32, 1 , 8 , 2 , 1 , 1 , false ));
4060
+ for (ggml_type type : all_types) {
4061
+ for (int b : {1 , 7 }) {
4062
+ for (bool v : {false , true }) {
4063
+ test_cases.emplace_back (new test_set_rows (type, 256 , 5 , 4 , b, 1 , v));
4064
+ }
4065
+ }
4066
+ }
4067
+
3987
4068
for (ggml_type type_input : {GGML_TYPE_F32}) {
3988
4069
for (ggml_op_pool pool_type : {GGML_OP_POOL_AVG, GGML_OP_POOL_MAX}) {
3989
4070
for (int k0 : {1 , 3 }) {
0 commit comments