Skip to content

Commit 7eccd52

Browse files
committed
Reland "[mlir][gpu] Align reduction operations with vector combining kinds (#73423)"
This reverts commit dd09221 and relands #73423. * Updated `gpu.all_reduce` `min`/`max` in CUDA integration tests.
1 parent c644486 commit 7eccd52

File tree

13 files changed

+354
-173
lines changed

13 files changed

+354
-173
lines changed

mlir/include/mlir/Dialect/GPU/IR/GPUOps.td

Lines changed: 59 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -931,38 +931,53 @@ def GPU_YieldOp : GPU_Op<"yield", [Pure, Terminator]>,
931931
}];
932932
}
933933

934-
// add, mul mirror the XLA ComparisonDirection enum.
934+
// These mirror the reduction combining kinds from the vector dialect.
935935
def GPU_AllReduceOpAdd : I32EnumAttrCase<"ADD", 0, "add">;
936-
def GPU_AllReduceOpAnd : I32EnumAttrCase<"AND", 1, "and">;
937-
def GPU_AllReduceOpMax : I32EnumAttrCase<"MAX", 2, "max">;
938-
def GPU_AllReduceOpMin : I32EnumAttrCase<"MIN", 3, "min">;
939-
def GPU_AllReduceOpMul : I32EnumAttrCase<"MUL", 4, "mul">;
940-
def GPU_AllReduceOpOr : I32EnumAttrCase<"OR", 5, "or">;
941-
def GPU_AllReduceOpXor : I32EnumAttrCase<"XOR", 6, "xor">;
936+
def GPU_AllReduceOpMul : I32EnumAttrCase<"MUL", 1, "mul">;
937+
def GPU_AllReduceOpMinUI : I32EnumAttrCase<"MINUI", 2, "minui">;
938+
def GPU_AllReduceOpMinSI : I32EnumAttrCase<"MINSI", 3, "minsi">;
939+
// Follows the `arith.minnumf` semantics.
940+
def GPU_AllReduceOpMinF : I32EnumAttrCase<"MINF", 4, "minf">;
941+
def GPU_AllReduceOpMaxUI : I32EnumAttrCase<"MAXUI", 5, "maxui">;
942+
def GPU_AllReduceOpMaxSI : I32EnumAttrCase<"MAXSI", 6, "maxsi">;
943+
// Follows the `arith.maxnumf` semantics.
944+
def GPU_AllReduceOpMaxF : I32EnumAttrCase<"MAXF", 7, "maxf">;
945+
def GPU_AllReduceOpAnd : I32EnumAttrCase<"AND", 8, "and">;
946+
def GPU_AllReduceOpOr : I32EnumAttrCase<"OR", 9, "or">;
947+
def GPU_AllReduceOpXor : I32EnumAttrCase<"XOR", 10, "xor">;
948+
// Follows the `arith.minimumf` semantics.
949+
def GPU_AllReduceOpMinimumF : I32EnumAttrCase<"MINIMUMF", 11, "minimumf">;
950+
// Follows the `arith.maximumf` semantics.
951+
def GPU_AllReduceOpMaximumF : I32EnumAttrCase<"MAXIMUMF", 12, "maximumf">;
942952

943953
def GPU_AllReduceOperation : I32EnumAttr<"AllReduceOperation",
944954
"built-in reduction operations supported by gpu.allreduce.",
945955
[
946956
GPU_AllReduceOpAdd,
947-
GPU_AllReduceOpAnd,
948-
GPU_AllReduceOpMax,
949-
GPU_AllReduceOpMin,
950957
GPU_AllReduceOpMul,
958+
GPU_AllReduceOpMinUI,
959+
GPU_AllReduceOpMinSI,
960+
GPU_AllReduceOpMinF,
961+
GPU_AllReduceOpMaxUI,
962+
GPU_AllReduceOpMaxSI,
963+
GPU_AllReduceOpMaxF,
964+
GPU_AllReduceOpAnd,
951965
GPU_AllReduceOpOr,
952-
GPU_AllReduceOpXor
966+
GPU_AllReduceOpXor,
967+
GPU_AllReduceOpMinimumF,
968+
GPU_AllReduceOpMaximumF
953969
]>{
954970
let genSpecializedAttr = 0;
955971
let cppNamespace = "::mlir::gpu";
956972
}
973+
974+
def AnyIntegerOrFloat : AnyTypeOf<[AnySignlessInteger, AnyFloat], "Integer or Float">;
975+
957976
def GPU_AllReduceOperationAttr : EnumAttr<GPU_Dialect, GPU_AllReduceOperation,
958977
"all_reduce_op">;
959978

