Skip to content

Commit b966978

Browse files
committed
[GlobalISel][NFC] Introduce a GVecReduce wrapper class and a minor refactor.
1 parent da56750 commit b966978

File tree

2 files changed

+81
-56
lines changed

2 files changed

+81
-56
lines changed

llvm/include/llvm/CodeGen/GlobalISel/GenericMachineInstrs.h

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,82 @@ class GIntrinsic final : public GenericMachineInstr {
400400
}
401401
};
402402

403+
// Represents a (non-sequential) vector reduction operation.
404+
class GVecReduce : public GenericMachineInstr {
405+
public:
406+
static bool classof(const MachineInstr *MI) {
407+
switch (MI->getOpcode()) {
408+
case TargetOpcode::G_VECREDUCE_FADD:
409+
case TargetOpcode::G_VECREDUCE_FMUL:
410+
case TargetOpcode::G_VECREDUCE_FMAX:
411+
case TargetOpcode::G_VECREDUCE_FMIN:
412+
case TargetOpcode::G_VECREDUCE_ADD:
413+
case TargetOpcode::G_VECREDUCE_MUL:
414+
case TargetOpcode::G_VECREDUCE_AND:
415+
case TargetOpcode::G_VECREDUCE_OR:
416+
case TargetOpcode::G_VECREDUCE_XOR:
417+
case TargetOpcode::G_VECREDUCE_SMAX:
418+
case TargetOpcode::G_VECREDUCE_SMIN:
419+
case TargetOpcode::G_VECREDUCE_UMAX:
420+
case TargetOpcode::G_VECREDUCE_UMIN:
421+
return true;
422+
default:
423+
return false;
424+
}
425+
}
426+
427+
/// Get the opcode for the equivalent scalar operation for this reduction.
428+
/// E.g. for G_VECREDUCE_FADD, this returns G_FADD.
429+
unsigned getScalarOpcForReduction() {
430+
unsigned ScalarOpc;
431+
switch (getOpcode()) {
432+
case TargetOpcode::G_VECREDUCE_FADD:
433+
ScalarOpc = TargetOpcode::G_FADD;
434+
break;
435+
case TargetOpcode::G_VECREDUCE_FMUL:
436+
ScalarOpc = TargetOpcode::G_FMUL;
437+
break;
438+
case TargetOpcode::G_VECREDUCE_FMAX:
439+
ScalarOpc = TargetOpcode::G_FMAXNUM;
440+
break;
441+
case TargetOpcode::G_VECREDUCE_FMIN:
442+
ScalarOpc = TargetOpcode::G_FMINNUM;
443+
break;
444+
case TargetOpcode::G_VECREDUCE_ADD:
445+
ScalarOpc = TargetOpcode::G_ADD;
446+
break;
447+
case TargetOpcode::G_VECREDUCE_MUL:
448+
ScalarOpc = TargetOpcode::G_MUL;
449+
break;
450+
case TargetOpcode::G_VECREDUCE_AND:
451+
ScalarOpc = TargetOpcode::G_AND;
452+
break;
453+
case TargetOpcode::G_VECREDUCE_OR:
454+
ScalarOpc = TargetOpcode::G_OR;
455+
break;
456+
case TargetOpcode::G_VECREDUCE_XOR:
457+
ScalarOpc = TargetOpcode::G_XOR;
458+
break;
459+
case TargetOpcode::G_VECREDUCE_SMAX:
460+
ScalarOpc = TargetOpcode::G_SMAX;
461+
break;
462+
case TargetOpcode::G_VECREDUCE_SMIN:
463+
ScalarOpc = TargetOpcode::G_SMIN;
464+
break;
465+
case TargetOpcode::G_VECREDUCE_UMAX:
466+
ScalarOpc = TargetOpcode::G_UMAX;
467+
break;
468+
case TargetOpcode::G_VECREDUCE_UMIN:
469+
ScalarOpc = TargetOpcode::G_UMIN;
470+
break;
471+
default:
472+
llvm_unreachable("Unhandled reduction");
473+
}
474+
return ScalarOpc;
475+
}
476+
};
477+
478+
403479
} // namespace llvm
404480

405481
#endif // LLVM_CODEGEN_GLOBALISEL_GENERICMACHINEINSTRS_H

llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp

Lines changed: 5 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -4428,73 +4428,22 @@ LegalizerHelper::LegalizeResult LegalizerHelper::fewerElementsVectorShuffle(
44284428
return Legalized;
44294429
}
44304430

