Skip to content

[mlir][gpu] Align reduction operations with vector combining kinds #73423

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Nov 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 59 additions & 25 deletions mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -868,38 +868,53 @@ def GPU_YieldOp : GPU_Op<"yield", [Pure, Terminator]>,
}];
}

// add, mul mirror the XLA ComparisonDirection enum.
// These mirror the reduction combining kinds from the vector dialect.
def GPU_AllReduceOpAdd : I32EnumAttrCase<"ADD", 0, "add">;
def GPU_AllReduceOpAnd : I32EnumAttrCase<"AND", 1, "and">;
def GPU_AllReduceOpMax : I32EnumAttrCase<"MAX", 2, "max">;
def GPU_AllReduceOpMin : I32EnumAttrCase<"MIN", 3, "min">;
def GPU_AllReduceOpMul : I32EnumAttrCase<"MUL", 4, "mul">;
def GPU_AllReduceOpOr : I32EnumAttrCase<"OR", 5, "or">;
def GPU_AllReduceOpXor : I32EnumAttrCase<"XOR", 6, "xor">;
def GPU_AllReduceOpMul : I32EnumAttrCase<"MUL", 1, "mul">;
def GPU_AllReduceOpMinUI : I32EnumAttrCase<"MINUI", 2, "minui">;
def GPU_AllReduceOpMinSI : I32EnumAttrCase<"MINSI", 3, "minsi">;
// Follows the `arith.minnumf` semantics.
def GPU_AllReduceOpMinF : I32EnumAttrCase<"MINF", 4, "minf">;
def GPU_AllReduceOpMaxUI : I32EnumAttrCase<"MAXUI", 5, "maxui">;
def GPU_AllReduceOpMaxSI : I32EnumAttrCase<"MAXSI", 6, "maxsi">;
// Follows the `arith.maxnumf` semantics.
def GPU_AllReduceOpMaxF : I32EnumAttrCase<"MAXF", 7, "maxf">;
def GPU_AllReduceOpAnd : I32EnumAttrCase<"AND", 8, "and">;
def GPU_AllReduceOpOr : I32EnumAttrCase<"OR", 9, "or">;
def GPU_AllReduceOpXor : I32EnumAttrCase<"XOR", 10, "xor">;
// Follows the `arith.minimumf` semantics.
def GPU_AllReduceOpMinimumF : I32EnumAttrCase<"MINIMUMF", 11, "minimumf">;
// Follows the `arith.maximumf` semantics.
def GPU_AllReduceOpMaximumF : I32EnumAttrCase<"MAXIMUMF", 12, "maximumf">;

def GPU_AllReduceOperation : I32EnumAttr<"AllReduceOperation",
"built-in reduction operations supported by gpu.allreduce.",
[
GPU_AllReduceOpAdd,
GPU_AllReduceOpAnd,
GPU_AllReduceOpMax,
GPU_AllReduceOpMin,
GPU_AllReduceOpMul,
GPU_AllReduceOpMinUI,
GPU_AllReduceOpMinSI,
GPU_AllReduceOpMinF,
GPU_AllReduceOpMaxUI,
GPU_AllReduceOpMaxSI,
GPU_AllReduceOpMaxF,
GPU_AllReduceOpAnd,
GPU_AllReduceOpOr,
GPU_AllReduceOpXor
GPU_AllReduceOpXor,
GPU_AllReduceOpMinimumF,
GPU_AllReduceOpMaximumF
]>{
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::gpu";
}

def AnyIntegerOrFloat : AnyTypeOf<[AnySignlessInteger, AnyFloat], "Integer or Float">;

def GPU_AllReduceOperationAttr : EnumAttr<GPU_Dialect, GPU_AllReduceOperation,
"all_reduce_op">;