960979
def GPU_AllReduceOp : GPU_Op<"all_reduce",
961-
[SameOperandsAndResultType, IsolatedFromAbove]>,
962-
Arguments<(ins AnyType:$value,
963-
OptionalAttr<GPU_AllReduceOperationAttr>:$op,
964-
UnitAttr:$uniform)>,
965-
Results<(outs AnyType)> {
980+
[SameOperandsAndResultType, IsolatedFromAbove]> {
966981
let summary = "Reduce values among workgroup.";
967982
let description = [{
968983
The `all_reduce` op reduces the value of every work item across a local
@@ -981,12 +996,23 @@ def GPU_AllReduceOp : GPU_Op<"all_reduce",
981996

982997
compute the sum of each work item's %0 value. The first version specifies
983998
the accumulation as operation, whereas the second version specifies the
984-
accumulation as code region. The accumulation operation must be one of:
985-
`add`, `and`, `max`, `min`, `mul`, `or`, `xor`.
999+
accumulation as code region. The reduction operation must be one of:
1000+
* Integer types: `add`, `mul`, `minui`, `minsi`, `maxui`, `maxsi`, `and`,
1001+
`or`, `xor`
1002+
* Floating point types: `add`, `mul`, `minf`, `maxf`, `minimumf`,
1003+
`maximumf`
9861004

9871005
If `uniform` flag is set either none or all work items of a workgroup
9881006
need to execute this op in convergence.
9891007
}];
1008+
1009+
let arguments = (ins
1010+
AnyIntegerOrFloat:$value,
1011+
OptionalAttr<GPU_AllReduceOperationAttr>:$op,
1012+
UnitAttr:$uniform
1013+
);
1014+
let results = (outs AnyIntegerOrFloat:$result);
1015+
9901016
let regions = (region AnyRegion:$body);
9911017
let assemblyFormat = [{ custom<AllReduceOperation>($op) $value
9921018
(`uniform` $uniform^)? $body attr-dict
@@ -996,12 +1022,7 @@ def GPU_AllReduceOp : GPU_Op<"all_reduce",
9961022
let hasRegionVerifier = 1;
9971023
}
9981024

999-
def GPU_SubgroupReduceOp : GPU_Op<"subgroup_reduce",
1000-
[SameOperandsAndResultType]>,
1001-
Arguments<(ins AnyType:$value,
1002-
GPU_AllReduceOperationAttr:$op,
1003-
UnitAttr:$uniform)>,
1004-
Results<(outs AnyType)> {
1025+
def GPU_SubgroupReduceOp : GPU_Op<"subgroup_reduce", [SameOperandsAndResultType]> {
10051026
let summary = "Reduce values among subgroup.";
10061027
let description = [{
10071028
The `subgroup_reduce` op reduces the value of every work item across a
@@ -1014,8 +1035,21 @@ def GPU_SubgroupReduceOp : GPU_Op<"subgroup_reduce",
10141035
```
10151036

10161037
If `uniform` flag is set either none or all work items of a subgroup
1017-
need to execute this op in convergence.
1038+
need to execute this op in convergence. The reduction operation must be one
1039+
of:
1040+
* Integer types: `add`, `mul`, `minui`, `minsi`, `maxui`, `maxsi`, `and`,
1041+
`or`, `xor`
1042+
* Floating point types: `add`, `mul`, `minf`, `maxf`, `minimumf`,
1043+
`maximumf`
10181044
}];
1045+
1046+
let arguments = (ins
1047+
AnyIntegerOrFloat:$value,
1048+
GPU_AllReduceOperationAttr:$op,
1049+
UnitAttr:$uniform
1050+
);
1051+
let results = (outs AnyIntegerOrFloat:$result);
1052+
10191053
let assemblyFormat = [{ custom<AllReduceOperation>($op) $value
10201054
(`uniform` $uniform^)? attr-dict
10211055
`:` functional-type(operands, results) }];

mlir/include/mlir/IR/CommonTypeConstraints.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def IsFixedVectorTypePred : CPred<[{::llvm::isa<::mlir::VectorType>($_self) &&
3434
!::llvm::cast<VectorType>($_self).isScalable()}]>;
3535

3636
// Whether a type is a scalable VectorType.
37-
def IsVectorTypeWithAnyDimScalablePred
37+
def IsVectorTypeWithAnyDimScalablePred
3838
: CPred<[{::llvm::isa<::mlir::VectorType>($_self) &&
3939
::llvm::cast<VectorType>($_self).isScalable()}]>;
4040

mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -65,17 +65,28 @@ convertReduxKind(gpu::AllReduceOperation mode) {
6565
switch (mode) {
6666
case gpu::AllReduceOperation::ADD:
6767
return NVVM::ReduxKind::ADD;
68+
case gpu::AllReduceOperation::MUL:
69+
return std::nullopt;
70+
case gpu::AllReduceOperation::MINSI:
71+
return NVVM::ReduxKind::MIN;
72+
case gpu::AllReduceOperation::MINUI:
73+
return std::nullopt;
74+
case gpu::AllReduceOperation::MINF:
75+
return NVVM::ReduxKind::MIN;
76+
case gpu::AllReduceOperation::MAXSI:
77+
return NVVM::ReduxKind::MAX;
78+
case gpu::AllReduceOperation::MAXUI:
79+
return std::nullopt;
80+
case gpu::AllReduceOperation::MAXF:
81+
return NVVM::ReduxKind::MAX;
6882
case gpu::AllReduceOperation::AND:
6983
return NVVM::ReduxKind::AND;
70-
case gpu::AllReduceOperation::MAX:
71-
return NVVM::ReduxKind::MAX;
72-
case gpu::AllReduceOperation::MIN:
73-
return NVVM::ReduxKind::MIN;
7484
case gpu::AllReduceOperation::OR:
7585
return NVVM::ReduxKind::OR;
7686
case gpu::AllReduceOperation::XOR:
7787
return NVVM::ReduxKind::XOR;
78-
case gpu::AllReduceOperation::MUL:
88+
case gpu::AllReduceOperation::MINIMUMF:
89+
case gpu::AllReduceOperation::MAXIMUMF:
7990
return std::nullopt;
8091
}
8192
return std::nullopt;

mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp

Lines changed: 43 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -503,26 +503,53 @@ static std::optional<Value> createGroupReduceOp(OpBuilder &builder,
503503
return std::nullopt;
504504
}
505505

506+
// TODO(https://github.com/llvm/llvm-project/issues/73459): The SPIR-V spec
507+
// does not specify how -0.0 / +0.0 and NaN values are handled in *FMin/*FMax
508+
// reduction ops. We should account possible precision requirements in this
509+
// conversion.
510+
506511
using ReduceType = gpu::AllReduceOperation;
507-
namespace spv = spirv;
508512
const OpHandler handlers[] = {
509513
{ReduceType::ADD,
510-
&createGroupReduceOpImpl<spv::GroupIAddOp, spv::GroupNonUniformIAddOp>,
511-
&createGroupReduceOpImpl<spv::GroupFAddOp, spv::GroupNonUniformFAddOp>},
514+
&createGroupReduceOpImpl<spirv::GroupIAddOp,
515+
spirv::GroupNonUniformIAddOp>,
516+
&createGroupReduceOpImpl<spirv::GroupFAddOp,
517+
spirv::GroupNonUniformFAddOp>},
512518
{ReduceType::MUL,
513-
&createGroupReduceOpImpl<spv::GroupIMulKHROp,
514-
spv::GroupNonUniformIMulOp>,
515-
&createGroupReduceOpImpl<spv::GroupFMulKHROp,
516-
spv::GroupNonUniformFMulOp>},
517-
{ReduceType::MIN,
518-
&createGroupReduceOpImpl<spv::GroupSMinOp, spv::GroupNonUniformSMinOp>,
519-
&createGroupReduceOpImpl<spv::GroupFMinOp, spv::GroupNonUniformFMinOp>},
520-
{ReduceType::MAX,
521-
&createGroupReduceOpImpl<spv::GroupSMaxOp, spv::GroupNonUniformSMaxOp>,
522-
&createGroupReduceOpImpl<spv::GroupFMaxOp, spv::GroupNonUniformFMaxOp>},
523-
};
524-
525-
for (auto &handler : handlers)
519+
&createGroupReduceOpImpl<spirv::GroupIMulKHROp,
520+
spirv::GroupNonUniformIMulOp>,
521+
&createGroupReduceOpImpl<spirv::GroupFMulKHROp,
522+
spirv::GroupNonUniformFMulOp>},
523+
{ReduceType::MINUI,
524+
&createGroupReduceOpImpl<spirv::GroupUMinOp,
525+
spirv::GroupNonUniformUMinOp>,
526+
nullptr},
527+
{ReduceType::MINSI,
528+
&createGroupReduceOpImpl<spirv::GroupSMinOp,
529+
spirv::GroupNonUniformSMinOp>,
530+
nullptr},
531+
{ReduceType::MINF, nullptr,
532+
&createGroupReduceOpImpl<spirv::GroupFMinOp,
533+
spirv::GroupNonUniformFMinOp>},
534+
{ReduceType::MAXUI,
535+
&createGroupReduceOpImpl<spirv::GroupUMaxOp,
536+
spirv::GroupNonUniformUMaxOp>,
537+
nullptr},
538+
{ReduceType::MAXSI,
539+
&createGroupReduceOpImpl<spirv::GroupSMaxOp,
540+
spirv::GroupNonUniformSMaxOp>,
541+
nullptr},
542+
{ReduceType::MAXF, nullptr,
543+
&createGroupReduceOpImpl<spirv::GroupFMaxOp,
544+
spirv::GroupNonUniformFMaxOp>},
545+
{ReduceType::MINIMUMF, nullptr,
546+
&createGroupReduceOpImpl<spirv::GroupFMinOp,
547+
spirv::GroupNonUniformFMinOp>},
548+
{ReduceType::MAXIMUMF, nullptr,
549+
&createGroupReduceOpImpl<spirv::GroupFMaxOp,
550+
spirv::GroupNonUniformFMaxOp>}};
551+
552+
for (const OpHandler &handler : handlers)
526553
if (handler.type == opType)
527554
return (handler.*handlerPtr)(builder, loc, arg, isGroup, isUniform);
528555

mlir/lib/Dialect/GPU/IR/GPUDialect.cpp

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@
2727
#include "mlir/IR/TypeUtilities.h"
2828
#include "mlir/Interfaces/FunctionImplementation.h"
2929
#include "mlir/Interfaces/SideEffectInterfaces.h"
30+
#include "mlir/Support/LogicalResult.h"
3031
#include "mlir/Transforms/InliningUtils.h"
32+
#include "llvm/ADT/STLExtras.h"
3133
#include "llvm/ADT/TypeSwitch.h"
3234
#include "llvm/Support/CommandLine.h"
3335
#include "llvm/Support/ErrorHandling.h"
@@ -486,12 +488,23 @@ static LogicalResult verifyAttributions(Operation *op,
486488
// AllReduceOp
487489
//===----------------------------------------------------------------------===//
488490

489-
static bool verifyReduceOpAndType(gpu::AllReduceOperation opName,
490-
Type resType) {
491-
return (opName != gpu::AllReduceOperation::AND &&
492-
opName != gpu::AllReduceOperation::OR &&
493-
opName != gpu::AllReduceOperation::XOR) ||
494-
llvm::isa<IntegerType>(resType);
491+
static LogicalResult verifyReduceOpAndType(gpu::AllReduceOperation opName,
492+
Type resType) {
493+
using Kind = gpu::AllReduceOperation;
494+
if (llvm::is_contained(
495+
{Kind::MINF, Kind::MAXF, Kind::MINIMUMF, Kind::MAXIMUMF}, opName)) {
496+
if (!isa<FloatType>(resType))
497+
return failure();
498+
}
499+
500+
if (llvm::is_contained({Kind::MINSI, Kind::MINUI, Kind::MAXSI, Kind::MAXUI,
501+
Kind::AND, Kind::OR, Kind::XOR},
502+
opName)) {
503+
if (!isa<IntegerType>(resType))
504+
return failure();
505+
}
506+
507+
return success();
495508
}
496509

497510
LogicalResult gpu::AllReduceOp::verifyRegions() {
@@ -518,12 +531,13 @@ LogicalResult gpu::AllReduceOp::verifyRegions() {
518531
return emitError("expected gpu.yield op in region");
519532
} else {
520533
gpu::AllReduceOperation opName = *getOp();
521-
if (!verifyReduceOpAndType(opName, getType())) {
522-
return emitError()
523-
<< '`' << gpu::stringifyAllReduceOperation(opName)
524-
<< "` accumulator is only compatible with Integer type";
534+
if (failed(verifyReduceOpAndType(opName, getType()))) {
535+
return emitError() << '`' << gpu::stringifyAllReduceOperation(opName)
536+
<< "` reduction operation is not compatible with type "
537+
<< getType();
525538
}
526539
}
540+
527541
return success();
528542
}
529543

@@ -574,9 +588,10 @@ static void printAllReduceOperation(AsmPrinter &printer, Operation *op,
574588

575589
LogicalResult gpu::SubgroupReduceOp::verify() {
576590
gpu::AllReduceOperation opName = getOp();
577-
if (!verifyReduceOpAndType(opName, getType())) {
591+
if (failed(verifyReduceOpAndType(opName, getType()))) {
578592
return emitError() << '`' << gpu::stringifyAllReduceOperation(opName)
579-
<< "` accumulator is only compatible with Integer type";
593+
<< "` reduction operation is not compatible with type "
594+
<< getType();
580595
}
581596
return success();
582597
}

mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp

Lines changed: 23 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -214,54 +214,49 @@ struct GpuAllReduceRewriter {
214214

215215
/// Returns an accumulator factory that creates an op specified by opName.
216216
AccumulatorFactory getFactory(gpu::AllReduceOperation opName) {
217+
using Kind = gpu::AllReduceOperation;
217218
bool isFloatingPoint = isa<FloatType>(valueType);
218219
switch (opName) {
219-
case gpu::AllReduceOperation::ADD:
220+
case Kind::ADD:
220221
return isFloatingPoint ? getFactory<arith::AddFOp>()
221222
: getFactory<arith::AddIOp>();
222-
case gpu::AllReduceOperation::MUL:
223+
case Kind::MUL:
223224
return isFloatingPoint ? getFactory<arith::MulFOp>()
224225
: getFactory<arith::MulIOp>();
225-
case gpu::AllReduceOperation::AND:
226+
case Kind::MINSI:
227+
return getFactory<arith::MinSIOp>();
228+
case Kind::MINUI:
229+
return getFactory<arith::MinUIOp>();
230+
case Kind::MINF:
231+
return getFactory<arith::MinNumFOp>();
232+
case Kind::MAXSI:
233+
return getFactory<arith::MaxSIOp>();
234+
case Kind::MAXUI:
235+
return getFactory<arith::MaxUIOp>();
236+
case Kind::MAXF:
237+
return getFactory<arith::MaxNumFOp>();
238+
case Kind::AND:
226239
return getFactory<arith::AndIOp>();
227-
case gpu::AllReduceOperation::OR:
240+
case Kind::OR:
228241
return getFactory<arith::OrIOp>();
229-
case gpu::AllReduceOperation::XOR:
242+
case Kind::XOR:
230243
return getFactory<arith::XOrIOp>();
231-
case gpu::AllReduceOperation::MAX:
232-
return isFloatingPoint
233-
? getCmpFactory<arith::CmpFOp, arith::CmpFPredicate,
234-
arith::CmpFPredicate::UGT>()
235-
: getCmpFactory<arith::CmpIOp, arith::CmpIPredicate,
236-
arith::CmpIPredicate::ugt>();
237-
case gpu::AllReduceOperation::MIN:
238-
return isFloatingPoint
239-
? getCmpFactory<arith::CmpFOp, arith::CmpFPredicate,
240-
arith::CmpFPredicate::ULT>()
241-
: getCmpFactory<arith::CmpIOp, arith::CmpIPredicate,
242-
arith::CmpIPredicate::ult>();
244+
case Kind::MINIMUMF:
245+
return getFactory<arith::MinimumFOp>();
246+
case Kind::MAXIMUMF:
247+
return getFactory<arith::MaximumFOp>();
243248
}
244249
llvm_unreachable("unknown GPU AllReduceOperation");
245250
}
246251

247252
/// Returns an accumulator factory that creates an op of type T.
248253
template <typename T>
249254
AccumulatorFactory getFactory() {
250-
return [&](Value lhs, Value rhs) {
255+
return [this](Value lhs, Value rhs) {
251256
return create<T>(lhs.getType(), lhs, rhs);
252257
};
253258
}
254259

255-
/// Returns an accumulator for comparison such as min, max. T is the type
256-
/// of the compare op.
257-
template <typename T, typename PredicateEnum, PredicateEnum predicate>
258-
AccumulatorFactory getCmpFactory() const {
259-
return [&](Value lhs, Value rhs) {
260-
Value cmp = rewriter.create<T>(loc, predicate, lhs, rhs);
261-
return rewriter.create<arith::SelectOp>(loc, cmp, lhs, rhs);
262-
};
263-
}
264-
265260
/// Creates an if-block skeleton and calls the two factories to generate the
266261
/// ops in the `then` and `else` block..
267262
///

0 commit comments

Comments
 (0)