Skip to content

[MLIR][LLVM] Remove bitcast pattern from type consistency pass #87755

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 0 additions & 11 deletions mlir/include/mlir/Dialect/LLVMIR/Transforms/TypeConsistency.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,17 +56,6 @@ class SplitStores : public OpRewritePattern<StoreOp> {
PatternRewriter &rewrite) const override;
};

/// Transforms type-inconsistent stores, aka stores where the type hint of
/// the address contradicts the value stored, by inserting a bitcast if
/// possible.
class BitcastStores : public OpRewritePattern<StoreOp> {
public:
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(StoreOp store,
PatternRewriter &rewriter) const override;
};

/// Splits GEPs with more than two indices into multiple GEPs with exactly
/// two indices. The created GEPs are then guaranteed to index into only
/// one aggregate at a time.
Expand Down
28 changes: 0 additions & 28 deletions mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,6 @@ static Type isElementTypeInconsistent(Value addr, Type expectedType) {
return elemType;
}

/// Checks that two types are the same or can be bitcast into one another.
static bool areBitcastCompatible(DataLayout &layout, Type lhs, Type rhs) {
return lhs == rhs || (!isa<LLVMStructType, LLVMArrayType>(lhs) &&
!isa<LLVMStructType, LLVMArrayType>(rhs) &&
layout.getTypeSize(lhs) == layout.getTypeSize(rhs));
}

//===----------------------------------------------------------------------===//
// CanonicalizeAlignedGep
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -518,26 +511,6 @@ LogicalResult SplitStores::matchAndRewrite(StoreOp store,
return success();
}

LogicalResult BitcastStores::matchAndRewrite(StoreOp store,
PatternRewriter &rewriter) const {
Type sourceType = store.getValue().getType();
Type typeHint = isElementTypeInconsistent(store.getAddr(), sourceType);
if (!typeHint) {
// Nothing to do, since it is already consistent.
return failure();
}

auto dataLayout = DataLayout::closest(store);
if (!areBitcastCompatible(dataLayout, typeHint, sourceType))
return failure();

auto bitcastOp =
rewriter.create<BitcastOp>(store.getLoc(), typeHint, store.getValue());
rewriter.modifyOpInPlace(store,
[&] { store.getValueMutable().assign(bitcastOp); });
return success();
}

LogicalResult SplitGEP::matchAndRewrite(GEPOp gepOp,
PatternRewriter &rewriter) const {
FailureOr<Type> typeHint = getRequiredConsistentGEPType(gepOp);
Expand Down Expand Up @@ -588,7 +561,6 @@ struct LLVMTypeConsistencyPass
RewritePatternSet rewritePatterns(&getContext());
rewritePatterns.add<CanonicalizeAlignedGep>(&getContext());
rewritePatterns.add<SplitStores>(&getContext(), maxVectorSplitSize);
rewritePatterns.add<BitcastStores>(&getContext());
rewritePatterns.add<SplitGEP>(&getContext());
FrozenRewritePatternSet frozen(std::move(rewritePatterns));

Expand Down
18 changes: 1 addition & 17 deletions mlir/test/Dialect/LLVMIR/type-consistency.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,7 @@ llvm.func @coalesced_store_floats(%arg: i64) {
// CHECK: %[[SHR:.*]] = llvm.lshr %[[ARG]], %[[CST32]] : i64
// CHECK: %[[TRUNC:.*]] = llvm.trunc %[[SHR]] : i64 to i32
// CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (f32, f32)>
// CHECK: %[[BIT_CAST:.*]] = llvm.bitcast %[[TRUNC]] : i32 to f32
// CHECK: llvm.store %[[BIT_CAST]], %[[GEP]]
// CHECK: llvm.store %[[TRUNC]], %[[GEP]]
llvm.store %arg, %1 : i64, !llvm.ptr
// CHECK-NOT: llvm.store %[[ARG]], %[[ALLOCA]]
llvm.return
Expand Down Expand Up @@ -327,21 +326,6 @@ llvm.func @vector_write_split_struct(%arg: vector<2xi64>) {

// -----

// CHECK-LABEL: llvm.func @bitcast_insertion
// CHECK-SAME: %[[ARG:.*]]: i32
llvm.func @bitcast_insertion(%arg: i32) {
%0 = llvm.mlir.constant(1 : i32) : i32
// CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x f32
%1 = llvm.alloca %0 x f32 : (i32) -> !llvm.ptr
// CHECK: %[[BIT_CAST:.*]] = llvm.bitcast %[[ARG]] : i32 to f32
// CHECK: llvm.store %[[BIT_CAST]], %[[ALLOCA]]
llvm.store %arg, %1 : i32, !llvm.ptr
// CHECK-NOT: llvm.store %[[ARG]], %[[ALLOCA]]
llvm.return
}

// -----

// CHECK-LABEL: llvm.func @gep_split
// CHECK-SAME: %[[ARG:.*]]: i64
llvm.func @gep_split(%arg: i64) {
Expand Down