Skip to content

[mlir][gpu] Allow subgroup reductions over 1-d vector types #76015

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Dec 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,11 @@ include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.td"
include "mlir/Dialect/GPU/IR/CompilationAttrs.td"
include "mlir/Dialect/GPU/IR/ParallelLoopMapperAttr.td"
include "mlir/Dialect/GPU/TransformOps/GPUDeviceMappingAttr.td"
include "mlir/IR/CommonTypeConstraints.td"
include "mlir/IR/EnumAttr.td"
include "mlir/Interfaces/FunctionInterfaces.td"
include "mlir/IR/SymbolInterfaces.td"
include "mlir/Interfaces/DataLayoutInterfaces.td"
include "mlir/Interfaces/FunctionInterfaces.td"
include "mlir/Interfaces/InferIntRangeInterface.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
Expand Down Expand Up @@ -1022,16 +1023,23 @@ def GPU_AllReduceOp : GPU_Op<"all_reduce",
let hasRegionVerifier = 1;
}

def AnyIntegerOrFloatOr1DVector :
AnyTypeOf<[AnyIntegerOrFloat, VectorOfRankAndType<[1], [AnyIntegerOrFloat]>]>;

def GPU_SubgroupReduceOp : GPU_Op<"subgroup_reduce", [SameOperandsAndResultType]> {
let summary = "Reduce values among subgroup.";
let description = [{
The `subgroup_reduce` op reduces the value of every work item across a
subgroup. The result is equal for all work items of a subgroup.

When the reduced value is of a vector type, each vector element is reduced
independently. Only 1-d vector types are allowed.

Example:

```mlir
%1 = gpu.subgroup_reduce add %0 : (f32) -> (f32)
%1 = gpu.subgroup_reduce add %a : (f32) -> (f32)
%2 = gpu.subgroup_reduce add %b : (vector<4xf16>) -> (vector<4xf16>)
```

If `uniform` flag is set either none or all work items of a subgroup
Expand All @@ -1044,11 +1052,11 @@ def GPU_SubgroupReduceOp : GPU_Op<"subgroup_reduce", [SameOperandsAndResultType]
}];

let arguments = (ins
AnyIntegerOrFloat:$value,
AnyIntegerOrFloatOr1DVector:$value,
GPU_AllReduceOperationAttr:$op,
UnitAttr:$uniform
);
let results = (outs AnyIntegerOrFloat:$result);
let results = (outs AnyIntegerOrFloatOr1DVector:$result);

let assemblyFormat = [{ custom<AllReduceOperation>($op) $value
(`uniform` $uniform^)? attr-dict
Expand Down
12 changes: 8 additions & 4 deletions mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h"
#include <optional>

Expand Down Expand Up @@ -591,10 +593,12 @@ class GPUSubgroupReduceConversion final
LogicalResult
matchAndRewrite(gpu::SubgroupReduceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto opType = op.getOp();
auto result =
createGroupReduceOp(rewriter, op.getLoc(), adaptor.getValue(), opType,
/*isGroup*/ false, op.getUniform());
if (!isa<spirv::ScalarType>(adaptor.getValue().getType()))
return rewriter.notifyMatchFailure(op, "reduction type is not a scalar");

auto result = createGroupReduceOp(rewriter, op.getLoc(), adaptor.getValue(),
adaptor.getOp(),
/*isGroup=*/false, adaptor.getUniform());
if (!result)
return failure();

Expand Down
11 changes: 10 additions & 1 deletion mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpImplementation.h"
Expand Down Expand Up @@ -588,8 +589,16 @@ static void printAllReduceOperation(AsmPrinter &printer, Operation *op,
//===----------------------------------------------------------------------===//

LogicalResult gpu::SubgroupReduceOp::verify() {
Type elemType = getType();
if (auto vecTy = dyn_cast<VectorType>(elemType)) {
if (vecTy.isScalable())
return emitOpError() << "is not compatible with scalable vector types";

elemType = vecTy.getElementType();
}

gpu::AllReduceOperation opName = getOp();
if (failed(verifyReduceOpAndType(opName, getType()))) {
if (failed(verifyReduceOpAndType(opName, elemType))) {
return emitError() << '`' << gpu::stringifyAllReduceOperation(opName)
<< "` reduction operation is not compatible with type "
<< getType();
Expand Down
38 changes: 38 additions & 0 deletions mlir/test/Conversion/GPUToSPIRV/reductions.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -655,6 +655,26 @@ gpu.module @kernels {

// -----

module attributes {
gpu.container_module,
spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Kernel, Addresses, Groups, GroupNonUniformArithmetic, GroupUniformArithmeticKHR], []>, #spirv.resource_limits<>>
} {

gpu.module @kernels {
// CHECK-LABEL: spirv.func @test
// CHECK-SAME: (%[[ARG:.*]]: i32)
gpu.func @test(%arg : vector<1xi32>) kernel
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
// CHECK: %{{.*}} = spirv.GroupNonUniformSMax "Subgroup" "Reduce" %[[ARG]] : i32
%r0 = gpu.subgroup_reduce maxsi %arg : (vector<1xi32>) -> (vector<1xi32>)
gpu.return
}
}

}

// -----

// TODO: Handle boolean reductions.

module attributes {
Expand Down Expand Up @@ -751,3 +771,21 @@ gpu.module @kernels {
}
}
}

// -----

// Vector reductions need to be lowered to scalar reductions first.

module attributes {
gpu.container_module,
spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Kernel, Addresses, Groups, GroupNonUniformArithmetic, GroupUniformArithmeticKHR], []>, #spirv.resource_limits<>>
} {
gpu.module @kernels {
gpu.func @maxui(%arg : vector<2xi32>) kernel
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
// expected-error @+1 {{failed to legalize operation 'gpu.subgroup_reduce'}}
%r0 = gpu.subgroup_reduce maxui %arg : (vector<2xi32>) -> (vector<2xi32>)
gpu.return
}
}
}
14 changes: 11 additions & 3 deletions mlir/test/Dialect/GPU/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -333,9 +333,17 @@ func.func @reduce_invalid_op_type_maximumf(%arg0 : i32) {

// -----

func.func @subgroup_reduce_bad_type(%arg0 : vector<2xf32>) {
// expected-error@+1 {{'gpu.subgroup_reduce' op operand #0 must be Integer or Float}}
%res = gpu.subgroup_reduce add %arg0 : (vector<2xf32>) -> vector<2xf32>
func.func @subgroup_reduce_bad_type(%arg0 : vector<2x2xf32>) {
// expected-error@+1 {{'gpu.subgroup_reduce' op operand #0 must be Integer or Float or vector of}}
%res = gpu.subgroup_reduce add %arg0 : (vector<2x2xf32>) -> vector<2x2xf32>
return
}

// -----

func.func @subgroup_reduce_bad_type_scalable(%arg0 : vector<[2]xf32>) {
// expected-error@+1 {{is not compatible with scalable vector types}}
%res = gpu.subgroup_reduce add %arg0 : (vector<[2]xf32>) -> vector<[2]xf32>
return
}

Expand Down
5 changes: 5 additions & 0 deletions mlir/test/Dialect/GPU/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ module attributes {gpu.container_module} {

%one = arith.constant 1.0 : f32

%vec = vector.broadcast %arg0 : f32 to vector<4xf32>

// CHECK: %{{.*}} = gpu.all_reduce add %{{.*}} {
// CHECK-NEXT: } : (f32) -> f32
%sum = gpu.all_reduce add %one {} : (f32) -> (f32)
Expand All @@ -98,6 +100,9 @@ module attributes {gpu.container_module} {
// CHECK: %{{.*}} = gpu.subgroup_reduce add %{{.*}} uniform : (f32) -> f32
%sum_subgroup1 = gpu.subgroup_reduce add %one uniform : (f32) -> f32

// CHECK: %{{.*}} = gpu.subgroup_reduce add %{{.*}} : (vector<4xf32>) -> vector<4xf32>
%sum_subgroup2 = gpu.subgroup_reduce add %vec : (vector<4xf32>) -> vector<4xf32>

%width = arith.constant 7 : i32
%offset = arith.constant 3 : i32
// CHECK: gpu.shuffle xor %{{.*}}, %{{.*}}, %{{.*}} : f32
Expand Down