Skip to content

Commit 7d59b33

Browse files
committed
[X86][AMX] Fix the shape dependency issue.
AMX shape should be defined before AMX intrinsics. However for below case, the shape a.row is defined after tile load of b. If we transform `load b` to `@llvm.x86.tileloadd64 intrinsic`, the shape dependency doesn't meet. ``` void test_tile_dpbsud(__tile1024i a, __tile1024i b, __tile1024i c) { __tile_dpbsud(&c, a, b); } ``` This patch is to store the tile b to stack and reloaded it after the def of b.row. It would cause redundant store/load, but it is simple to avoid generating invalid IR. The better way may hoist `def b.row` before tile load instruction, but it seems more complicated to recursively hoist its operands. Differential Revision: https://reviews.llvm.org/D137923
1 parent a214c52 commit 7d59b33

File tree

2 files changed

+86
-21
lines changed

2 files changed

+86
-21
lines changed

llvm/lib/Target/X86/X86LowerAMXType.cpp

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -700,11 +700,12 @@ namespace {
700700

701701
class X86LowerAMXCast {
702702
Function &Func;
703+
std::unique_ptr<DominatorTree> DT;
703704

704705
public:
705-
X86LowerAMXCast(Function &F) : Func(F) {}
706+
X86LowerAMXCast(Function &F) : Func(F), DT(nullptr) {}
706707
void combineCastStore(IntrinsicInst *Cast, StoreInst *ST);
707-
void combineLoadCast(IntrinsicInst *Cast, LoadInst *LD);
708+
bool combineLoadCast(IntrinsicInst *Cast, LoadInst *LD);
708709
bool combineLdSt(SmallVectorImpl<Instruction *> &Casts);
709710
bool combineAMXcast(TargetLibraryInfo *TLI);
710711
bool transformAMXCast(IntrinsicInst *AMXCast);
@@ -942,26 +943,46 @@ void X86LowerAMXCast::combineCastStore(IntrinsicInst *Cast, StoreInst *ST) {
942943
// -->
943944
// %66 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
944945
// i8* %p, i64 64)
945-
void X86LowerAMXCast::combineLoadCast(IntrinsicInst *Cast, LoadInst *LD) {
946+
bool X86LowerAMXCast::combineLoadCast(IntrinsicInst *Cast, LoadInst *LD) {
947+
bool EraseLoad = true;
946948
Value *Row = nullptr, *Col = nullptr;
947949
Use &U = *(Cast->use_begin());
948950
unsigned OpNo = U.getOperandNo();
949951
auto *II = cast<IntrinsicInst>(U.getUser());
950952
// TODO: If it is cast intrinsic or phi node, we can propagate the
951953
// shape information through def-use chain.
952954
if (!isAMXIntrinsic(II))
953-
return;
955+
return false;
954956
std::tie(Row, Col) = getShape(II, OpNo);
955957
IRBuilder<> Builder(LD);
956958
// Use the maximun column as stride.
957959
Value *Stride = Builder.getInt64(64);
958-
Value *I8Ptr =
959-
Builder.CreateBitCast(LD->getOperand(0), Builder.getInt8PtrTy());
960+
Value *I8Ptr;
961+
962+
// To save compiling time, we create doninator tree when it is really
963+
// needed.
964+
if (!DT)
965+
DT.reset(new DominatorTree(Func));
966+
if (!DT->dominates(Row, LD) || !DT->dominates(Col, LD)) {
967+
// store the value to stack and reload it from stack before cast.
968+
auto *AllocaAddr =
969+
createAllocaInstAtEntry(Builder, Cast->getParent(), LD->getType());
970+
Builder.SetInsertPoint(&*std::next(LD->getIterator()));
971+
Builder.CreateStore(LD, AllocaAddr);
972+
973+
Builder.SetInsertPoint(Cast);
974+
I8Ptr = Builder.CreateBitCast(AllocaAddr, Builder.getInt8PtrTy());
975+
EraseLoad = false;
976+
} else {
977+
I8Ptr = Builder.CreateBitCast(LD->getOperand(0), Builder.getInt8PtrTy());
978+
}
960979
std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};
961980

962981
Value *NewInst =
963982
Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, None, Args);
964983
Cast->replaceAllUsesWith(NewInst);
984+
985+
return EraseLoad;
965986
}
966987

967988
bool X86LowerAMXCast::combineLdSt(SmallVectorImpl<Instruction *> &Casts) {
@@ -995,10 +1016,11 @@ bool X86LowerAMXCast::combineLdSt(SmallVectorImpl<Instruction *> &Casts) {
9951016
// -->
9961017
// %66 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
9971018
// i8* %p, i64 64)
998-
combineLoadCast(cast<IntrinsicInst>(Cast), Load);
999-
// Set the operand is null so that load instruction can be erased.
1000-
Cast->setOperand(0, nullptr);
1001-
Load->eraseFromParent();
1019+
if (combineLoadCast(cast<IntrinsicInst>(Cast), Load)) {
1020+
// Set the operand is null so that load instruction can be erased.
1021+
Cast->setOperand(0, nullptr);
1022+
Load->eraseFromParent();
1023+
}
10021024
}
10031025
}
10041026
return Change;
@@ -1198,6 +1220,7 @@ class X86LowerAMXTypeLegacyPass : public FunctionPass {
11981220
TargetMachine *TM = &getAnalysis<TargetPassConfig>().getTM<TargetMachine>();
11991221
TargetLibraryInfo *TLI =
12001222
&getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
1223+
12011224
X86LowerAMXCast LAC(F);
12021225
C |= LAC.combineAMXcast(TLI);
12031226
// There might be remaining AMXcast after combineAMXcast and they should be

llvm/test/CodeGen/X86/AMX/amx-combine.ll

Lines changed: 53 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@ define <256 x i32> @combine_store_2user(ptr%p) {
1818
; CHECK-NEXT: [[TMP1:%.*]] = alloca <256 x i32>, align 64
1919
; CHECK-NEXT: [[T1:%.*]] = call x86_amx @llvm.x86.tilezero.internal(i16 16, i16 64)
2020
; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 16, i16 64, ptr [[TMP1]], i64 64, x86_amx [[T1]])
21-
; CHECK-NEXT: [[TMP3:%.*]] = load <256 x i32>, ptr [[TMP1]], align 1024
21+
; CHECK-NEXT: [[TMP2:%.*]] = load <256 x i32>, ptr [[TMP1]], align 1024
2222
; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 16, i16 64, ptr [[P:%.*]], i64 64, x86_amx [[T1]])
23-
; CHECK-NEXT: ret <256 x i32> [[TMP3]]
23+
; CHECK-NEXT: ret <256 x i32> [[TMP2]]
2424
;
2525
%t1 = call x86_amx @llvm.x86.tilezero.internal(i16 16, i16 64)
2626
%t2 = call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx %t1)
@@ -30,8 +30,8 @@ define <256 x i32> @combine_store_2user(ptr%p) {
3030

3131
define void @combine_load(ptr%p, ptr%p2) {
3232
; CHECK-LABEL: @combine_load(
33-
; CHECK-NEXT: [[TMP2:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 16, i16 64, ptr [[P:%.*]], i64 64)
34-
; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 16, i16 64, ptr [[P2:%.*]], i64 64, x86_amx [[TMP2]])
33+
; CHECK-NEXT: [[TMP1:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 16, i16 64, ptr [[P:%.*]], i64 64)
34+
; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 16, i16 64, ptr [[P2:%.*]], i64 64, x86_amx [[TMP1]])
3535
; CHECK-NEXT: ret void
3636
;
3737
%t1 = load <256 x i32>, ptr %p, align 64
@@ -42,9 +42,9 @@ define void @combine_load(ptr%p, ptr%p2) {
4242

4343
define void @combine_cast_across_store(ptr%p, ptr%p2) {
4444
; CHECK-LABEL: @combine_cast_across_store(
45-
; CHECK-NEXT: [[TMP2:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 16, i16 64, ptr [[P:%.*]], i64 64)
45+
; CHECK-NEXT: [[TMP1:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 16, i16 64, ptr [[P:%.*]], i64 64)
4646
; CHECK-NEXT: store <256 x i32> zeroinitializer, ptr [[P]], align 64
47-
; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 16, i16 64, ptr [[P2:%.*]], i64 64, x86_amx [[TMP2]])
47+
; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 16, i16 64, ptr [[P2:%.*]], i64 64, x86_amx [[TMP1]])
4848
; CHECK-NEXT: ret void
4949
;
5050
%t1 = load <256 x i32>, ptr %p, align 64
@@ -59,8 +59,8 @@ define <256 x i32> @combine_load_2user(ptr%p, ptr%p2) {
5959
; CHECK-NEXT: [[TMP1:%.*]] = alloca <256 x i32>, align 64
6060
; CHECK-NEXT: [[T1:%.*]] = load <256 x i32>, ptr [[P:%.*]], align 64
6161
; CHECK-NEXT: store <256 x i32> [[T1]], ptr [[TMP1]], align 1024
62-
; CHECK-NEXT: [[TMP3:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 16, i16 64, ptr [[TMP1]], i64 64)
63-
; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 16, i16 64, ptr [[P2:%.*]], i64 64, x86_amx [[TMP3]])
62+
; CHECK-NEXT: [[TMP2:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 16, i16 64, ptr [[TMP1]], i64 64)
63+
; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 16, i16 64, ptr [[P2:%.*]], i64 64, x86_amx [[TMP2]])
6464
; CHECK-NEXT: ret <256 x i32> [[T1]]
6565
;
6666
%t1 = load <256 x i32>, ptr %p, align 64
@@ -75,9 +75,9 @@ define <256 x i32> @combine_load_3user(ptr%p, ptr%p2) {
7575
; CHECK-NEXT: [[TMP1:%.*]] = alloca <256 x i32>, align 64
7676
; CHECK-NEXT: [[T1:%.*]] = load <256 x i32>, ptr [[P:%.*]], align 64
7777
; CHECK-NEXT: store <256 x i32> [[T1]], ptr [[TMP1]], align 1024
78-
; CHECK-NEXT: [[TMP3:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 16, i16 16, ptr [[TMP1]], i64 16)
79-
; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 16, i16 64, ptr [[P2:%.*]], i64 64, x86_amx [[TMP3]])
80-
; CHECK-NEXT: [[TMP4:%.*]] = call x86_amx @llvm.x86.tdpbssd.internal(i16 16, i16 16, i16 64, x86_amx [[TMP3]], x86_amx [[TMP3]], x86_amx [[TMP3]])
78+
; CHECK-NEXT: [[TMP2:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 16, i16 16, ptr [[TMP1]], i64 16)
79+
; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 16, i16 64, ptr [[P2:%.*]], i64 64, x86_amx [[TMP2]])
80+
; CHECK-NEXT: [[TMP3:%.*]] = call x86_amx @llvm.x86.tdpbssd.internal(i16 16, i16 16, i16 64, x86_amx [[TMP2]], x86_amx [[TMP2]], x86_amx [[TMP2]])
8181
; CHECK-NEXT: ret <256 x i32> [[T1]]
8282
;
8383
%t1 = load <256 x i32>, ptr %p, align 64
@@ -88,6 +88,48 @@ define <256 x i32> @combine_load_3user(ptr%p, ptr%p2) {
8888
ret <256 x i32> %t3
8989
}
9090

91+
; the shape is loaded after tile.
92+
%struct.__tile1024i_str = type <{ i16, i16, [60 x i8], <256 x i32> }>
93+
define void @test_tile_dpbssd(ptr byval(%struct.__tile1024i_str) align 64 %a, ptr byval(%struct.__tile1024i_str) align 64 %b, ptr byval(%struct.__tile1024i_str) align 64 %c) {
94+
; CHECK-LABEL: @test_tile_dpbssd(
95+
; CHECK-NEXT: entry:
96+
; CHECK-NEXT: [[TMP0:%.*]] = alloca <256 x i32>, align 64
97+
; CHECK-NEXT: [[B_ROW_PTR:%.*]] = getelementptr inbounds i8, ptr [[B:%.*]], i64 2
98+
; CHECK-NEXT: [[B_ROW:%.*]] = load i16, ptr [[B_ROW_PTR]], align 2
99+
; CHECK-NEXT: [[B_TILE_PTR:%.*]] = getelementptr inbounds i8, ptr [[B]], i64 64
100+
; CHECK-NEXT: [[B_TILE:%.*]] = load <256 x i32>, ptr [[B_TILE_PTR]], align 64
101+
; CHECK-NEXT: store <256 x i32> [[B_TILE]], ptr [[TMP0]], align 1024
102+
; CHECK-NEXT: [[A_ROW:%.*]] = load i16, ptr [[A:%.*]], align 64
103+
; CHECK-NEXT: [[A_COL_PTR:%.*]] = getelementptr inbounds i8, ptr [[A]], i64 2
104+
; CHECK-NEXT: [[A_COL:%.*]] = load i16, ptr [[A_COL_PTR]], align 2
105+
; CHECK-NEXT: [[TMP1:%.*]] = udiv i16 [[A_COL]], 4
106+
; CHECK-NEXT: [[A_TILE_PTR:%.*]] = getelementptr inbounds i8, ptr [[A]], i64 64
107+
; CHECK-NEXT: [[TMP2:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[A_ROW]], i16 [[A_COL]], ptr [[A_TILE_PTR]], i64 64)
108+
; CHECK-NEXT: [[C_TILE_PTR:%.*]] = getelementptr inbounds [[STRUCT___TILE1024I_STR:%.*]], ptr [[C:%.*]], i64 0, i32 3
109+
; CHECK-NEXT: [[TMP3:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[A_ROW]], i16 [[B_ROW]], ptr [[C_TILE_PTR]], i64 64)
110+
; CHECK-NEXT: [[TMP4:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP1]], i16 [[B_ROW]], ptr [[TMP0]], i64 64)
111+
; CHECK-NEXT: [[RES:%.*]] = tail call x86_amx @llvm.x86.tdpbssd.internal(i16 [[A_ROW]], i16 [[B_ROW]], i16 [[A_COL]], x86_amx [[TMP3]], x86_amx [[TMP2]], x86_amx [[TMP4]])
112+
; CHECK-NEXT: ret void
113+
;
114+
entry:
115+
%b.row.ptr= getelementptr inbounds i8, ptr %b, i64 2
116+
%b.row = load i16, ptr %b.row.ptr, align 2
117+
%b.tile.ptr = getelementptr inbounds i8, ptr %b, i64 64
118+
%b.tile = load <256 x i32>, ptr %b.tile.ptr, align 64
119+
%a.row = load i16, ptr %a, align 64
120+
%a.col.ptr = getelementptr inbounds i8, ptr %a, i64 2
121+
%a.col = load i16, ptr %a.col.ptr, align 2
122+
%a.tile.ptr = getelementptr inbounds i8, ptr %a, i64 64
123+
%a.tile = load <256 x i32>, ptr %a.tile.ptr, align 64
124+
%c.tile.ptr = getelementptr inbounds %struct.__tile1024i_str, ptr %c, i64 0, i32 3
125+
%c.tile = load <256 x i32>, ptr %c.tile.ptr, align 64
126+
%c.amx = tail call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> %c.tile)
127+
%a.amx = tail call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> %a.tile)
128+
%b.amx = tail call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> %b.tile)
129+
%res = tail call x86_amx @llvm.x86.tdpbssd.internal(i16 %a.row, i16 %b.row, i16 %a.col, x86_amx %c.amx, x86_amx %a.amx, x86_amx %b.amx)
130+
ret void
131+
}
132+
91133
declare x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32>)
92134
declare <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx)
93135
declare x86_amx @llvm.x86.tilezero.internal(i16, i16)

0 commit comments

Comments
 (0)