Skip to content

Commit 72003ad

Browse files
authored
[mlir][gpu] Allow subgroup reductions over 1-d vector types (#76015)
Each vector element is reduced independently, which is a form of multi-reduction. The plan is to allow for gradual lowering of multi-reduction that results in fewer `gpu.shuffle` ops at the end: 1d `vector.multi_reduction` --> 1d `gpu.subgroup_reduce` --> smaller 1d `gpu.subgroup_reduce` --> packed `gpu.shuffle` over i32 For example we can perform 2 independent f16 reductions with a series of `gpu.shuffles` over i32, reducing the final number of `gpu.shuffles` by 2x.
1 parent e6751c1 commit 72003ad

File tree

6 files changed

+84
-12
lines changed

6 files changed

+84
-12
lines changed

mlir/include/mlir/Dialect/GPU/IR/GPUOps.td

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,11 @@ include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.td"
1919
include "mlir/Dialect/GPU/IR/CompilationAttrs.td"
2020
include "mlir/Dialect/GPU/IR/ParallelLoopMapperAttr.td"
2121
include "mlir/Dialect/GPU/TransformOps/GPUDeviceMappingAttr.td"
22+
include "mlir/IR/CommonTypeConstraints.td"
2223
include "mlir/IR/EnumAttr.td"
23-
include "mlir/Interfaces/FunctionInterfaces.td"
2424
include "mlir/IR/SymbolInterfaces.td"
2525
include "mlir/Interfaces/DataLayoutInterfaces.td"
26+
include "mlir/Interfaces/FunctionInterfaces.td"
2627
include "mlir/Interfaces/InferIntRangeInterface.td"
2728
include "mlir/Interfaces/InferTypeOpInterface.td"
2829
include "mlir/Interfaces/SideEffectInterfaces.td"
@@ -1023,16 +1024,23 @@ def GPU_AllReduceOp : GPU_Op<"all_reduce",
10231024
let hasRegionVerifier = 1;
10241025
}
10251026