4431-
static unsigned getScalarOpcForReduction(unsigned Opc) {
4432-
unsigned ScalarOpc;
4433-
switch (Opc) {
4434-
case TargetOpcode::G_VECREDUCE_FADD:
4435-
ScalarOpc = TargetOpcode::G_FADD;
4436-
break;
4437-
case TargetOpcode::G_VECREDUCE_FMUL:
4438-
ScalarOpc = TargetOpcode::G_FMUL;
4439-
break;
4440-
case TargetOpcode::G_VECREDUCE_FMAX:
4441-
ScalarOpc = TargetOpcode::G_FMAXNUM;
4442-
break;
4443-
case TargetOpcode::G_VECREDUCE_FMIN:
4444-
ScalarOpc = TargetOpcode::G_FMINNUM;
4445-
break;
4446-
case TargetOpcode::G_VECREDUCE_ADD:
4447-
ScalarOpc = TargetOpcode::G_ADD;
4448-
break;
4449-
case TargetOpcode::G_VECREDUCE_MUL:
4450-
ScalarOpc = TargetOpcode::G_MUL;
4451-
break;
4452-
case TargetOpcode::G_VECREDUCE_AND:
4453-
ScalarOpc = TargetOpcode::G_AND;
4454-
break;
4455-
case TargetOpcode::G_VECREDUCE_OR:
4456-
ScalarOpc = TargetOpcode::G_OR;
4457-
break;
4458-
case TargetOpcode::G_VECREDUCE_XOR:
4459-
ScalarOpc = TargetOpcode::G_XOR;
4460-
break;
4461-
case TargetOpcode::G_VECREDUCE_SMAX:
4462-
ScalarOpc = TargetOpcode::G_SMAX;
4463-
break;
4464-
case TargetOpcode::G_VECREDUCE_SMIN:
4465-
ScalarOpc = TargetOpcode::G_SMIN;
4466-
break;
4467-
case TargetOpcode::G_VECREDUCE_UMAX:
4468-
ScalarOpc = TargetOpcode::G_UMAX;
4469-
break;
4470-
case TargetOpcode::G_VECREDUCE_UMIN:
4471-
ScalarOpc = TargetOpcode::G_UMIN;
4472-
break;
4473-
default:
4474-
llvm_unreachable("Unhandled reduction");
4475-
}
4476-
return ScalarOpc;
4477-
}
4478-
44794431
LegalizerHelper::LegalizeResult LegalizerHelper::fewerElementsVectorReductions(
44804432
MachineInstr &MI, unsigned int TypeIdx, LLT NarrowTy) {
4481-
unsigned Opc = MI.getOpcode();
4482-
assert(Opc != TargetOpcode::G_VECREDUCE_SEQ_FADD &&
4483-
Opc != TargetOpcode::G_VECREDUCE_SEQ_FMUL &&
4484-
"Sequential reductions not expected");
4433+
auto &RdxMI = cast<GVecReduce>(MI);
44854434

44864435
if (TypeIdx != 1)
44874436
return UnableToLegalize;
44884437

44894438
// The semantics of the normal non-sequential reductions allow us to freely
44904439
// re-associate the operation.
4491-
auto [DstReg, DstTy, SrcReg, SrcTy] = MI.getFirst2RegLLTs();
4440+
auto [DstReg, DstTy, SrcReg, SrcTy] = RdxMI.getFirst2RegLLTs();
44924441

44934442
if (NarrowTy.isVector() &&
44944443
(SrcTy.getNumElements() % NarrowTy.getNumElements() != 0))
44954444
return UnableToLegalize;
44964445

4497-
unsigned ScalarOpc = getScalarOpcForReduction(Opc);
4446+
unsigned ScalarOpc = RdxMI.getScalarOpcForReduction();
44984447
SmallVector<Register> SplitSrcs;
44994448
// If NarrowTy is a scalar then we're being asked to scalarize.
45004449
const unsigned NumParts =
@@ -4539,10 +4488,10 @@ LegalizerHelper::LegalizeResult LegalizerHelper::fewerElementsVectorReductions(
45394488
SmallVector<Register> PartialReductions;
45404489
for (unsigned Part = 0; Part < NumParts; ++Part) {
45414490
PartialReductions.push_back(
4542-
MIRBuilder.buildInstr(Opc, {DstTy}, {SplitSrcs[Part]}).getReg(0));
4491+
MIRBuilder.buildInstr(RdxMI.getOpcode(), {DstTy}, {SplitSrcs[Part]})
4492+
.getReg(0));
45434493
}
45444494

4545-
45464495
// If the types involved are powers of 2, we can generate intermediate vector
45474496
// ops, before generating a final reduction operation.
45484497
if (isPowerOf2_32(SrcTy.getNumElements()) &&

0 commit comments

Comments
 (0)