Skip to content

Commit 1a5d4a8

Browse files
[RISCV][llvm-mca] Vector Unit Stride Loads and stores use EEW and EMUL based on instruction EEW
Vector Unit Stride Loads and stores EEW and EMUL depend on the EEW given in the instruction name. llvm-mca needs some help to correctly report this information.
1 parent 460e843 commit 1a5d4a8

File tree

2 files changed

+1370
-6
lines changed

2 files changed

+1370
-6
lines changed

llvm/lib/Target/RISCV/MCA/RISCVCustomBehaviour.cpp

Lines changed: 121 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
#include "TargetInfo/RISCVTargetInfo.h"
1818
#include "llvm/MC/TargetRegistry.h"
1919
#include "llvm/Support/Debug.h"
20+
#include <numeric>
21+
#include <set>
2022

2123
#define DEBUG_TYPE "llvm-mca-riscv-custombehaviour"
2224

@@ -185,6 +187,109 @@ RISCVInstrumentManager::createInstruments(const MCInst &Inst) {
185187
return SmallVector<UniqueInstrument>();
186188
}
187189

190+
/// Return EMUL = (EEW / SEW) * LMUL
191+
inline static std::pair<unsigned, bool>
192+
getEMULEqualsEEWDivSEWTimesLMUL(unsigned EEW, unsigned SEW,
193+
RISCVII::VLMUL VLMUL) {
194+
// Calculate (EEW/SEW)*LMUL preserving fractions less than 1. Use GCD
195+
// to put fraction in simplest form.
196+
auto [LMUL, Fractional] = RISCVVType::decodeVLMUL(VLMUL);
197+
unsigned Num = EEW, Denom = SEW;
198+
int GCD =
199+
Fractional ? std::gcd(Num, Denom * LMUL) : std::gcd(Num * LMUL, Denom);
200+
Num = Fractional ? Num / GCD : Num * LMUL / GCD;
201+
Denom = Fractional ? Denom * LMUL / GCD : Denom / GCD;
202+
return std::make_pair(Num > Denom ? Num : Denom, Denom > Num);
203+
}
204+
205+
static std::pair<uint8_t, uint8_t>
206+
getEEWAndEMULForUnitStrideLoadStore(unsigned Opcode, uint8_t LMUL,
207+
uint8_t SEW) {
208+
uint8_t EEW;
209+
switch (Opcode) {
210+
case RISCV::VLM_V:
211+
case RISCV::VSM_V:
212+
case RISCV::VLE8_V:
213+
case RISCV::VSE8_V:
214+
EEW = 8;
215+
break;
216+
case RISCV::VLE16_V:
217+
case RISCV::VSE16_V:
218+
EEW = 16;
219+
break;
220+
case RISCV::VLE32_V:
221+
case RISCV::VSE32_V:
222+
EEW = 32;
223+
break;
224+
case RISCV::VLE64_V:
225+
case RISCV::VSE64_V:
226+
EEW = 64;
227+
break;
228+
default:
229+
llvm_unreachable("Opcode is not a vector unit stride load nor store");
230+
}
231+
232+
RISCVII::VLMUL VLMUL;
233+
switch (LMUL) {
234+
case 0b000:
235+
VLMUL = RISCVII::LMUL_1;
236+
break;
237+
case 0b001:
238+
VLMUL = RISCVII::LMUL_2;
239+
break;
240+
case 0b010:
241+
VLMUL = RISCVII::LMUL_4;
242+
break;
243+
case 0b011:
244+
VLMUL = RISCVII::LMUL_8;
245+
break;
246+
case 0b111:
247+
VLMUL = RISCVII::LMUL_F2;
248+
break;
249+
case 0b110:
250+
VLMUL = RISCVII::LMUL_F4;
251+
break;
252+
case 0b101:
253+
VLMUL = RISCVII::LMUL_F8;
254+
break;
255+
case RISCVII::LMUL_RESERVED:
256+
llvm_unreachable("LMUL cannot be LMUL_RESERVED");
257+
}
258+
259+
auto [EMULPart, Fractional] =
260+
getEMULEqualsEEWDivSEWTimesLMUL(EEW, SEW, VLMUL);
261+
assert(RISCVVType::isValidLMUL(EMULPart, Fractional) &&
262+
"Unexpected EEW from instruction used with LMUL and SEW");
263+
264+
uint8_t EMUL;
265+
switch (RISCVVType::encodeLMUL(EMULPart, Fractional)) {
266+
case RISCVII::LMUL_1:
267+
EMUL = 0b000;
268+
break;
269+
case RISCVII::LMUL_2:
270+
EMUL = 0b001;
271+
break;
272+
case RISCVII::LMUL_4:
273+
EMUL = 0b010;
274+
break;
275+
case RISCVII::LMUL_8:
276+
EMUL = 0b011;
277+
break;
278+
case RISCVII::LMUL_F2:
279+
EMUL = 0b111;
280+
break;
281+
case RISCVII::LMUL_F4:
282+
EMUL = 0b110;
283+
break;
284+
case RISCVII::LMUL_F8:
285+
EMUL = 0b101;
286+
break;
287+
case RISCVII::LMUL_RESERVED:
288+
llvm_unreachable("Cannot create instrument for LMUL_RESERVED");
289+
}
290+
return std::make_pair(EEW, EMUL);
291+
}
292+
188293
unsigned RISCVInstrumentManager::getSchedClassID(
189294
const MCInstrInfo &MCII, const MCInst &MCI,
190295
const llvm::SmallVector<Instrument *> &IVec) const {
@@ -214,12 +319,22 @@ unsigned RISCVInstrumentManager::getSchedClassID(
214319
// or (Opcode, LMUL, SEW) if SEW instrument is active, and depends on LMUL
215320
// and SEW, or (Opcode, LMUL, 0) if does not depend on SEW.
216321
uint8_t SEW = SI ? SI->getSEW() : 0;
217-
// Check if it depends on LMUL and SEW
218-
const RISCVVInversePseudosTable::PseudoInfo *RVV =
219-
RISCVVInversePseudosTable::getBaseInfo(Opcode, LMUL, SEW);
220-
// Check if it depends only on LMUL
221-
if (!RVV)
222-
RVV = RISCVVInversePseudosTable::getBaseInfo(Opcode, LMUL, 0);
322+
323+
const RISCVVInversePseudosTable::PseudoInfo *RVV = nullptr;
324+
if (Opcode == RISCV::VLM_V || Opcode == RISCV::VSM_V ||
325+
Opcode == RISCV::VLE8_V || Opcode == RISCV::VSE8_V ||
326+
Opcode == RISCV::VLE16_V || Opcode == RISCV::VSE16_V ||
327+
Opcode == RISCV::VLE32_V || Opcode == RISCV::VSE32_V ||
328+
Opcode == RISCV::VLE64_V || Opcode == RISCV::VSE64_V) {
329+
auto [EEW, EMUL] = getEEWAndEMULForUnitStrideLoadStore(Opcode, LMUL, SEW);
330+
RVV = RISCVVInversePseudosTable::getBaseInfo(Opcode, EMUL, EEW);
331+
} else {
332+
// Check if it depends on LMUL and SEW
333+
RVV = RISCVVInversePseudosTable::getBaseInfo(Opcode, LMUL, SEW);
334+
// Check if it depends only on LMUL
335+
if (!RVV)
336+
RVV = RISCVVInversePseudosTable::getBaseInfo(Opcode, LMUL, 0);
337+
}
223338

224339
// Not a RVV instr
225340
if (!RVV) {

0 commit comments

Comments
 (0)