def GPU_AllReduceOp : GPU_Op<"all_reduce",
[SameOperandsAndResultType, IsolatedFromAbove]>,
Arguments<(ins AnyType:$value,
OptionalAttr<GPU_AllReduceOperationAttr>:$op,
UnitAttr:$uniform)>,
Results<(outs AnyType)> {
[SameOperandsAndResultType, IsolatedFromAbove]> {
let summary = "Reduce values among workgroup.";
let description = [{
The `all_reduce` op reduces the value of every work item across a local
Expand All @@ -918,12 +933,23 @@ def GPU_AllReduceOp : GPU_Op<"all_reduce",

compute the sum of each work item's %0 value. The first version specifies
the accumulation as operation, whereas the second version specifies the
accumulation as code region. The accumulation operation must be one of:
`add`, `and`, `max`, `min`, `mul`, `or`, `xor`.
accumulation as code region. The reduction operation must be one of:
* Integer types: `add`, `mul`, `minui`, `minsi`, `maxui`, `maxsi`, `and`,
`or`, `xor`
* Floating point types: `add`, `mul`, `minf`, `maxf`, `minimumf`,
`maximumf`

If `uniform` flag is set either none or all work items of a workgroup
need to execute this op in convergence.
}];

let arguments = (ins
AnyIntegerOrFloat:$value,
OptionalAttr<GPU_AllReduceOperationAttr>:$op,
UnitAttr:$uniform
);
let results = (outs AnyIntegerOrFloat:$result);

let regions = (region AnyRegion:$body);
let assemblyFormat = [{ custom<AllReduceOperation>($op) $value
(`uniform` $uniform^)? $body attr-dict
Expand All @@ -933,12 +959,7 @@ def GPU_AllReduceOp : GPU_Op<"all_reduce",
let hasRegionVerifier = 1;
}

def GPU_SubgroupReduceOp : GPU_Op<"subgroup_reduce",
[SameOperandsAndResultType]>,
Arguments<(ins AnyType:$value,
GPU_AllReduceOperationAttr:$op,
UnitAttr:$uniform)>,
Results<(outs AnyType)> {
def GPU_SubgroupReduceOp : GPU_Op<"subgroup_reduce", [SameOperandsAndResultType]> {
let summary = "Reduce values among subgroup.";
let description = [{
The `subgroup_reduce` op reduces the value of every work item across a
Expand All @@ -951,8 +972,21 @@ def GPU_SubgroupReduceOp : GPU_Op<"subgroup_reduce",
```

If `uniform` flag is set either none or all work items of a subgroup
need to execute this op in convergence.
need to execute this op in convergence. The reduction operation must be one
of:
* Integer types: `add`, `mul`, `minui`, `minsi`, `maxui`, `maxsi`, `and`,
`or`, `xor`
* Floating point types: `add`, `mul`, `minf`, `maxf`, `minimumf`,
`maximumf`
}];

let arguments = (ins
AnyIntegerOrFloat:$value,
GPU_AllReduceOperationAttr:$op,
UnitAttr:$uniform
);
let results = (outs AnyIntegerOrFloat:$result);

let assemblyFormat = [{ custom<AllReduceOperation>($op) $value
(`uniform` $uniform^)? attr-dict
`:` functional-type(operands, results) }];
Expand Down
2 changes: 1 addition & 1 deletion mlir/include/mlir/IR/CommonTypeConstraints.td
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def IsFixedVectorTypePred : CPred<[{::llvm::isa<::mlir::VectorType>($_self) &&
!::llvm::cast<VectorType>($_self).isScalable()}]>;

// Whether a type is a scalable VectorType.
def IsVectorTypeWithAnyDimScalablePred
def IsVectorTypeWithAnyDimScalablePred
: CPred<[{::llvm::isa<::mlir::VectorType>($_self) &&
::llvm::cast<VectorType>($_self).isScalable()}]>;

Expand Down
21 changes: 16 additions & 5 deletions mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,17 +65,28 @@ convertReduxKind(gpu::AllReduceOperation mode) {
switch (mode) {
case gpu::AllReduceOperation::ADD:
return NVVM::ReduxKind::ADD;
case gpu::AllReduceOperation::MUL:
return std::nullopt;
case gpu::AllReduceOperation::MINSI:
return NVVM::ReduxKind::MIN;
case gpu::AllReduceOperation::MINUI:
return std::nullopt;
case gpu::AllReduceOperation::MINF:
return NVVM::ReduxKind::MIN;
case gpu::AllReduceOperation::MAXSI:
return NVVM::ReduxKind::MAX;
case gpu::AllReduceOperation::MAXUI:
return std::nullopt;
case gpu::AllReduceOperation::MAXF:
return NVVM::ReduxKind::MAX;
case gpu::AllReduceOperation::AND:
return NVVM::ReduxKind::AND;
case gpu::AllReduceOperation::MAX:
return NVVM::ReduxKind::MAX;
case gpu::AllReduceOperation::MIN:
return NVVM::ReduxKind::MIN;
case gpu::AllReduceOperation::OR:
return NVVM::ReduxKind::OR;
case gpu::AllReduceOperation::XOR:
return NVVM::ReduxKind::XOR;
case gpu::AllReduceOperation::MUL:
case gpu::AllReduceOperation::MINIMUMF:
case gpu::AllReduceOperation::MAXIMUMF:
return std::nullopt;
}
return std::nullopt;
Expand Down
59 changes: 43 additions & 16 deletions mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -503,26 +503,53 @@ static std::optional<Value> createGroupReduceOp(OpBuilder &builder,
return std::nullopt;
}

// TODO(https://github.com/llvm/llvm-project/issues/73459): The SPIR-V spec
// does not specify how -0.0 / +0.0 and NaN values are handled in *FMin/*FMax
// reduction ops. We should account possible precision requirements in this
// conversion.

using ReduceType = gpu::AllReduceOperation;
namespace spv = spirv;
const OpHandler handlers[] = {
{ReduceType::ADD,
&createGroupReduceOpImpl<spv::GroupIAddOp, spv::GroupNonUniformIAddOp>,
&createGroupReduceOpImpl<spv::GroupFAddOp, spv::GroupNonUniformFAddOp>},
&createGroupReduceOpImpl<spirv::GroupIAddOp,
spirv::GroupNonUniformIAddOp>,
&createGroupReduceOpImpl<spirv::GroupFAddOp,
spirv::GroupNonUniformFAddOp>},
{ReduceType::MUL,
&createGroupReduceOpImpl<spv::GroupIMulKHROp,
spv::GroupNonUniformIMulOp>,
&createGroupReduceOpImpl<spv::GroupFMulKHROp,
spv::GroupNonUniformFMulOp>},
{ReduceType::MIN,
&createGroupReduceOpImpl<spv::GroupSMinOp, spv::GroupNonUniformSMinOp>,
&createGroupReduceOpImpl<spv::GroupFMinOp, spv::GroupNonUniformFMinOp>},
{ReduceType::MAX,
&createGroupReduceOpImpl<spv::GroupSMaxOp, spv::GroupNonUniformSMaxOp>,
&createGroupReduceOpImpl<spv::GroupFMaxOp, spv::GroupNonUniformFMaxOp>},
};

for (auto &handler : handlers)
&createGroupReduceOpImpl<spirv::GroupIMulKHROp,
spirv::GroupNonUniformIMulOp>,
&createGroupReduceOpImpl<spirv::GroupFMulKHROp,
spirv::GroupNonUniformFMulOp>},
{ReduceType::MINUI,
&createGroupReduceOpImpl<spirv::GroupUMinOp,
spirv::GroupNonUniformUMinOp>,
nullptr},
{ReduceType::MINSI,
&createGroupReduceOpImpl<spirv::GroupSMinOp,
spirv::GroupNonUniformSMinOp>,
nullptr},
{ReduceType::MINF, nullptr,
&createGroupReduceOpImpl<spirv::GroupFMinOp,
spirv::GroupNonUniformFMinOp>},
{ReduceType::MAXUI,
&createGroupReduceOpImpl<spirv::GroupUMaxOp,
spirv::GroupNonUniformUMaxOp>,
nullptr},
{ReduceType::MAXSI,
&createGroupReduceOpImpl<spirv::GroupSMaxOp,
spirv::GroupNonUniformSMaxOp>,
nullptr},
{ReduceType::MAXF, nullptr,
&createGroupReduceOpImpl<spirv::GroupFMaxOp,
spirv::GroupNonUniformFMaxOp>},
{ReduceType::MINIMUMF, nullptr,
&createGroupReduceOpImpl<spirv::GroupFMinOp,
spirv::GroupNonUniformFMinOp>},
{ReduceType::MAXIMUMF, nullptr,
&createGroupReduceOpImpl<spirv::GroupFMaxOp,
spirv::GroupNonUniformFMaxOp>}};

for (const OpHandler &handler : handlers)
if (handler.type == opType)
return (handler.*handlerPtr)(builder, loc, arg, isGroup, isUniform);

Expand Down
39 changes: 27 additions & 12 deletions mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/FunctionImplementation.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/InliningUtils.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/ErrorHandling.h"
Expand Down Expand Up @@ -485,12 +487,23 @@ static LogicalResult verifyAttributions(Operation *op,
// AllReduceOp
//===----------------------------------------------------------------------===//

static bool verifyReduceOpAndType(gpu::AllReduceOperation opName,
Type resType) {
return (opName != gpu::AllReduceOperation::AND &&
opName != gpu::AllReduceOperation::OR &&
opName != gpu::AllReduceOperation::XOR) ||
llvm::isa<IntegerType>(resType);
static LogicalResult verifyReduceOpAndType(gpu::AllReduceOperation opName,
Type resType) {
using Kind = gpu::AllReduceOperation;
if (llvm::is_contained(
{Kind::MINF, Kind::MAXF, Kind::MINIMUMF, Kind::MAXIMUMF}, opName)) {
if (!isa<FloatType>(resType))
return failure();
}

if (llvm::is_contained({Kind::MINSI, Kind::MINUI, Kind::MAXSI, Kind::MAXUI,
Kind::AND, Kind::OR, Kind::XOR},
opName)) {
if (!isa<IntegerType>(resType))
return failure();
}

return success();
}

LogicalResult gpu::AllReduceOp::verifyRegions() {
Expand All @@ -517,12 +530,13 @@ LogicalResult gpu::AllReduceOp::verifyRegions() {
return emitError("expected gpu.yield op in region");
} else {
gpu::AllReduceOperation opName = *getOp();
if (!verifyReduceOpAndType(opName, getType())) {
return emitError()
<< '`' << gpu::stringifyAllReduceOperation(opName)
<< "` accumulator is only compatible with Integer type";
if (failed(verifyReduceOpAndType(opName, getType()))) {
return emitError() << '`' << gpu::stringifyAllReduceOperation(opName)
<< "` reduction operation is not compatible with type "
<< getType();
}
}

return success();
}

Expand Down Expand Up @@ -573,9 +587,10 @@ static void printAllReduceOperation(AsmPrinter &printer, Operation *op,

LogicalResult gpu::SubgroupReduceOp::verify() {
gpu::AllReduceOperation opName = getOp();
if (!verifyReduceOpAndType(opName, getType())) {
if (failed(verifyReduceOpAndType(opName, getType()))) {
return emitError() << '`' << gpu::stringifyAllReduceOperation(opName)
<< "` accumulator is only compatible with Integer type";
<< "` reduction operation is not compatible with type "
<< getType();
}
return success();
}
Expand Down
51 changes: 23 additions & 28 deletions mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -214,54 +214,49 @@ struct GpuAllReduceRewriter {

/// Returns an accumulator factory that creates an op specified by opName.
AccumulatorFactory getFactory(gpu::AllReduceOperation opName) {
using Kind = gpu::AllReduceOperation;
bool isFloatingPoint = isa<FloatType>(valueType);
switch (opName) {
case gpu::AllReduceOperation::ADD:
case Kind::ADD:
return isFloatingPoint ? getFactory<arith::AddFOp>()
: getFactory<arith::AddIOp>();
case gpu::AllReduceOperation::MUL:
case Kind::MUL:
return isFloatingPoint ? getFactory<arith::MulFOp>()
: getFactory<arith::MulIOp>();
case gpu::AllReduceOperation::AND:
case Kind::MINSI:
return getFactory<arith::MinSIOp>();
case Kind::MINUI:
return getFactory<arith::MinUIOp>();
case Kind::MINF:
return getFactory<arith::MinNumFOp>();
case Kind::MAXSI:
return getFactory<arith::MaxSIOp>();
case Kind::MAXUI:
return getFactory<arith::MaxUIOp>();
case Kind::MAXF:
return getFactory<arith::MaxNumFOp>();
case Kind::AND:
return getFactory<arith::AndIOp>();
case gpu::AllReduceOperation::OR:
case Kind::OR:
return getFactory<arith::OrIOp>();
case gpu::AllReduceOperation::XOR:
case Kind::XOR:
return getFactory<arith::XOrIOp>();
case gpu::AllReduceOperation::MAX:
return isFloatingPoint
? getCmpFactory<arith::CmpFOp, arith::CmpFPredicate,
arith::CmpFPredicate::UGT>()
: getCmpFactory<arith::CmpIOp, arith::CmpIPredicate,
arith::CmpIPredicate::ugt>();
case gpu::AllReduceOperation::MIN:
return isFloatingPoint
? getCmpFactory<arith::CmpFOp, arith::CmpFPredicate,
arith::CmpFPredicate::ULT>()
: getCmpFactory<arith::CmpIOp, arith::CmpIPredicate,
arith::CmpIPredicate::ult>();
case Kind::MINIMUMF:
return getFactory<arith::MinimumFOp>();
case Kind::MAXIMUMF:
return getFactory<arith::MaximumFOp>();
}
llvm_unreachable("unknown GPU AllReduceOperation");
}

/// Returns an accumulator factory that creates an op of type T.
template <typename T>
AccumulatorFactory getFactory() {
return [&](Value lhs, Value rhs) {
return [this](Value lhs, Value rhs) {
return create<T>(lhs.getType(), lhs, rhs);
};
}

/// Returns an accumulator for comparison such as min, max. T is the type
/// of the compare op.
template <typename T, typename PredicateEnum, PredicateEnum predicate>
AccumulatorFactory getCmpFactory() const {
return [&](Value lhs, Value rhs) {
Value cmp = rewriter.create<T>(loc, predicate, lhs, rhs);
return rewriter.create<arith::SelectOp>(loc, cmp, lhs, rhs);
};
}

/// Creates an if-block skeleton and calls the two factories to generate the
/// ops in the `then` and `else` block..
///
Expand Down
Loading