Skip to content

AMDGPU: Handle folding vector splats of inline split f64 inline immediates #140878

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 70 additions & 33 deletions llvm/lib/Target/AMDGPU/SIFoldOperands.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -227,12 +227,12 @@ class SIFoldOperandsImpl {
getRegSeqInit(SmallVectorImpl<std::pair<MachineOperand *, unsigned>> &Defs,
Register UseReg) const;

std::pair<MachineOperand *, const TargetRegisterClass *>
std::pair<int64_t, const TargetRegisterClass *>
isRegSeqSplat(MachineInstr &RegSeg) const;

MachineOperand *tryFoldRegSeqSplat(MachineInstr *UseMI, unsigned UseOpIdx,
MachineOperand *SplatVal,
const TargetRegisterClass *SplatRC) const;
bool tryFoldRegSeqSplat(MachineInstr *UseMI, unsigned UseOpIdx,
int64_t SplatVal,
const TargetRegisterClass *SplatRC) const;

bool tryToFoldACImm(const FoldableDef &OpToFold, MachineInstr *UseMI,
unsigned UseOpIdx,
Expand Down Expand Up @@ -966,15 +966,15 @@ const TargetRegisterClass *SIFoldOperandsImpl::getRegSeqInit(
return getRegSeqInit(*Def, Defs);
}

std::pair<MachineOperand *, const TargetRegisterClass *>
std::pair<int64_t, const TargetRegisterClass *>
SIFoldOperandsImpl::isRegSeqSplat(MachineInstr &RegSeq) const {
SmallVector<std::pair<MachineOperand *, unsigned>, 32> Defs;
const TargetRegisterClass *SrcRC = getRegSeqInit(RegSeq, Defs);
if (!SrcRC)
return {};

// TODO: Recognize 64-bit splats broken into 32-bit pieces (i.e. recognize
// every other other element is 0 for 64-bit immediates)
bool TryToMatchSplat64 = false;

int64_t Imm;
for (unsigned I = 0, E = Defs.size(); I != E; ++I) {
const MachineOperand *Op = Defs[I].first;
Expand All @@ -986,38 +986,75 @@ SIFoldOperandsImpl::isRegSeqSplat(MachineInstr &RegSeq) const {
Imm = SubImm;
continue;
}
if (Imm != SubImm)

if (Imm != SubImm) {
if (I == 1 && (E & 1) == 0) {
// If we have an even number of inputs, there's a chance this is a
// 64-bit element splat broken into 32-bit pieces.
TryToMatchSplat64 = true;
break;
}

return {}; // Can only fold splat constants
}
}

if (!TryToMatchSplat64)
return {Defs[0].first->getImm(), SrcRC};

// Fallback to recognizing 64-bit splats broken into 32-bit pieces
// (i.e. recognize every other other element is 0 for 64-bit immediates)
int64_t SplatVal64;
for (unsigned I = 0, E = Defs.size(); I != E; I += 2) {
const MachineOperand *Op0 = Defs[I].first;
const MachineOperand *Op1 = Defs[I + 1].first;

if (!Op0->isImm() || !Op1->isImm())
return {};

unsigned SubReg0 = Defs[I].second;
unsigned SubReg1 = Defs[I + 1].second;

// Assume we're going to generally encounter reg_sequences with sorted
// subreg indexes, so reject any that aren't consecutive.
if (TRI->getChannelFromSubReg(SubReg0) + 1 !=
TRI->getChannelFromSubReg(SubReg1))
return {};

int64_t MergedVal = Make_64(Op1->getImm(), Op0->getImm());
if (I == 0)
SplatVal64 = MergedVal;
else if (SplatVal64 != MergedVal)
return {};
}

return {Defs[0].first, SrcRC};
const TargetRegisterClass *RC64 = TRI->getSubRegisterClass(
MRI->getRegClass(RegSeq.getOperand(0).getReg()), AMDGPU::sub0_sub1);

return {SplatVal64, RC64};
}

MachineOperand *SIFoldOperandsImpl::tryFoldRegSeqSplat(
MachineInstr *UseMI, unsigned UseOpIdx, MachineOperand *SplatVal,
bool SIFoldOperandsImpl::tryFoldRegSeqSplat(
MachineInstr *UseMI, unsigned UseOpIdx, int64_t SplatVal,
const TargetRegisterClass *SplatRC) const {
const MCInstrDesc &Desc = UseMI->getDesc();
if (UseOpIdx >= Desc.getNumOperands())
return nullptr;
return false;

// Filter out unhandled pseudos.
if (!AMDGPU::isSISrcOperand(Desc, UseOpIdx))
return nullptr;
return false;

int16_t RCID = Desc.operands()[UseOpIdx].RegClass;
if (RCID == -1)
return nullptr;
return false;

const TargetRegisterClass *OpRC = TRI->getRegClass(RCID);

// Special case 0/-1, since when interpreted as a 64-bit element both halves
// have the same bits. Effectively this code does not handle 64-bit element
// operands correctly, as the incoming 64-bit constants are already split into
// 32-bit sequence elements.
//
// TODO: We should try to figure out how to interpret the reg_sequence as a
// split 64-bit splat constant, or use 64-bit pseudos for materializing f64
// constants.
if (SplatVal->getImm() != 0 && SplatVal->getImm() != -1) {
const TargetRegisterClass *OpRC = TRI->getRegClass(RCID);
// have the same bits. These are the only cases where a splat has the same
// interpretation for 32-bit and 64-bit splats.
if (SplatVal != 0 && SplatVal != -1) {
// We need to figure out the scalar type read by the operand. e.g. the MFMA
// operand will be AReg_128, and we want to check if it's compatible with an
// AReg_32 constant.
Expand All @@ -1031,17 +1068,18 @@ MachineOperand *SIFoldOperandsImpl::tryFoldRegSeqSplat(
OpRC = TRI->getSubRegisterClass(OpRC, AMDGPU::sub0_sub1);
break;
default:
return nullptr;
return false;
}

if (!TRI->getCommonSubClass(OpRC, SplatRC))
return nullptr;
return false;
}

if (!TII->isOperandLegal(*UseMI, UseOpIdx, SplatVal))
return nullptr;
MachineOperand TmpOp = MachineOperand::CreateImm(SplatVal);
if (!TII->isOperandLegal(*UseMI, UseOpIdx, &TmpOp))
return false;

return SplatVal;
return true;
}

bool SIFoldOperandsImpl::tryToFoldACImm(
Expand Down Expand Up @@ -1119,7 +1157,7 @@ void SIFoldOperandsImpl::foldOperand(
Register RegSeqDstReg = UseMI->getOperand(0).getReg();
unsigned RegSeqDstSubReg = UseMI->getOperand(UseOpIdx + 1).getImm();

MachineOperand *SplatVal;
int64_t SplatVal;
const TargetRegisterClass *SplatRC;
std::tie(SplatVal, SplatRC) = isRegSeqSplat(*UseMI);

Expand All @@ -1130,10 +1168,9 @@ void SIFoldOperandsImpl::foldOperand(
MachineInstr *RSUseMI = RSUse->getParent();
unsigned OpNo = RSUseMI->getOperandNo(RSUse);

if (SplatVal) {
if (MachineOperand *Foldable =
tryFoldRegSeqSplat(RSUseMI, OpNo, SplatVal, SplatRC)) {
FoldableDef SplatDef(*Foldable, SplatRC);
if (SplatRC) {
if (tryFoldRegSeqSplat(RSUseMI, OpNo, SplatVal, SplatRC)) {
FoldableDef SplatDef(SplatVal, SplatRC);
appendFoldCandidate(FoldList, RSUseMI, OpNo, SplatDef);
continue;
}
Expand Down
41 changes: 6 additions & 35 deletions llvm/test/CodeGen/AMDGPU/llvm.amdgcn.mfma.gfx90a.ll
Original file line number Diff line number Diff line change
Expand Up @@ -165,19 +165,9 @@ bb:
}

; GCN-LABEL: {{^}}test_mfma_f64_16x16x4f64_splat_imm_1:
; GCN: v_mov_b32_e32 [[HIGH_BITS:v[0-9]+]], 0x3ff00000
; GCN: v_accvgpr_write_b32 a[[A_HIGH_BITS_0:[0-9]+]], [[HIGH_BITS]]
; GCN: v_accvgpr_write_b32 a[[A_LOW_BITS_0:[0-9]+]], 0{{$}}
; GCN: v_accvgpr_mov_b32 a{{[0-9]+}}, a[[A_LOW_BITS_0]]
; GCN: v_accvgpr_mov_b32 a{{[0-9]+}}, a[[A_HIGH_BITS_0]]
; GCN: v_accvgpr_mov_b32 a{{[0-9]+}}, a[[A_LOW_BITS_0]]
; GCN: v_accvgpr_mov_b32 a{{[0-9]+}}, a[[A_HIGH_BITS_0]]
; GCN: v_accvgpr_mov_b32 a{{[0-9]+}}, a[[A_LOW_BITS_0]]
; GCN: v_accvgpr_mov_b32 a[[LAST_CONST_REG:[0-9]+]], a[[A_HIGH_BITS_0]]

; GFX90A: v_mfma_f64_16x16x4f64 [[M1:a\[[0-9]+:[0-9]+\]]], v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], a{{\[}}[[A_LOW_BITS_0]]:[[LAST_CONST_REG]]{{\]$}}
; GFX90A: v_mfma_f64_16x16x4f64 [[M1:a\[[0-9]+:[0-9]+\]]], v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], 1.0{{$}}
; GFX90A: v_mfma_f64_16x16x4f64 a[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], [[M1]] cbsz:1 abid:2 blgp:3
; GFX942: v_mfma_f64_16x16x4_f64 [[M1:a\[[0-9]+:[0-9]+\]]], v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], a{{\[}}[[A_LOW_BITS_0]]:[[LAST_CONST_REG]]{{\]$}}
; GFX942: v_mfma_f64_16x16x4_f64 [[M1:a\[[0-9]+:[0-9]+\]]], v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], 1.0{{$}}
; GFX942: v_mfma_f64_16x16x4_f64 a[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], [[M1]] cbsz:1 abid:2 neg:[1,1,0]
; GCN: global_store_dwordx4
; GCN: global_store_dwordx4
Expand All @@ -190,19 +180,9 @@ bb:
}

; GCN-LABEL: {{^}}test_mfma_f64_16x16x4f64_splat_imm_neg1:
; GCN: v_mov_b32_e32 [[HIGH_BITS:v[0-9]+]], 0xbff00000
; GCN: v_accvgpr_write_b32 a[[A_HIGH_BITS_0:[0-9]+]], [[HIGH_BITS]]
; GCN: v_accvgpr_write_b32 a[[A_LOW_BITS_0:[0-9]+]], 0{{$}}
; GCN: v_accvgpr_mov_b32 a{{[0-9]+}}, a[[A_LOW_BITS_0]]
; GCN: v_accvgpr_mov_b32 a{{[0-9]+}}, a[[A_HIGH_BITS_0]]
; GCN: v_accvgpr_mov_b32 a{{[0-9]+}}, a[[A_LOW_BITS_0]]
; GCN: v_accvgpr_mov_b32 a{{[0-9]+}}, a[[A_HIGH_BITS_0]]
; GCN: v_accvgpr_mov_b32 a{{[0-9]+}}, a[[A_LOW_BITS_0]]
; GCN: v_accvgpr_mov_b32 a[[LAST_CONST_REG:[0-9]+]], a[[A_HIGH_BITS_0]]

; GFX90A: v_mfma_f64_16x16x4f64 [[M1:a\[[0-9]+:[0-9]+\]]], v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], a{{\[}}[[A_LOW_BITS_0]]:[[LAST_CONST_REG]]{{\]$}}
; GFX90A: v_mfma_f64_16x16x4f64 [[M1:a\[[0-9]+:[0-9]+\]]], v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], -1.0{{$}}
; GFX90A: v_mfma_f64_16x16x4f64 a[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], [[M1]] cbsz:1 abid:2 blgp:3
; GFX942: v_mfma_f64_16x16x4_f64 [[M1:a\[[0-9]+:[0-9]+\]]], v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], a{{\[}}[[A_LOW_BITS_0]]:[[LAST_CONST_REG]]{{\]$}}
; GFX942: v_mfma_f64_16x16x4_f64 [[M1:a\[[0-9]+:[0-9]+\]]], v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], -1.0{{$}}
; GFX942: v_mfma_f64_16x16x4_f64 a[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], [[M1]] cbsz:1 abid:2 neg:[1,1,0]
; GCN: global_store_dwordx4
; GCN: global_store_dwordx4
Expand All @@ -215,18 +195,9 @@ bb:
}

