Skip to content

Commit 2d76274

Browse files
committed
[mlir][VectorOps] Loosen restrictions on vector.reduction types
LLVM can deal with any integer or float type, don't arbitrarily restrict it to f32/f64/i32/i64. Differential Revision: https://reviews.llvm.org/D88010
1 parent f4c5cad commit 2d76274

File tree

3 files changed

+25
-6
lines changed

3 files changed

+25
-6
lines changed

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -561,7 +561,7 @@ class VectorReductionOpConversion : public ConvertToLLVMPattern {
561561
auto kind = reductionOp.kind();
562562
Type eltType = reductionOp.dest().getType();
563563
Type llvmType = typeConverter.convertType(eltType);
564-
if (eltType.isSignlessInteger(32) || eltType.isSignlessInteger(64)) {
564+
if (eltType.isSignlessInteger()) {
565565
// Integer reductions: add/mul/min/max/and/or/xor.
566566
if (kind == "add")
567567
rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_add>(
@@ -588,7 +588,7 @@ class VectorReductionOpConversion : public ConvertToLLVMPattern {
588588
return failure();
589589
return success();
590590

591-
} else if (eltType.isF32() || eltType.isF64()) {
591+
} else if (eltType.isa<FloatType>()) {
592592
// Floating-point reductions: add/mul/min/max
593593
if (kind == "add") {
594594
// Optional accumulator (or zero).

mlir/lib/Dialect/Vector/VectorOps.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -132,11 +132,10 @@ static LogicalResult verify(ReductionOp op) {
132132
auto kind = op.kind();
133133
Type eltType = op.dest().getType();
134134
if (kind == "add" || kind == "mul" || kind == "min" || kind == "max") {
135-
if (!eltType.isF32() && !eltType.isF64() &&
136-
!eltType.isSignlessInteger(32) && !eltType.isSignlessInteger(64))
135+
if (!eltType.isSignlessIntOrFloat())
137136
return op.emitOpError("unsupported reduction type");
138137
} else if (kind == "and" || kind == "or" || kind == "xor") {
139-
if (!eltType.isSignlessInteger(32) && !eltType.isSignlessInteger(64))
138+
if (!eltType.isSignlessInteger())
140139
return op.emitOpError("unsupported reduction type");
141140
} else {
142141
return op.emitOpError("unknown reduction kind: ") << kind;
@@ -146,7 +145,7 @@ static LogicalResult verify(ReductionOp op) {
146145
if (!op.acc().empty()) {
147146
if (kind != "add" && kind != "mul")
148147
return op.emitOpError("no accumulator for reduction kind: ") << kind;
149-
if (!eltType.isF32() && !eltType.isF64())
148+
if (!eltType.isa<FloatType>())
150149
return op.emitOpError("no accumulator for type: ") << eltType;
151150
}
152151

mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -678,6 +678,17 @@ func @vector_fma(%a: vector<8xf32>, %b: vector<2x4xf32>) -> (vector<8xf32>, vect
678678
return %0, %1: vector<8xf32>, vector<2x4xf32>
679679
}
680680

681+
func @reduce_f16(%arg0: vector<16xf16>) -> f16 {
682+
%0 = vector.reduction "add", %arg0 : vector<16xf16> into f16
683+
return %0 : f16
684+
}
685+
// CHECK-LABEL: llvm.func @reduce_f16(
686+
// CHECK-SAME: %[[A:.*]]: !llvm.vec<16 x half>)
687+
// CHECK: %[[C:.*]] = llvm.mlir.constant(0.000000e+00 : f16) : !llvm.half
688+
// CHECK: %[[V:.*]] = "llvm.intr.experimental.vector.reduce.v2.fadd"(%[[C]], %[[A]])
689+
// CHECK-SAME: {reassoc = false} : (!llvm.half, !llvm.vec<16 x half>) -> !llvm.half
690+
// CHECK: llvm.return %[[V]] : !llvm.half
691+
681692
func @reduce_f32(%arg0: vector<16xf32>) -> f32 {
682693
%0 = vector.reduction "add", %arg0 : vector<16xf32> into f32
683694
return %0 : f32
@@ -700,6 +711,15 @@ func @reduce_f64(%arg0: vector<16xf64>) -> f64 {
700711
// CHECK-SAME: {reassoc = false} : (!llvm.double, !llvm.vec<16 x double>) -> !llvm.double
701712
// CHECK: llvm.return %[[V]] : !llvm.double
702713

714+
func @reduce_i8(%arg0: vector<16xi8>) -> i8 {
715+
%0 = vector.reduction "add", %arg0 : vector<16xi8> into i8
716+
return %0 : i8
717+
}
718+
// CHECK-LABEL: llvm.func @reduce_i8(
719+
// CHECK-SAME: %[[A:.*]]: !llvm.vec<16 x i8>)
720+
// CHECK: %[[V:.*]] = "llvm.intr.experimental.vector.reduce.add"(%[[A]])
721+
// CHECK: llvm.return %[[V]] : !llvm.i8
722+
703723
func @reduce_i32(%arg0: vector<16xi32>) -> i32 {
704724
%0 = vector.reduction "add", %arg0 : vector<16xi32> into i32
705725
return %0 : i32

0 commit comments

Comments
 (0)