Skip to content

Commit 54759ce

Browse files
committed
[mlir] [VectorOps] changes to printing support for integers
(1) simplify integer printing logic by always using 64-bit print (2) add index support (since vector<16xindex> is planned to be added) (3) adjust naming convention print_x -> printX Reviewed By: bkramer Differential Revision: https://reviews.llvm.org/D88436
1 parent 83dc53d commit 54759ce

File tree

8 files changed

+168
-184
lines changed

8 files changed

+168
-184
lines changed

mlir/include/mlir/ExecutionEngine/CRunnerUtils.h

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -200,14 +200,14 @@ class DynamicMemRefType {
200200
//===----------------------------------------------------------------------===//
201201
// Small runtime support "lib" for vector.print lowering during codegen.
202202
//===----------------------------------------------------------------------===//
203-
extern "C" MLIR_CRUNNERUTILS_EXPORT void print_i32(int32_t i);
204-
extern "C" MLIR_CRUNNERUTILS_EXPORT void print_i64(int64_t l);
205-
extern "C" MLIR_CRUNNERUTILS_EXPORT void print_f32(float f);
206-
extern "C" MLIR_CRUNNERUTILS_EXPORT void print_f64(double d);
207-
extern "C" MLIR_CRUNNERUTILS_EXPORT void print_open();
208-
extern "C" MLIR_CRUNNERUTILS_EXPORT void print_close();
209-
extern "C" MLIR_CRUNNERUTILS_EXPORT void print_comma();
210-
extern "C" MLIR_CRUNNERUTILS_EXPORT void print_newline();
203+
extern "C" MLIR_CRUNNERUTILS_EXPORT void printI64(int64_t i);
204+
extern "C" MLIR_CRUNNERUTILS_EXPORT void printU64(uint64_t u);
205+
extern "C" MLIR_CRUNNERUTILS_EXPORT void printF32(float f);
206+
extern "C" MLIR_CRUNNERUTILS_EXPORT void printF64(double d);
207+
extern "C" MLIR_CRUNNERUTILS_EXPORT void printOpen();
208+
extern "C" MLIR_CRUNNERUTILS_EXPORT void printClose();
209+
extern "C" MLIR_CRUNNERUTILS_EXPORT void printComma();
210+
extern "C" MLIR_CRUNNERUTILS_EXPORT void printNewline();
211211

212212
#endif // EXECUTIONENGINE_CRUNNERUTILS_H_
213213

mlir/integration_test/Dialect/LLVMIR/CPU/test-vector-reductions-fp.mlir

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44

55
// End-to-end test of all fp reduction intrinsics (not exhaustive unit tests).
66
module {
7-
llvm.func @print_newline()
8-
llvm.func @print_f32(!llvm.float)
7+
llvm.func @printNewline()
8+
llvm.func @printF32(!llvm.float)
99
llvm.func @entry() {
1010
// Setup (1,2,3,4).
1111
%0 = llvm.mlir.constant(1.000000e+00 : f32) : !llvm.float
@@ -26,62 +26,62 @@ module {
2626

2727
%max = "llvm.intr.experimental.vector.reduce.fmax"(%v)
2828
: (!llvm.vec<4 x float>) -> !llvm.float
29-
llvm.call @print_f32(%max) : (!llvm.float) -> ()
30-
llvm.call @print_newline() : () -> ()
29+
llvm.call @printF32(%max) : (!llvm.float) -> ()
30+
llvm.call @printNewline() : () -> ()
3131
// CHECK: 4
3232

3333
%min = "llvm.intr.experimental.vector.reduce.fmin"(%v)
3434
: (!llvm.vec<4 x float>) -> !llvm.float
35-
llvm.call @print_f32(%min) : (!llvm.float) -> ()
36-
llvm.call @print_newline() : () -> ()
35+
llvm.call @printF32(%min) : (!llvm.float) -> ()
36+
llvm.call @printNewline() : () -> ()
3737
// CHECK: 1
3838

3939
%add1 = "llvm.intr.experimental.vector.reduce.v2.fadd"(%0, %v)
4040
: (!llvm.float, !llvm.vec<4 x float>) -> !llvm.float
41-
llvm.call @print_f32(%add1) : (!llvm.float) -> ()
42-
llvm.call @print_newline() : () -> ()
41+
llvm.call @printF32(%add1) : (!llvm.float) -> ()
42+
llvm.call @printNewline() : () -> ()
4343
// CHECK: 11
4444

4545
%add1r = "llvm.intr.experimental.vector.reduce.v2.fadd"(%0, %v)
4646
{reassoc = true} : (!llvm.float, !llvm.vec<4 x float>) -> !llvm.float
47-
llvm.call @print_f32(%add1r) : (!llvm.float) -> ()
48-
llvm.call @print_newline() : () -> ()
47+
llvm.call @printF32(%add1r) : (!llvm.float) -> ()
48+
llvm.call @printNewline() : () -> ()
4949
// CHECK: 11
5050

5151
%add2 = "llvm.intr.experimental.vector.reduce.v2.fadd"(%1, %v)
5252
: (!llvm.float, !llvm.vec<4 x float>) -> !llvm.float
53-
llvm.call @print_f32(%add2) : (!llvm.float) -> ()
54-
llvm.call @print_newline() : () -> ()
53+
llvm.call @printF32(%add2) : (!llvm.float) -> ()
54+
llvm.call @printNewline() : () -> ()
5555
// CHECK: 12
5656

5757
%add2r = "llvm.intr.experimental.vector.reduce.v2.fadd"(%1, %v)
5858
{reassoc = true} : (!llvm.float, !llvm.vec<4 x float>) -> !llvm.float
59-
llvm.call @print_f32(%add2r) : (!llvm.float) -> ()
60-
llvm.call @print_newline() : () -> ()
59+
llvm.call @printF32(%add2r) : (!llvm.float) -> ()
60+
llvm.call @printNewline() : () -> ()
6161
// CHECK: 12
6262

6363
%mul1 = "llvm.intr.experimental.vector.reduce.v2.fmul"(%0, %v)
6464
: (!llvm.float, !llvm.vec<4 x float>) -> !llvm.float
65-
llvm.call @print_f32(%mul1) : (!llvm.float) -> ()
66-
llvm.call @print_newline() : () -> ()
65+
llvm.call @printF32(%mul1) : (!llvm.float) -> ()
66+
llvm.call @printNewline() : () -> ()
6767
// CHECK: 24
6868

6969
%mul1r = "llvm.intr.experimental.vector.reduce.v2.fmul"(%0, %v)
7070
{reassoc = true} : (!llvm.float, !llvm.vec<4 x float>) -> !llvm.float
71-
llvm.call @print_f32(%mul1r) : (!llvm.float) -> ()
72-
llvm.call @print_newline() : () -> ()
71+
llvm.call @printF32(%mul1r) : (!llvm.float) -> ()
72+
llvm.call @printNewline() : () -> ()
7373
// CHECK: 24
7474

7575
%mul2 = "llvm.intr.experimental.vector.reduce.v2.fmul"(%1, %v)
7676
: (!llvm.float, !llvm.vec<4 x float>) -> !llvm.float
77-
llvm.call @print_f32(%mul2) : (!llvm.float) -> ()
78-
llvm.call @print_newline() : () -> ()
77+
llvm.call @printF32(%mul2) : (!llvm.float) -> ()
78+
llvm.call @printNewline() : () -> ()
7979
// CHECK: 48
8080

8181
%mul2r = "llvm.intr.experimental.vector.reduce.v2.fmul"(%1, %v)
8282
{reassoc = true} : (!llvm.float, !llvm.vec<4 x float>) -> !llvm.float
83-
llvm.call @print_f32(%mul2r) : (!llvm.float) -> ()
84-
llvm.call @print_newline() : () -> ()
83+
llvm.call @printF32(%mul2r) : (!llvm.float) -> ()
84+
llvm.call @printNewline() : () -> ()
8585
// CHECK: 48
8686

8787
llvm.return

mlir/integration_test/Dialect/LLVMIR/CPU/test-vector-reductions-int.mlir

Lines changed: 40 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -4,78 +4,78 @@
44

55
// End-to-end test of all int reduction intrinsics (not exhaustive unit tests).
66
module {
7-
llvm.func @print_newline()
8-
llvm.func @print_i32(!llvm.i32)
7+
llvm.func @printNewline()
8+
llvm.func @printI64(!llvm.i64)
99
llvm.func @entry() {
1010
// Setup (1,2,3,4).
11-
%0 = llvm.mlir.constant(1 : i32) : !llvm.i32
12-
%1 = llvm.mlir.constant(2 : i32) : !llvm.i32
13-
%2 = llvm.mlir.constant(3 : i32) : !llvm.i32
14-
%3 = llvm.mlir.constant(4 : i32) : !llvm.i32
15-
%4 = llvm.mlir.undef : !llvm.vec<4 x i32>
11+
%0 = llvm.mlir.constant(1 : i64) : !llvm.i64
12+
%1 = llvm.mlir.constant(2 : i64) : !llvm.i64
13+
%2 = llvm.mlir.constant(3 : i64) : !llvm.i64
14+
%3 = llvm.mlir.constant(4 : i64) : !llvm.i64
15+
%4 = llvm.mlir.undef : !llvm.vec<4 x i64>
1616
%5 = llvm.mlir.constant(0 : index) : !llvm.i64
17-
%6 = llvm.insertelement %0, %4[%5 : !llvm.i64] : !llvm.vec<4 x i32>
18-
%7 = llvm.shufflevector %6, %4 [0 : i32, 0 : i32, 0 : i32, 0 : i32]
19-
: !llvm.vec<4 x i32>, !llvm.vec<4 x i32>
17+
%6 = llvm.insertelement %0, %4[%5 : !llvm.i64] : !llvm.vec<4 x i64>
18+
%7 = llvm.shufflevector %6, %4 [0 : i64, 0 : i64, 0 : i64, 0 : i64]
19+
: !llvm.vec<4 x i64>, !llvm.vec<4 x i64>
2020
%8 = llvm.mlir.constant(1 : i64) : !llvm.i64
21-
%9 = llvm.insertelement %1, %7[%8 : !llvm.i64] : !llvm.vec<4 x i32>
21+
%9 = llvm.insertelement %1, %7[%8 : !llvm.i64] : !llvm.vec<4 x i64>
2222
%10 = llvm.mlir.constant(2 : i64) : !llvm.i64
23-
%11 = llvm.insertelement %2, %9[%10 : !llvm.i64] : !llvm.vec<4 x i32>
23+
%11 = llvm.insertelement %2, %9[%10 : !llvm.i64] : !llvm.vec<4 x i64>
2424
%12 = llvm.mlir.constant(3 : i64) : !llvm.i64
25-
%v = llvm.insertelement %3, %11[%12 : !llvm.i64] : !llvm.vec<4 x i32>
25+
%v = llvm.insertelement %3, %11[%12 : !llvm.i64] : !llvm.vec<4 x i64>
2626

2727
%add = "llvm.intr.experimental.vector.reduce.add"(%v)
28-
: (!llvm.vec<4 x i32>) -> !llvm.i32
29-
llvm.call @print_i32(%add) : (!llvm.i32) -> ()
30-
llvm.call @print_newline() : () -> ()
28+
: (!llvm.vec<4 x i64>) -> !llvm.i64
29+
llvm.call @printI64(%add) : (!llvm.i64) -> ()
30+
llvm.call @printNewline() : () -> ()
3131
// CHECK: 10
3232

3333
%and = "llvm.intr.experimental.vector.reduce.and"(%v)
34-
: (!llvm.vec<4 x i32>) -> !llvm.i32
35-
llvm.call @print_i32(%and) : (!llvm.i32) -> ()
36-
llvm.call @print_newline() : () -> ()
34+
: (!llvm.vec<4 x i64>) -> !llvm.i64
35+
llvm.call @printI64(%and) : (!llvm.i64) -> ()
36+
llvm.call @printNewline() : () -> ()
3737
// CHECK: 0
3838

3939
%mul = "llvm.intr.experimental.vector.reduce.mul"(%v)
40-
: (!llvm.vec<4 x i32>) -> !llvm.i32
41-
llvm.call @print_i32(%mul) : (!llvm.i32) -> ()
42-
llvm.call @print_newline() : () -> ()
40+
: (!llvm.vec<4 x i64>) -> !llvm.i64
41+
llvm.call @printI64(%mul) : (!llvm.i64) -> ()
42+
llvm.call @printNewline() : () -> ()
4343
// CHECK: 24
4444

4545
%or = "llvm.intr.experimental.vector.reduce.or"(%v)
46-
: (!llvm.vec<4 x i32>) -> !llvm.i32
47-
llvm.call @print_i32(%or) : (!llvm.i32) -> ()
48-
llvm.call @print_newline() : () -> ()
46+
: (!llvm.vec<4 x i64>) -> !llvm.i64
47+
llvm.call @printI64(%or) : (!llvm.i64) -> ()
48+
llvm.call @printNewline() : () -> ()
4949
// CHECK: 7
5050

5151
%smax = "llvm.intr.experimental.vector.reduce.smax"(%v)
52-
: (!llvm.vec<4 x i32>) -> !llvm.i32
53-
llvm.call @print_i32(%smax) : (!llvm.i32) -> ()
54-
llvm.call @print_newline() : () -> ()
52+
: (!llvm.vec<4 x i64>) -> !llvm.i64
53+
llvm.call @printI64(%smax) : (!llvm.i64) -> ()
54+
llvm.call @printNewline() : () -> ()
5555
// CHECK: 4
5656

5757
%smin = "llvm.intr.experimental.vector.reduce.smin"(%v)
58-
: (!llvm.vec<4 x i32>) -> !llvm.i32
59-
llvm.call @print_i32(%smin) : (!llvm.i32) -> ()
60-
llvm.call @print_newline() : () -> ()
58+
: (!llvm.vec<4 x i64>) -> !llvm.i64
59+
llvm.call @printI64(%smin) : (!llvm.i64) -> ()
60+
llvm.call @printNewline() : () -> ()
6161
// CHECK: 1
6262

6363
%umax = "llvm.intr.experimental.vector.reduce.umax"(%v)
64-
: (!llvm.vec<4 x i32>) -> !llvm.i32
65-
llvm.call @print_i32(%umax) : (!llvm.i32) -> ()
66-
llvm.call @print_newline() : () -> ()
64+
: (!llvm.vec<4 x i64>) -> !llvm.i64
65+
llvm.call @printI64(%umax) : (!llvm.i64) -> ()
66+
llvm.call @printNewline() : () -> ()
6767
// CHECK: 4
6868

6969
%umin = "llvm.intr.experimental.vector.reduce.umin"(%v)
70-
: (!llvm.vec<4 x i32>) -> !llvm.i32
71-
llvm.call @print_i32(%umin) : (!llvm.i32) -> ()
72-
llvm.call @print_newline() : () -> ()
70+
: (!llvm.vec<4 x i64>) -> !llvm.i64
71+
llvm.call @printI64(%umin) : (!llvm.i64) -> ()
72+
llvm.call @printNewline() : () -> ()
7373
// CHECK: 1
7474

7575
%xor = "llvm.intr.experimental.vector.reduce.xor"(%v)
76-
: (!llvm.vec<4 x i32>) -> !llvm.i32
77-
llvm.call @print_i32(%xor) : (!llvm.i32) -> ()
78-
llvm.call @print_newline() : () -> ()
76+
: (!llvm.vec<4 x i64>) -> !llvm.i64
77+
llvm.call @printI64(%xor) : (!llvm.i64) -> ()
78+
llvm.call @printNewline() : () -> ()
7979
// CHECK: 4
8080

8181
llvm.return

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp

Lines changed: 13 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1328,17 +1328,15 @@ class VectorPrintOpConversion : public ConvertToLLVMPattern {
13281328
printer = getPrintFloat(op);
13291329
} else if (eltType.isF64()) {
13301330
printer = getPrintDouble(op);
1331+
} else if (eltType.isIndex()) {
1332+
printer = getPrintU64(op);
13311333
} else if (auto intTy = eltType.dyn_cast<IntegerType>()) {
13321334
// Integers need a zero or sign extension on the operand
13331335
// (depending on the source type) as well as a signed or
13341336
// unsigned print method. Up to 64-bit is supported.
13351337
unsigned width = intTy.getWidth();
13361338
if (intTy.isUnsigned()) {
1337-
if (width <= 32) {
1338-
if (width < 32)
1339-
conversion = PrintConversion::ZeroExt32;
1340-
printer = getPrintU32(op);
1341-
} else if (width <= 64) {
1339+
if (width <= 64) {
13421340
if (width < 64)
13431341
conversion = PrintConversion::ZeroExt64;
13441342
printer = getPrintU64(op);
@@ -1347,16 +1345,12 @@ class VectorPrintOpConversion : public ConvertToLLVMPattern {
13471345
}
13481346
} else {
13491347
assert(intTy.isSignless() || intTy.isSigned());
1350-
if (width <= 32) {
1348+
if (width <= 64) {
13511349
// Note that we *always* zero extend booleans (1-bit integers),
13521350
// so that true/false is printed as 1/0 rather than -1/0.
13531351
if (width == 1)
1354-
conversion = PrintConversion::ZeroExt32;
1355-
else if (width < 32)
1356-
conversion = PrintConversion::SignExt32;
1357-
printer = getPrintI32(op);
1358-
} else if (width <= 64) {
1359-
if (width < 64)
1352+
conversion = PrintConversion::ZeroExt64;
1353+
else if (width < 64)
13601354
conversion = PrintConversion::SignExt64;
13611355
printer = getPrintI64(op);
13621356
} else {
@@ -1379,8 +1373,6 @@ class VectorPrintOpConversion : public ConvertToLLVMPattern {
13791373
private:
13801374
enum class PrintConversion {
13811375
None,
1382-
ZeroExt32,
1383-
SignExt32,
13841376
ZeroExt64,
13851377
SignExt64
13861378
};
@@ -1391,14 +1383,6 @@ class VectorPrintOpConversion : public ConvertToLLVMPattern {
13911383
Location loc = op->getLoc();
13921384
if (rank == 0) {
13931385
switch (conversion) {
1394-
case PrintConversion::ZeroExt32:
1395-
value = rewriter.create<ZeroExtendIOp>(
1396-
loc, value, LLVM::LLVMType::getInt32Ty(rewriter.getContext()));
1397-
break;
1398-
case PrintConversion::SignExt32:
1399-
value = rewriter.create<SignExtendIOp>(
1400-
loc, value, LLVM::LLVMType::getInt32Ty(rewriter.getContext()));
1401-
break;
14021386
case PrintConversion::ZeroExt64:
14031387
value = rewriter.create<ZeroExtendIOp>(
14041388
loc, value, LLVM::LLVMType::getInt64Ty(rewriter.getContext()));
@@ -1455,41 +1439,33 @@ class VectorPrintOpConversion : public ConvertToLLVMPattern {
14551439
}
14561440

14571441
// Helpers for method names.
1458-
Operation *getPrintI32(Operation *op) const {
1459-
return getPrint(op, "print_i32",
1460-
LLVM::LLVMType::getInt32Ty(op->getContext()));
1461-
}
14621442
Operation *getPrintI64(Operation *op) const {
1463-
return getPrint(op, "print_i64",
1443+
return getPrint(op, "printI64",
14641444
LLVM::LLVMType::getInt64Ty(op->getContext()));
14651445
}
1466-
Operation *getPrintU32(Operation *op) const {
1467-
return getPrint(op, "printU32",
1468-
LLVM::LLVMType::getInt32Ty(op->getContext()));
1469-
}
14701446
Operation *getPrintU64(Operation *op) const {
14711447
return getPrint(op, "printU64",
14721448
LLVM::LLVMType::getInt64Ty(op->getContext()));
14731449
}
14741450
Operation *getPrintFloat(Operation *op) const {
1475-
return getPrint(op, "print_f32",
1451+
return getPrint(op, "printF32",
14761452
LLVM::LLVMType::getFloatTy(op->getContext()));
14771453
}
14781454
Operation *getPrintDouble(Operation *op) const {
1479-
return getPrint(op, "print_f64",
1455+
return getPrint(op, "printF64",
14801456
LLVM::LLVMType::getDoubleTy(op->getContext()));
14811457
}
14821458
Operation *getPrintOpen(Operation *op) const {
1483-
return getPrint(op, "print_open", {});
1459+
return getPrint(op, "printOpen", {});
14841460
}
14851461
Operation *getPrintClose(Operation *op) const {
1486-
return getPrint(op, "print_close", {});
1462+
return getPrint(op, "printClose", {});
14871463
}
14881464
Operation *getPrintComma(Operation *op) const {
1489-
return getPrint(op, "print_comma", {});
1465+
return getPrint(op, "printComma", {});
14901466
}
14911467
Operation *getPrintNewline(Operation *op) const {
1492-
return getPrint(op, "print_newline", {});
1468+
return getPrint(op, "printNewline", {});
14931469
}
14941470
};
14951471

mlir/lib/ExecutionEngine/CRunnerUtils.cpp

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,13 @@
2323
// By providing elementary printing methods only, this
2424
// library can remain fully unaware of low-level implementation
2525
// details of our vectors. Also useful for direct LLVM IR output.
26-
extern "C" void print_i32(int32_t i) { fprintf(stdout, "%" PRId32, i); }
27-
extern "C" void print_i64(int64_t l) { fprintf(stdout, "%" PRId64, l); }
28-
extern "C" void printU32(uint32_t i) { fprintf(stdout, "%" PRIu32, i); }
29-
extern "C" void printU64(uint64_t l) { fprintf(stdout, "%" PRIu64, l); }
30-
extern "C" void print_f32(float f) { fprintf(stdout, "%g", f); }
31-
extern "C" void print_f64(double d) { fprintf(stdout, "%lg", d); }
32-
extern "C" void print_open() { fputs("( ", stdout); }
33-
extern "C" void print_close() { fputs(" )", stdout); }
34-
extern "C" void print_comma() { fputs(", ", stdout); }
35-
extern "C" void print_newline() { fputc('\n', stdout); }
26+
extern "C" void printI64(int64_t i) { fprintf(stdout, "%" PRId64, i); }
27+
extern "C" void printU64(uint64_t u) { fprintf(stdout, "%" PRIu64, u); }
28+
extern "C" void printF32(float f) { fprintf(stdout, "%g", f); }
29+
extern "C" void printF64(double d) { fprintf(stdout, "%lg", d); }
30+
extern "C" void printOpen() { fputs("( ", stdout); }
31+
extern "C" void printClose() { fputs(" )", stdout); }
32+
extern "C" void printComma() { fputs(", ", stdout); }
33+
extern "C" void printNewline() { fputc('\n', stdout); }
3634

3735
#endif

0 commit comments

Comments
 (0)