Skip to content

Commit 0d65000

Browse files
committed
[MLIR] Add llvm.mlir.cast op for semantic preserving cast between dialect types.
Summary: See discussion here: https://llvm.discourse.group/t/rfc-dialect-type-cast-op/538/11 Reviewers: ftynse Subscribers: bixia, sanjoy.google, mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, nicolasvasilache, arpith-jacob, mgester, lucyrfox, aartbik, liufengdb, Joonsoo, llvm-commits Differential Revision: https://reviews.llvm.org/D75141
1 parent 2a00ae3 commit 0d65000

File tree

5 files changed

+147
-2
lines changed

5 files changed

+147
-2
lines changed

mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -686,6 +686,25 @@ def LLVM_ConstantOp
686686
let assemblyFormat = "`(` $value `)` attr-dict `:` type($res)";
687687
}
688688

689+
def LLVM_DialectCastOp : LLVM_Op<"mlir.cast", [NoSideEffect]>,
690+
Results<(outs AnyType:$res)>,
691+
Arguments<(ins AnyType:$in)> {
692+
let summary = "Type cast between LLVM dialect and Standard.";
693+
let description = [{
694+
llvm.mlir.cast op casts between Standard and LLVM dialects. It only changes
695+
the dialect, but does not change compile-time or runtime semantics.
696+
697+
Notice that index type is not supported, as it's Standard-specific.
698+
699+
Example:
700+
llvm.mlir.cast %v : f16 to llvm.half
701+
llvm.mlir.cast %v : llvm.float to f32
702+
llvm.mlir.cast %v : !llvm<"<2 x float>"> to vector<2xf32>
703+
}];
704+
let assemblyFormat = "$in attr-dict `:` type($in) `to` type($res)";
705+
let verifier = "return ::verify(*this);";
706+
}
707+
689708
// Operations that correspond to LLVM intrinsics. With MLIR operation set being
690709
// extendable, there is no reason to introduce a hard boundary between "core"
691710
// operations and intrinsics. However, we systematically prefix them with

mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1807,6 +1807,24 @@ struct MemRefCastOpLowering : public LLVMLegalizationPattern<MemRefCastOp> {
18071807
}
18081808
};
18091809

