@@ -79,6 +79,25 @@ func.func @masked_reduce_add_f32(%arg0: vector<16xf32>, %mask : vector<16xi1>) -
79
79
// CHECK: "llvm.intr.vp.reduce.fadd"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL]]) : (f32, vector<16xf32>, vector<16xi1>, i32) -> f32
80
80
81
81
82
+ // -----
83
+
84
+ func.func @masked_reduce_add_f32_scalable (%arg0: vector <[16 ]xf32 >, %mask : vector <[16 ]xi1 >) -> f32 {
85
+ %0 = vector.mask %mask { vector.reduction <add >, %arg0 : vector <[16 ]xf32 > into f32 } : vector <[16 ]xi1 > -> f32
86
+ return %0 : f32
87
+ }
88
+
89
+ // CHECK-LABEL: func.func @masked_reduce_add_f32_scalable(
90
+ // CHECK-SAME: %[[INPUT:.*]]: vector<[16]xf32>,
91
+ // CHECK-SAME: %[[MASK:.*]]: vector<[16]xi1>) -> f32 {
92
+ // CHECK: %[[NEUTRAL:.*]] = llvm.mlir.constant(0.000000e+00 : f32) : f32
93
+ // CHECK: %[[VL_BASE:.*]] = llvm.mlir.constant(16 : i32) : i32
94
+ // CHECK: %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64
95
+ // CHECK: %[[CAST_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index
96
+ // CHECK: %[[CAST_I32:.*]] = arith.index_cast %[[CAST_IDX]] : index to i32
97
+ // CHECK: %[[VL_MUL:.*]] = arith.muli %[[VL_BASE]], %[[CAST_I32]] : i32
98
+ // CHECK: "llvm.intr.vp.reduce.fadd"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL_MUL]]) : (f32, vector<[16]xf32>, vector<[16]xi1>, i32) -> f32
99
+
100
+
82
101
// -----
83
102
84
103
func.func @masked_reduce_mul_f32 (%arg0: vector <16 xf32 >, %mask : vector <16 xi1 >) -> f32 {
@@ -110,6 +129,24 @@ func.func @masked_reduce_minf_f32(%arg0: vector<16xf32>, %mask : vector<16xi1>)
110
129
111
130
// -----
112
131
132
+ func.func @masked_reduce_minf_f32_scalable (%arg0: vector <[16 ]xf32 >, %mask : vector <[16 ]xi1 >) -> f32 {
133
+ %0 = vector.mask %mask { vector.reduction <minnumf >, %arg0 : vector <[16 ]xf32 > into f32 } : vector <[16 ]xi1 > -> f32
134
+ return %0 : f32
135
+ }
136
+
137
+ // CHECK-LABEL: func.func @masked_reduce_minf_f32_scalable(
138
+ // CHECK-SAME: %[[INPUT:.*]]: vector<[16]xf32>,
139
+ // CHECK-SAME: %[[MASK:.*]]: vector<[16]xi1>) -> f32 {
140
+ // CHECK: %[[NEUTRAL:.*]] = llvm.mlir.constant(0xFFC00000 : f32) : f32
141
+ // CHECK: %[[VL_BASE:.*]] = llvm.mlir.constant(16 : i32) : i32
142
+ // CHECK: %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64
143
+ // CHECK: %[[CAST_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index
144
+ // CHECK: %[[CAST_I32:.*]] = arith.index_cast %[[CAST_IDX]] : index to i32
145
+ // CHECK: %[[VL_MUL:.*]] = arith.muli %[[VL_BASE]], %[[CAST_I32]] : i32
146
+ // CHECK: "llvm.intr.vp.reduce.fmin"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL_MUL]]) : (f32, vector<[16]xf32>, vector<[16]xi1>, i32) -> f32
147
+
148
+ // -----
149
+
113
150
func.func @masked_reduce_maxf_f32 (%arg0: vector <16 xf32 >, %mask : vector <16 xi1 >) -> f32 {
114
151
%0 = vector.mask %mask { vector.reduction <maxnumf >, %arg0 : vector <16 xf32 > into f32 } : vector <16 xi1 > -> f32
115
152
return %0 : f32
@@ -167,6 +204,25 @@ func.func @masked_reduce_add_i8(%arg0: vector<32xi8>, %mask : vector<32xi1>) ->
167
204
// CHECK: "llvm.intr.vp.reduce.add"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL]]) : (i8, vector<32xi8>, vector<32xi1>, i32) -> i8
168
205
169
206
207
+ // -----
208
+
209
+ func.func @masked_reduce_add_i8_scalable (%arg0: vector <[32 ]xi8 >, %mask : vector <[32 ]xi1 >) -> i8 {
210
+ %0 = vector.mask %mask { vector.reduction <add >, %arg0 : vector <[32 ]xi8 > into i8 } : vector <[32 ]xi1 > -> i8
211
+ return %0 : i8
212
+ }
213
+
214
+ // CHECK-LABEL: func.func @masked_reduce_add_i8_scalable(
215
+ // CHECK-SAME: %[[INPUT:.*]]: vector<[32]xi8>,
216
+ // CHECK-SAME: %[[MASK:.*]]: vector<[32]xi1>) -> i8 {
217
+ // CHECK: %[[NEUTRAL:.*]] = llvm.mlir.constant(0 : i8) : i8
218
+ // CHECK: %[[VL_BASE:.*]] = llvm.mlir.constant(32 : i32) : i32
219
+ // CHECK: %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64
220
+ // CHECK: %[[CAST_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index
221
+ // CHECK: %[[CAST_I32:.*]] = arith.index_cast %[[CAST_IDX]] : index to i32
222
+ // CHECK: %[[VL_MUL:.*]] = arith.muli %[[VL_BASE]], %[[CAST_I32]] : i32
223
+ // CHECK: "llvm.intr.vp.reduce.add"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL_MUL]]) : (i8, vector<[32]xi8>, vector<[32]xi1>, i32) -> i8
224
+
225
+
170
226
// -----
171
227
172
228
func.func @masked_reduce_mul_i8 (%arg0: vector <32 xi8 >, %mask : vector <32 xi1 >) -> i8 {
@@ -197,6 +253,24 @@ func.func @masked_reduce_minui_i8(%arg0: vector<32xi8>, %mask : vector<32xi1>) -
197
253
198
254
// -----
199
255
256
+ func.func @masked_reduce_minui_i8_scalable (%arg0: vector <[32 ]xi8 >, %mask : vector <[32 ]xi1 >) -> i8 {
257
+ %0 = vector.mask %mask { vector.reduction <minui >, %arg0 : vector <[32 ]xi8 > into i8 } : vector <[32 ]xi1 > -> i8
258
+ return %0 : i8
259
+ }
260
+
261
+ // CHECK-LABEL: func.func @masked_reduce_minui_i8_scalable(
262
+ // CHECK-SAME: %[[INPUT:.*]]: vector<[32]xi8>,
263
+ // CHECK-SAME: %[[MASK:.*]]: vector<[32]xi1>) -> i8 {
264
+ // CHECK: %[[NEUTRAL:.*]] = llvm.mlir.constant(-1 : i8) : i8
265
+ // CHECK: %[[VL_BASE:.*]] = llvm.mlir.constant(32 : i32) : i32
266
+ // CHECK: %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64
267
+ // CHECK: %[[CAST_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index
268
+ // CHECK: %[[CAST_I32:.*]] = arith.index_cast %[[CAST_IDX]] : index to i32
269
+ // CHECK: %[[VL_MUL:.*]] = arith.muli %[[VL_BASE]], %[[CAST_I32]] : i32
270
+ // CHECK: "llvm.intr.vp.reduce.umin"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL_MUL]]) : (i8, vector<[32]xi8>, vector<[32]xi1>, i32) -> i8
271
+
272
+ // -----
273
+
200
274
func.func @masked_reduce_maxui_i8 (%arg0: vector <32 xi8 >, %mask : vector <32 xi1 >) -> i8 {
201
275
%0 = vector.mask %mask { vector.reduction <maxui >, %arg0 : vector <32 xi8 > into i8 } : vector <32 xi1 > -> i8
202
276
return %0 : i8
@@ -239,6 +313,24 @@ func.func @masked_reduce_maxsi_i8(%arg0: vector<32xi8>, %mask : vector<32xi1>) -
239
313
240
314
// -----
241
315
316
+ func.func @masked_reduce_maxsi_i8_scalable (%arg0: vector <[32 ]xi8 >, %mask : vector <[32 ]xi1 >) -> i8 {
317
+ %0 = vector.mask %mask { vector.reduction <maxsi >, %arg0 : vector <[32 ]xi8 > into i8 } : vector <[32 ]xi1 > -> i8
318
+ return %0 : i8
319
+ }
320
+
321
+ // CHECK-LABEL: func.func @masked_reduce_maxsi_i8_scalable(
322
+ // CHECK-SAME: %[[INPUT:.*]]: vector<[32]xi8>,
323
+ // CHECK-SAME: %[[MASK:.*]]: vector<[32]xi1>) -> i8 {
324
+ // CHECK: %[[NEUTRAL:.*]] = llvm.mlir.constant(-128 : i8) : i8
325
+ // CHECK: %[[VL_BASE:.*]] = llvm.mlir.constant(32 : i32) : i32
326
+ // CHECK: %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64
327
+ // CHECK: %[[CAST_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index
328
+ // CHECK: %[[CAST_I32:.*]] = arith.index_cast %[[CAST_IDX]] : index to i32
329
+ // CHECK: %[[VL_MUL:.*]] = arith.muli %[[VL_BASE]], %[[CAST_I32]] : i32
330
+ // CHECK: "llvm.intr.vp.reduce.smax"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL_MUL]]) : (i8, vector<[32]xi8>, vector<[32]xi1>, i32) -> i8
331
+
332
+ // -----
333
+
242
334
func.func @masked_reduce_or_i8 (%arg0: vector <32 xi8 >, %mask : vector <32 xi1 >) -> i8 {
243
335
%0 = vector.mask %mask { vector.reduction <or >, %arg0 : vector <32 xi8 > into i8 } : vector <32 xi1 > -> i8
244
336
return %0 : i8
@@ -280,4 +372,22 @@ func.func @masked_reduce_xor_i8(%arg0: vector<32xi8>, %mask : vector<32xi1>) ->
280
372
// CHECK: %[[VL:.*]] = llvm.mlir.constant(32 : i32) : i32
281
373
// CHECK: "llvm.intr.vp.reduce.xor"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL]]) : (i8, vector<32xi8>, vector<32xi1>, i32) -> i8
282
374
375
+ // -----
376
+
377
+ func.func @masked_reduce_xor_i8_scalable (%arg0: vector <[32 ]xi8 >, %mask : vector <[32 ]xi1 >) -> i8 {
378
+ %0 = vector.mask %mask { vector.reduction <xor >, %arg0 : vector <[32 ]xi8 > into i8 } : vector <[32 ]xi1 > -> i8
379
+ return %0 : i8
380
+ }
381
+
382
+ // CHECK-LABEL: func.func @masked_reduce_xor_i8_scalable(
383
+ // CHECK-SAME: %[[INPUT:.*]]: vector<[32]xi8>,
384
+ // CHECK-SAME: %[[MASK:.*]]: vector<[32]xi1>) -> i8 {
385
+ // CHECK: %[[NEUTRAL:.*]] = llvm.mlir.constant(0 : i8) : i8
386
+ // CHECK: %[[VL_BASE:.*]] = llvm.mlir.constant(32 : i32) : i32
387
+ // CHECK: %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64
388
+ // CHECK: %[[CAST_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index
389
+ // CHECK: %[[CAST_I32:.*]] = arith.index_cast %[[CAST_IDX]] : index to i32
390
+ // CHECK: %[[VL_MUL:.*]] = arith.muli %[[VL_BASE]], %[[CAST_I32]] : i32
391
+ // CHECK: "llvm.intr.vp.reduce.xor"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL_MUL]]) : (i8, vector<[32]xi8>, vector<[32]xi1>, i32) -> i8
392
+
283
393
0 commit comments