; GCN-LABEL: {{^}}test_mfma_f64_16x16x4f64_splat_imm_int_64:
; GCN: v_accvgpr_write_b32 a[[A_LOW_BITS_0:[0-9]+]], 64{{$}}
; GCN: v_accvgpr_write_b32 a[[A_HIGH_BITS_0:[0-9]+]], 0
; GCN: v_accvgpr_mov_b32 a{{[0-9]+}}, a[[A_LOW_BITS_0]]
; GCN: v_accvgpr_mov_b32 a{{[0-9]+}}, a[[A_HIGH_BITS_0]]
; GCN: v_accvgpr_mov_b32 a{{[0-9]+}}, a[[A_LOW_BITS_0]]
; GCN: v_accvgpr_mov_b32 a{{[0-9]+}}, a[[A_HIGH_BITS_0]]
; GCN: v_accvgpr_mov_b32 a{{[0-9]+}}, a[[A_LOW_BITS_0]]
; GCN: v_accvgpr_mov_b32 a[[LAST_CONST_REG:[0-9]+]], a[[A_HIGH_BITS_0]]

; GFX90A: v_mfma_f64_16x16x4f64 [[M1:a\[[0-9]+:[0-9]+\]]], v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], a{{\[}}[[A_LOW_BITS_0]]:[[LAST_CONST_REG]]{{\]$}}
; GFX90A: v_mfma_f64_16x16x4f64 [[M1:a\[[0-9]+:[0-9]+\]]], v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], 64{{$}}
; GFX90A: v_mfma_f64_16x16x4f64 a[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], [[M1]] cbsz:1 abid:2 blgp:3
; GFX942: v_mfma_f64_16x16x4_f64 [[M1:a\[[0-9]+:[0-9]+\]]], v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], a{{\[}}[[A_LOW_BITS_0]]:[[LAST_CONST_REG]]{{\]$}}
; GFX942: v_mfma_f64_16x16x4_f64 [[M1:a\[[0-9]+:[0-9]+\]]], v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], 64{{$}}
; GFX942: v_mfma_f64_16x16x4_f64 a[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], [[M1]] cbsz:1 abid:2 neg:[1,1,0]
; GCN: global_store_dwordx4
; GCN: global_store_dwordx4
Expand Down
Loading