Skip to content

Commit 79010e2

Browse files
authored
[mlir] ArithToLLVM: fix memref bitcast lowering (#125148)
`arith.bitcast` is allowed on memrefs and such code can actually be generated by IREE `ConvertBf16ArithToF32Pass`. `LLVM::detail::vectorOneToOneRewrite` doesn't properly check its types and will generate bitcast between structs which is illegal. With the opaque pointers this is a no-op operation for memref so we can just add type check in `LLVM::detail::vectorOneToOneRewrite` and add a separate pattern which removes op if converted types are the same.
1 parent bf7af2d commit 79010e2

File tree

3 files changed

+43
-1
lines changed

3 files changed

+43
-1
lines changed

mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,25 @@ struct ConstrainedVectorConvertToLLVMPattern
5454
}
5555
};
5656

57+
/// No-op bitcast. Propagate type input arg if converted source and dest types
58+
/// are the same.
59+
struct IdentityBitcastLowering final
60+
: public OpConversionPattern<arith::BitcastOp> {
61+
using OpConversionPattern::OpConversionPattern;
62+
63+
LogicalResult
64+
matchAndRewrite(arith::BitcastOp op, OpAdaptor adaptor,
65+
ConversionPatternRewriter &rewriter) const final {
66+
Value src = adaptor.getIn();
67+
Type resultType = getTypeConverter()->convertType(op.getType());
68+
if (src.getType() != resultType)
69+
return rewriter.notifyMatchFailure(op, "Types are different");
70+
71+
rewriter.replaceOp(op, src);
72+
return success();
73+
}
74+
};
75+
5776
//===----------------------------------------------------------------------===//
5877
// Straightforward Op Lowerings
5978
//===----------------------------------------------------------------------===//
@@ -524,6 +543,9 @@ void mlir::arith::registerConvertArithToLLVMInterface(
524543

525544
void mlir::arith::populateArithToLLVMConversionPatterns(
526545
const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
546+
547+
patterns.add<IdentityBitcastLowering>(converter, patterns.getContext());
548+
527549
// clang-format off
528550
patterns.add<
529551
AddFOpLowering,

mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,14 @@ LogicalResult LLVM::detail::handleMultidimensionalVectors(
103103
return success();
104104
}
105105

106+
static bool isVectorCompatibleType(Type type) {
107+
// Limit `vectorOneToOneRewrite` to scalar and vector types (and to
108+
// `LLVM::LLVMArrayType` which have a special handling).
109+
return isa<LLVM::LLVMArrayType, LLVM::LLVMPointerType, VectorType,
110+
IntegerType, FloatType>(type) &&
111+
LLVM::isCompatibleType(type);
112+
}
113+
106114
LogicalResult LLVM::detail::vectorOneToOneRewrite(
107115
Operation *op, StringRef targetOp, ValueRange operands,
108116
ArrayRef<NamedAttribute> targetAttrs,
@@ -111,7 +119,7 @@ LogicalResult LLVM::detail::vectorOneToOneRewrite(
111119
assert(!operands.empty());
112120

113121
// Cannot convert ops if their operands are not of LLVM type.
114-
if (!llvm::all_of(operands.getTypes(), isCompatibleType))
122+
if (!llvm::all_of(operands.getTypes(), isVectorCompatibleType))
115123
return failure();
116124

117125
auto llvmNDVectorTy = operands[0].getType();

mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -727,3 +727,15 @@ func.func @ops_supporting_overflow(%arg0: i64, %arg1: i64) {
727727
%3 = arith.shli %arg0, %arg1 overflow<nsw, nuw> : i64
728728
return
729729
}
730+
731+
// -----
732+
733+
// CHECK-LABEL: func @memref_bitcast
734+
// CHECK-SAME: (%[[ARG:.*]]: memref<?xi16>)
735+
// CHECK: %[[V1:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : memref<?xi16> to !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
736+
// CHECK: %[[V2:.*]] = builtin.unrealized_conversion_cast %[[V1]] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> to memref<?xbf16>
737+
// CHECK: return %[[V2]]
738+
func.func @memref_bitcast(%1: memref<?xi16>) -> memref<?xbf16> {
739+
%2 = arith.bitcast %1 : memref<?xi16> to memref<?xbf16>
740+
func.return %2 : memref<?xbf16>
741+
}

0 commit comments

Comments
 (0)