Skip to content

Commit 055644c

Browse files
committed
[X86][AMX] Prohibit pointer cast on load.
The load/store instruction will be transformed to amx intrinsics in the pass of AMX type lowering. Prohibiting the pointer cast make that pass happy. Differential Revision: https://reviews.llvm.org/D94372
1 parent c0f3ea8 commit 055644c

File tree

2 files changed

+53
-3
lines changed

2 files changed

+53
-3
lines changed

llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -589,7 +589,16 @@ static Instruction *combineLoadToOperationType(InstCombinerImpl &IC,
589589
// Fold away bit casts of the loaded value by loading the desired type.
590590
// Note that we should not do this for pointer<->integer casts,
591591
// because that would result in type punning.
592-
if (LI.hasOneUse())
592+
if (LI.hasOneUse()) {
593+
// Don't transform when the type is x86_amx, it makes the pass that lower
594+
// x86_amx type happy.
595+
if (auto *BC = dyn_cast<BitCastInst>(LI.user_back())) {
596+
assert(!LI.getType()->isX86_AMXTy() &&
597+
"load from x86_amx* should not happen!");
598+
if (BC->getType()->isX86_AMXTy())
599+
return nullptr;
600+
}
601+
593602
if (auto* CI = dyn_cast<CastInst>(LI.user_back()))
594603
if (CI->isNoopCast(DL) && LI.getType()->isPtrOrPtrVectorTy() ==
595604
CI->getDestTy()->isPtrOrPtrVectorTy())
@@ -599,6 +608,7 @@ static Instruction *combineLoadToOperationType(InstCombinerImpl &IC,
599608
IC.eraseInstFromFunction(*CI);
600609
return &LI;
601610
}
611+
}
602612

603613
// FIXME: We should also canonicalize loads of vectors when their elements are
604614
// cast to other types.
@@ -1114,10 +1124,12 @@ static bool combineStoreToValueType(InstCombinerImpl &IC, StoreInst &SI) {
11141124

11151125
// Fold away bit casts of the stored value by storing the original type.
11161126
if (auto *BC = dyn_cast<BitCastInst>(V)) {
1127+
assert(!BC->getType()->isX86_AMXTy() &&
1128+
"store to x86_amx* should not happen!");
11171129
V = BC->getOperand(0);
1118-
// Don't transform when the type is x86_amx, it make the pass that lower
1130+
// Don't transform when the type is x86_amx, it makes the pass that lower
11191131
// x86_amx type happy.
1120-
if (BC->getType()->isX86_AMXTy() || V->getType()->isX86_AMXTy())
1132+
if (V->getType()->isX86_AMXTy())
11211133
return false;
11221134
if (!SI.isAtomic() || isSupportedAtomicType(V->getType())) {
11231135
combineStoreToNewValue(IC, SI, V);
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
2+
; RUN: opt -instcombine -S < %s | FileCheck %s
3+
; RUN: opt -passes=instcombine -S < %s | FileCheck %s
4+
5+
; Prohibit poiter cast for amx.
6+
define dso_local void @test_amx_load_store(<256 x i32>* %src, i8* %dst) {
7+
; CHECK-LABEL: @test_amx_load_store(
8+
; CHECK-NEXT: entry:
9+
; CHECK-NEXT: [[VEC:%.*]] = load <256 x i32>, <256 x i32>* [[SRC:%.*]], align 64
10+
; CHECK-NEXT: [[BC:%.*]] = bitcast <256 x i32> [[VEC]] to x86_amx
11+
; CHECK-NEXT: tail call void @llvm.x86.tilestored64.internal(i16 16, i16 16, i8* [[DST:%.*]], i64 64, x86_amx [[BC]])
12+
; CHECK-NEXT: ret void
13+
;
14+
entry:
15+
%vec = load <256 x i32>, <256 x i32>* %src, align 64
16+
%bc = bitcast <256 x i32> %vec to x86_amx
17+
tail call void @llvm.x86.tilestored64.internal(i16 16, i16 16, i8* %dst, i64 64, x86_amx %bc)
18+
ret void
19+
}
20+
21+
; Prohibit poiter cast for amx.
22+
define dso_local void @test_amx_load_store2(<256 x i32>* %dst, i8* %src) {
23+
; CHECK-LABEL: @test_amx_load_store2(
24+
; CHECK-NEXT: entry:
25+
; CHECK-NEXT: [[AMX:%.*]] = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 16, i16 16, i8* [[SRC:%.*]], i64 64)
26+
; CHECK-NEXT: [[BC:%.*]] = bitcast x86_amx [[AMX]] to <256 x i32>
27+
; CHECK-NEXT: store <256 x i32> [[BC]], <256 x i32>* [[DST:%.*]], align 1024
28+
; CHECK-NEXT: ret void
29+
;
30+
entry:
31+
%amx = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 16, i16 16, i8* %src, i64 64)
32+
%bc = bitcast x86_amx %amx to <256 x i32>
33+
store <256 x i32> %bc, <256 x i32>* %dst
34+
ret void
35+
}
36+
37+
declare x86_amx @llvm.x86.tileloadd64.internal(i16, i16, i8*, i64)
38+
declare void @llvm.x86.tilestored64.internal(i16, i16, i8*, i64, x86_amx)

0 commit comments

Comments
 (0)