Skip to content

Commit e167c31

Browse files
authored
Reland [mlir] ArithToLLVM: fix memref bitcast lowering (#125148) (#126939)
Reland #125148 Limiting vector pattern caused issues with `select` of complex lowering, which wasn't caught as it was missing lit tests. Keep the pattern as is for now and instead set a higher benefit to `IdentityBitcastLowering` so it will always run before the vector pattern.
1 parent b04a980 commit e167c31

File tree

2 files changed

+52
-1
lines changed

2 files changed

+52
-1
lines changed

mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,25 @@ struct ConstrainedVectorConvertToLLVMPattern
5454
}
5555
};
5656

57+
/// No-op bitcast. Propagate type input arg if converted source and dest types
58+
/// are the same.
59+
struct IdentityBitcastLowering final
60+
: public OpConversionPattern<arith::BitcastOp> {
61+
using OpConversionPattern::OpConversionPattern;
62+
63+
LogicalResult
64+
matchAndRewrite(arith::BitcastOp op, OpAdaptor adaptor,
65+
ConversionPatternRewriter &rewriter) const final {
66+
Value src = adaptor.getIn();
67+
Type resultType = getTypeConverter()->convertType(op.getType());
68+
if (src.getType() != resultType)
69+
return rewriter.notifyMatchFailure(op, "Types are different");
70+
71+
rewriter.replaceOp(op, src);
72+
return success();
73+
}
74+
};
75+
5776
//===----------------------------------------------------------------------===//
5877
// Straightforward Op Lowerings
5978
//===----------------------------------------------------------------------===//
@@ -524,6 +543,12 @@ void mlir::arith::registerConvertArithToLLVMInterface(
524543

525544
void mlir::arith::populateArithToLLVMConversionPatterns(
526545
const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
546+
547+
// Set a higher pattern benefit for IdentityBitcastLowering so it will run
548+
// before BitcastOpLowering.
549+
patterns.add<IdentityBitcastLowering>(converter, patterns.getContext(),
550+
/*patternBenefit*/ 10);
551+
527552
// clang-format off
528553
patterns.add<
529554
AddFOpLowering,

mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -577,12 +577,26 @@ func.func @cmpi_2dvector(%arg0 : vector<4x3xi32>, %arg1 : vector<4x3xi32>) {
577577
// -----
578578

579579
// CHECK-LABEL: @select
580+
// CHECK-SAME: (%[[ARG0:.*]]: i1, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32)
580581
func.func @select(%arg0 : i1, %arg1 : i32, %arg2 : i32) -> i32 {
581-
// CHECK: = llvm.select %arg0, %arg1, %arg2 : i1, i32
582+
// CHECK: %[[RES:.*]] = llvm.select %[[ARG0]], %[[ARG1]], %[[ARG2]] : i1, i32
583+
// CHECK: return %[[RES]]
582584
%0 = arith.select %arg0, %arg1, %arg2 : i32
583585
return %0 : i32
584586
}
585587

588+
// CHECK-LABEL: @select_complex
589+
// CHECK-SAME: (%[[ARG0:.*]]: i1, %[[ARG1:.*]]: complex<f32>, %[[ARG2:.*]]: complex<f32>)
590+
func.func @select_complex(%arg0 : i1, %arg1 : complex<f32>, %arg2 : complex<f32>) -> complex<f32> {
591+
// CHECK-DAG: %[[ARGC1:.*]] = builtin.unrealized_conversion_cast %[[ARG1]] : complex<f32> to !llvm.struct<(f32, f32)>
592+
// CHECK-DAG: %[[ARGC2:.*]] = builtin.unrealized_conversion_cast %[[ARG2]] : complex<f32> to !llvm.struct<(f32, f32)>
593+
// CHECK: %[[RES:.*]] = llvm.select %[[ARG0]], %[[ARGC1]], %[[ARGC2]] : i1, !llvm.struct<(f32, f32)>
594+
// CHECK: %[[RESC:.*]] = builtin.unrealized_conversion_cast %[[RES]] : !llvm.struct<(f32, f32)> to complex<f32>
595+
// CHECK: return %[[RESC]]
596+
%0 = arith.select %arg0, %arg1, %arg2 : complex<f32>
597+
return %0 : complex<f32>
598+
}
599+
586600
// -----
587601

588602
// CHECK-LABEL: @ceildivsi
@@ -727,3 +741,15 @@ func.func @ops_supporting_overflow(%arg0: i64, %arg1: i64) {
727741
%3 = arith.shli %arg0, %arg1 overflow<nsw, nuw> : i64
728742
return
729743
}
744+
745+
// -----
746+
747+
// CHECK-LABEL: func @memref_bitcast
748+
// CHECK-SAME: (%[[ARG:.*]]: memref<?xi16>)
749+
// CHECK: %[[V1:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : memref<?xi16> to !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
750+
// CHECK: %[[V2:.*]] = builtin.unrealized_conversion_cast %[[V1]] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> to memref<?xbf16>
751+
// CHECK: return %[[V2]]
752+
func.func @memref_bitcast(%1: memref<?xi16>) -> memref<?xbf16> {
753+
%2 = arith.bitcast %1 : memref<?xi16> to memref<?xbf16>
754+
func.return %2 : memref<?xbf16>
755+
}

0 commit comments

Comments
 (0)