1027+
def AnyIntegerOrFloatOr1DVector :
1028+
AnyTypeOf<[AnyIntegerOrFloat, VectorOfRankAndType<[1], [AnyIntegerOrFloat]>]>;
1029+
10261030
def GPU_SubgroupReduceOp : GPU_Op<"subgroup_reduce", [SameOperandsAndResultType]> {
10271031
let summary = "Reduce values among subgroup.";
10281032
let description = [{
10291033
The `subgroup_reduce` op reduces the value of every work item across a
10301034
subgroup. The result is equal for all work items of a subgroup.
10311035

1036+
When the reduced value is of a vector type, each vector element is reduced
1037+
independently. Only 1-d vector types are allowed.
1038+
10321039
Example:
10331040

10341041
```mlir
1035-
%1 = gpu.subgroup_reduce add %0 : (f32) -> (f32)
1042+
%1 = gpu.subgroup_reduce add %a : (f32) -> (f32)
1043+
%2 = gpu.subgroup_reduce add %b : (vector<4xf16>) -> (vector<4xf16>)
10361044
```
10371045

10381046
If `uniform` flag is set either none or all work items of a subgroup
@@ -1045,11 +1053,11 @@ def GPU_SubgroupReduceOp : GPU_Op<"subgroup_reduce", [SameOperandsAndResultType]
10451053
}];
10461054

10471055
let arguments = (ins
1048-
AnyIntegerOrFloat:$value,
1056+
AnyIntegerOrFloatOr1DVector:$value,
10491057
GPU_AllReduceOperationAttr:$op,
10501058
UnitAttr:$uniform
10511059
);
1052-
let results = (outs AnyIntegerOrFloat:$result);
1060+
let results = (outs AnyIntegerOrFloatOr1DVector:$result);
10531061

10541062
let assemblyFormat = [{ custom<AllReduceOperation>($op) $value
10551063
(`uniform` $uniform^)? attr-dict

mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,12 @@
1616
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
1717
#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
1818
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
19+
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
1920
#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
2021
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
2122
#include "mlir/IR/BuiltinOps.h"
2223
#include "mlir/IR/Matchers.h"
24+
#include "mlir/Support/LogicalResult.h"
2325
#include "mlir/Transforms/DialectConversion.h"
2426
#include <optional>
2527

@@ -591,10 +593,12 @@ class GPUSubgroupReduceConversion final
591593
LogicalResult
592594
matchAndRewrite(gpu::SubgroupReduceOp op, OpAdaptor adaptor,
593595
ConversionPatternRewriter &rewriter) const override {
594-
auto opType = op.getOp();
595-
auto result =
596-
createGroupReduceOp(rewriter, op.getLoc(), adaptor.getValue(), opType,
597-
/*isGroup*/ false, op.getUniform());
596+
if (!isa<spirv::ScalarType>(adaptor.getValue().getType()))
597+
return rewriter.notifyMatchFailure(op, "reduction type is not a scalar");
598+
599+
auto result = createGroupReduceOp(rewriter, op.getLoc(), adaptor.getValue(),
600+
adaptor.getOp(),
601+
/*isGroup=*/false, adaptor.getUniform());
598602
if (!result)
599603
return failure();
600604

mlir/lib/Dialect/GPU/IR/GPUDialect.cpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "mlir/IR/BuiltinAttributes.h"
2020
#include "mlir/IR/BuiltinOps.h"
2121
#include "mlir/IR/BuiltinTypes.h"
22+
#include "mlir/IR/Diagnostics.h"
2223
#include "mlir/IR/DialectImplementation.h"
2324
#include "mlir/IR/Matchers.h"
2425
#include "mlir/IR/OpImplementation.h"
@@ -588,8 +589,16 @@ static void printAllReduceOperation(AsmPrinter &printer, Operation *op,
588589
//===----------------------------------------------------------------------===//
589590

590591
LogicalResult gpu::SubgroupReduceOp::verify() {
592+
Type elemType = getType();
593+
if (auto vecTy = dyn_cast<VectorType>(elemType)) {
594+
if (vecTy.isScalable())
595+
return emitOpError() << "is not compatible with scalable vector types";
596+
597+
elemType = vecTy.getElementType();
598+
}
599+
591600
gpu::AllReduceOperation opName = getOp();
592-
if (failed(verifyReduceOpAndType(opName, getType()))) {
601+
if (failed(verifyReduceOpAndType(opName, elemType))) {
593602
return emitError() << '`' << gpu::stringifyAllReduceOperation(opName)
594603
<< "` reduction operation is not compatible with type "
595604
<< getType();

mlir/test/Conversion/GPUToSPIRV/reductions.mlir

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -655,6 +655,26 @@ gpu.module @kernels {
655655

656656
// -----
657657

658+
module attributes {
659+
gpu.container_module,
660+
spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Kernel, Addresses, Groups, GroupNonUniformArithmetic, GroupUniformArithmeticKHR], []>, #spirv.resource_limits<>>
661+
} {
662+
663+
gpu.module @kernels {
664+
// CHECK-LABEL: spirv.func @test
665+
// CHECK-SAME: (%[[ARG:.*]]: i32)
666+
gpu.func @test(%arg : vector<1xi32>) kernel
667+
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
668+
// CHECK: %{{.*}} = spirv.GroupNonUniformSMax "Subgroup" "Reduce" %[[ARG]] : i32
669+
%r0 = gpu.subgroup_reduce maxsi %arg : (vector<1xi32>) -> (vector<1xi32>)
670+
gpu.return
671+
}
672+
}
673+
674+
}
675+
676+
// -----
677+
658678
// TODO: Handle boolean reductions.
659679

660680
module attributes {
@@ -751,3 +771,21 @@ gpu.module @kernels {
751771
}
752772
}
753773
}
774+
775+
// -----
776+
777+
// Vector reductions need to be lowered to scalar reductions first.
778+
779+
module attributes {
780+
gpu.container_module,
781+
spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Kernel, Addresses, Groups, GroupNonUniformArithmetic, GroupUniformArithmeticKHR], []>, #spirv.resource_limits<>>
782+
} {
783+
gpu.module @kernels {
784+
gpu.func @maxui(%arg : vector<2xi32>) kernel
785+
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
786+
// expected-error @+1 {{failed to legalize operation 'gpu.subgroup_reduce'}}
787+
%r0 = gpu.subgroup_reduce maxui %arg : (vector<2xi32>) -> (vector<2xi32>)
788+
gpu.return
789+
}
790+
}
791+
}

mlir/test/Dialect/GPU/invalid.mlir

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -333,9 +333,17 @@ func.func @reduce_invalid_op_type_maximumf(%arg0 : i32) {
333333

334334
// -----
335335

336-
func.func @subgroup_reduce_bad_type(%arg0 : vector<2xf32>) {
337-
// expected-error@+1 {{'gpu.subgroup_reduce' op operand #0 must be Integer or Float}}
338-
%res = gpu.subgroup_reduce add %arg0 : (vector<2xf32>) -> vector<2xf32>
336+
func.func @subgroup_reduce_bad_type(%arg0 : vector<2x2xf32>) {
337+
// expected-error@+1 {{'gpu.subgroup_reduce' op operand #0 must be Integer or Float or vector of}}
338+
%res = gpu.subgroup_reduce add %arg0 : (vector<2x2xf32>) -> vector<2x2xf32>
339+
return
340+
}
341+
342+
// -----
343+
344+
func.func @subgroup_reduce_bad_type_scalable(%arg0 : vector<[2]xf32>) {
345+
// expected-error@+1 {{is not compatible with scalable vector types}}
346+
%res = gpu.subgroup_reduce add %arg0 : (vector<[2]xf32>) -> vector<[2]xf32>
339347
return
340348
}
341349

mlir/test/Dialect/GPU/ops.mlir

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,8 @@ module attributes {gpu.container_module} {
8484

8585
%one = arith.constant 1.0 : f32
8686

87+
%vec = vector.broadcast %arg0 : f32 to vector<4xf32>
88+
8789
// CHECK: %{{.*}} = gpu.all_reduce add %{{.*}} {
8890
// CHECK-NEXT: } : (f32) -> f32
8991
%sum = gpu.all_reduce add %one {} : (f32) -> (f32)
@@ -98,6 +100,9 @@ module attributes {gpu.container_module} {
98100
// CHECK: %{{.*}} = gpu.subgroup_reduce add %{{.*}} uniform : (f32) -> f32
99101
%sum_subgroup1 = gpu.subgroup_reduce add %one uniform : (f32) -> f32
100102

103+
// CHECK: %{{.*}} = gpu.subgroup_reduce add %{{.*}} : (vector<4xf32>) -> vector<4xf32>
104+
%sum_subgroup2 = gpu.subgroup_reduce add %vec : (vector<4xf32>) -> vector<4xf32>
105+
101106
%width = arith.constant 7 : i32
102107
%offset = arith.constant 3 : i32
103108
// CHECK: gpu.shuffle xor %{{.*}}, %{{.*}}, %{{.*}} : f32

0 commit comments

Comments
 (0)