@@ -1151,6 +1151,60 @@ struct test_glu : public test_case {
1151
1151
}
1152
1152
};
1153
1153
1154
+ struct test_glu_split : public test_case {
1155
+ const ggml_glu_op op;
1156
+ const ggml_type type;
1157
+ const std::array<int64_t , 4 > ne_a;
1158
+ int v; // view (1 : non-contiguous a)
1159
+
1160
+ std::string vars () override {
1161
+ return VARS_TO_STR3 (type, ne_a, v);
1162
+ }
1163
+
1164
+ test_glu_split (ggml_glu_op op,
1165
+ ggml_type type = GGML_TYPE_F32,
1166
+ std::array<int64_t , 4 > ne_a = {128 , 2 , 2 , 2 },
1167
+ int v = 0 )
1168
+ : op(op), type(type), ne_a(ne_a), v(v) {}
1169
+
1170
+ ggml_tensor * build_graph (ggml_context * ctx) override {
1171
+ ggml_tensor * a;
1172
+ ggml_tensor * b;
1173
+ if (v & 1 ) {
1174
+ auto ne = ne_a; ne[0 ] *= 3 ;
1175
+ a = ggml_new_tensor (ctx, type, 4 , ne.data ());
1176
+ ggml_set_name (a, " a" );
1177
+
1178
+ a = ggml_view_4d (ctx, a, ne_a[0 ], ne_a[1 ], ne_a[2 ], ne_a[3 ], a->nb [1 ], a->nb [2 ], a->nb [3 ], 0 );
1179
+ ggml_set_name (a, " view_of_a" );
1180
+
1181
+ b = ggml_new_tensor (ctx, type, 4 , ne.data ());
1182
+ ggml_set_name (b, " b" );
1183
+
1184
+ b = ggml_view_4d (ctx, b, ne_a[0 ], ne_a[1 ], ne_a[2 ], ne_a[3 ], b->nb [1 ], b->nb [2 ], b->nb [3 ], 0 );
1185
+ ggml_set_name (a, " view_of_b" );
1186
+ } else {
1187
+ a = ggml_new_tensor (ctx, type, 4 , ne_a.data ());
1188
+ ggml_set_name (a, " a" );
1189
+
1190
+ b = ggml_new_tensor (ctx, type, 4 , ne_a.data ());
1191
+ ggml_set_name (b, " b" );
1192
+ }
1193
+
1194
+ ggml_tensor * out = ggml_glu_split (ctx, a, b, op);
1195
+ ggml_set_name (out, " out" );
1196
+
1197
+ return out;
1198
+ }
1199
+
1200
+ void initialize_tensors (ggml_context * ctx) override {
1201
+ for (ggml_tensor * t = ggml_get_first_tensor (ctx); t != NULL ; t = ggml_get_next_tensor (ctx, t)) {
1202
+ // test extended range of values to check for NaNs in GELU
1203
+ init_tensor_uniform (t, -150 .f , 150 .f );
1204
+ }
1205
+ }
1206
+ };
1207
+
1154
1208
// GGML_OP_GET_ROWS
1155
1209
struct test_get_rows : public test_case {
1156
1210
const ggml_type type;
@@ -3986,6 +4040,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
3986
4040
test_cases.emplace_back (new test_glu ((ggml_glu_op) op, type, { 128 , 2 , 2 , 2 }, v, swapped));
3987
4041
test_cases.emplace_back (new test_glu ((ggml_glu_op) op, type, { 5 , 7 , 11 , 13 }, v, swapped));
3988
4042
}
4043
+
4044
+ test_cases.emplace_back (new test_glu_split ((ggml_glu_op) op, type, { 128 , 2 , 2 , 2 }, v));
4045
+ test_cases.emplace_back (new test_glu_split ((ggml_glu_op) op, type, { 5 , 7 , 11 , 13 }, v));
3989
4046
}
3990
4047
}
3991
4048
}
0 commit comments