1810+
struct DialectCastOpLowering
1811+
: public LLVMLegalizationPattern<LLVM::DialectCastOp> {
1812+
using LLVMLegalizationPattern<LLVM::DialectCastOp>::LLVMLegalizationPattern;
1813+
1814+
PatternMatchResult
1815+
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1816+
ConversionPatternRewriter &rewriter) const override {
1817+
auto castOp = cast<LLVM::DialectCastOp>(op);
1818+
OperandAdaptor<LLVM::DialectCastOp> transformed(operands);
1819+
if (transformed.in().getType() !=
1820+
typeConverter.convertType(castOp.getType())) {
1821+
return matchFailure();
1822+
}
1823+
rewriter.replaceOp(op, transformed.in());
1824+
return matchSuccess();
1825+
}
1826+
};
1827+
18101828
// A `dim` is converted to a constant for static sizes and to an access to the
18111829
// size stored in the memref descriptor for dynamic sizes.
18121830
struct DimOpLowering : public LLVMLegalizationPattern<DimOp> {
@@ -2772,6 +2790,7 @@ void mlir::populateStdToLLVMNonMemoryConversionPatterns(
27722790
CopySignOpLowering,
27732791
CosOpLowering,
27742792
ConstLLVMOpLowering,
2793+
DialectCastOpLowering,
27752794
DivFOpLowering,
27762795
ExpOpLowering,
27772796
LogOpLowering,
@@ -2988,6 +3007,7 @@ struct LLVMLoweringPass : public ModulePass<LLVMLoweringPass> {
29883007
mlir::LLVMConversionTarget::LLVMConversionTarget(MLIRContext &ctx)
29893008
: ConversionTarget(ctx) {
29903009
this->addLegalDialect<LLVM::LLVMDialect>();
3010+
this->addIllegalOp<LLVM::DialectCastOp>();
29913011
}
29923012

29933013
std::unique_ptr<OpPassBase<ModuleOp>>

mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -890,6 +890,45 @@ static void printGlobalOp(OpAsmPrinter &p, GlobalOp op) {
890890
p.printRegion(initializer, /*printEntryBlockArgs=*/false);
891891
}
892892

893+
//===----------------------------------------------------------------------===//
894+
// Verifier for LLVM::DialectCastOp.
895+
//===----------------------------------------------------------------------===//
896+
897+
static LogicalResult verify(DialectCastOp op) {
898+
auto verifyMLIRCastType = [&op](Type type) -> LogicalResult {
899+
if (auto llvmType = type.dyn_cast<LLVM::LLVMType>()) {
900+
if (llvmType.isVectorTy())
901+
llvmType = llvmType.getVectorElementType();
902+
if (llvmType.isIntegerTy() || llvmType.isHalfTy() ||
903+
llvmType.isFloatTy() || llvmType.isDoubleTy()) {
904+
return success();
905+
}
906+
return op.emitOpError("type must be non-index integer types, float "
907+
"types, or vector of mentioned types.");
908+
}
909+
if (auto vectorType = type.dyn_cast<VectorType>()) {
910+
if (vectorType.getShape().size() > 1)
911+
return op.emitOpError("only 1-d vector is allowed");
912+
type = vectorType.getElementType();
913+
}
914+
if (type.isSignlessIntOrFloat())
915+
return success();
916+
// Note that memrefs are not supported. We currently don't have a use case
917+
// for it, but even if we do, there are challenges:
918+
// * if we allow memrefs to cast from/to memref descriptors, then the
919+
// semantics of the cast op depends on the implementation detail of the
920+
// descriptor.
921+
// * if we allow memrefs to cast from/to bare pointers, some users might
922+
// alternatively want metadata that only present in the descriptor.
923+
//
924+
// TODO(timshen): re-evaluate the memref cast design when it's needed.
925+
return op.emitOpError("type must be non-index integer types, float types, "
926+
"or vector of mentioned types.");
927+
};
928+
return failure(failed(verifyMLIRCastType(op.in().getType())) ||
929+
failed(verifyMLIRCastType(op.getType())));
930+
}
931+
893932
// Parses one of the keywords provided in the list `keywords` and returns the
894933
// position of the parsed keyword in the list. If none of the keywords from the
895934
// list is parsed, returns -1.

mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -910,3 +910,39 @@ func @assume_alignment(%0 : memref<4x4xf16>) {
910910
assume_alignment %0, 16 : memref<4x4xf16>
911911
return
912912
}
913+
914+
// -----
915+
916+
// CHECK-LABEL: func @mlir_cast_to_llvm
917+
// CHECK-SAME: %[[ARG:.*]]:
918+
func @mlir_cast_to_llvm(%0 : vector<2xf16>) -> !llvm<"<2 x half>"> {
919+
%1 = llvm.mlir.cast %0 : vector<2xf16> to !llvm<"<2 x half>">
920+
// CHECK-NEXT: llvm.return %[[ARG]]
921+
return %1 : !llvm<"<2 x half>">
922+
}
923+
924+
// CHECK-LABEL: func @mlir_cast_from_llvm
925+
// CHECK-SAME: %[[ARG:.*]]:
926+
func @mlir_cast_from_llvm(%0 : !llvm<"<2 x half>">) -> vector<2xf16> {
927+
%1 = llvm.mlir.cast %0 : !llvm<"<2 x half>"> to vector<2xf16>
928+
// CHECK-NEXT: llvm.return %[[ARG]]
929+
return %1 : vector<2xf16>
930+
}
931+
932+
// -----
933+
934+
// CHECK-LABEL: func @mlir_cast_to_llvm
935+
// CHECK-SAME: %[[ARG:.*]]:
936+
func @mlir_cast_to_llvm(%0 : f16) -> !llvm.half {
937+
%1 = llvm.mlir.cast %0 : f16 to !llvm.half
938+
// CHECK-NEXT: llvm.return %[[ARG]]
939+
return %1 : !llvm.half
940+
}
941+
942+
// CHECK-LABEL: func @mlir_cast_from_llvm
943+
// CHECK-SAME: %[[ARG:.*]]:
944+
func @mlir_cast_from_llvm(%0 : !llvm.half) -> f16 {
945+
%1 = llvm.mlir.cast %0 : !llvm.half to f16
946+
// CHECK-NEXT: llvm.return %[[ARG]]
947+
return %1 : f16
948+
}
Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,44 @@
1-
// RUN: mlir-opt %s -verify-diagnostics -split-input-file
1+
// RUN: mlir-opt %s -convert-std-to-llvm -verify-diagnostics -split-input-file
22

33
#map1 = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>
44

55
func @invalid_memref_cast(%arg0: memref<?x?xf64>) {
66
%c1 = constant 1 : index
77
%c0 = constant 0 : index
8-
// expected-error@+1: 'std.memref_cast' op operand #0 must be unranked.memref of any type values or memref of any type values,
8+
// expected-error@+1 {{'std.memref_cast' op operand #0 must be unranked.memref of any type values or memref of any type values, but got '!llvm<"{ double*, double*, i64, [2 x i64], [2 x i64] }">'}}
99
%5 = memref_cast %arg0 : memref<?x?xf64> to memref<?x?xf64, #map1>
1010
%25 = std.subview %5[%c0, %c0][%c1, %c1][] : memref<?x?xf64, #map1> to memref<?x?xf64, #map1>
1111
return
1212
}
1313

14+
// -----
15+
16+
func @mlir_cast_to_llvm(%0 : index) -> !llvm.i64 {
17+
// expected-error@+1 {{'llvm.mlir.cast' op type must be non-index integer types, float types, or vector of mentioned types}}
18+
%1 = llvm.mlir.cast %0 : index to !llvm.i64
19+
return %1 : !llvm.i64
20+
}
21+
22+
// -----
23+
24+
func @mlir_cast_from_llvm(%0 : !llvm.i64) -> index {
25+
// expected-error@+1 {{'llvm.mlir.cast' op type must be non-index integer types, float types, or vector of mentioned types}}
26+
%1 = llvm.mlir.cast %0 : !llvm.i64 to index
27+
return %1 : index
28+
}
29+
30+
// -----
31+
32+
func @mlir_cast_to_llvm_int(%0 : i32) -> !llvm.i64 {
33+
// expected-error@+1 {{failed to legalize operation 'llvm.mlir.cast' that was explicitly marked illegal}}
34+
%1 = llvm.mlir.cast %0 : i32 to !llvm.i64
35+
return %1 : !llvm.i64
36+
}
37+
38+
// -----
39+
40+
func @mlir_cast_to_llvm_vec(%0 : vector<1x1xf32>) -> !llvm<"<1 x float>"> {
41+
// expected-error@+1 {{'llvm.mlir.cast' op only 1-d vector is allowed}}
42+
%1 = llvm.mlir.cast %0 : vector<1x1xf32> to !llvm<"<1 x float>">
43+
return %1 : !llvm<"<1 x float>">
44+
}

0 commit comments

Comments
 (0)