Skip to content

Commit 2c7827d

Browse files
committed
[mlir][spirv] Add GPU subgroup MMA to spirv.MMAMatrixTimesScalar
Along the way, make the default pattern fail instead of crashing when an elementwise op is not supported yet. Reviewed By: kuhar Differential Revision: https://reviews.llvm.org/D139280
1 parent 96d6399 commit 2c7827d

File tree

2 files changed

+109
-40
lines changed

2 files changed

+109
-40
lines changed

mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp

Lines changed: 78 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -24,42 +24,47 @@
2424

2525
using namespace mlir;
2626

27-
// See SPV_NV_cooperative_matrix for supported element wise ops.
28-
static void createElementWiseOp(ConversionPatternRewriter &builder,
27+
/// Creates a SPIR-V op to replace the given GPU subgroup mma elementwise op
28+
/// when the elementwise op directly supports with cooperative matrix type.
29+
/// Returns false if cannot.
30+
///
31+
/// See SPV_NV_cooperative_matrix for supported elementwise ops.
32+
static bool createElementwiseOp(ConversionPatternRewriter &builder,
2933
gpu::SubgroupMmaElementwiseOp op,
3034
spirv::CooperativeMatrixNVType coopType,
3135
ValueRange operands) {
3236
switch (op.getOpType()) {
3337
case gpu::MMAElementwiseOp::ADDF:
3438
builder.replaceOpWithNewOp<spirv::FAddOp>(op, coopType, operands);
35-
return;
39+
return true;
3640
case gpu::MMAElementwiseOp::ADDI:
3741
builder.replaceOpWithNewOp<spirv::IAddOp>(op, coopType, operands);
38-
return;
42+
return true;
3943
case gpu::MMAElementwiseOp::SUBF:
4044
builder.replaceOpWithNewOp<spirv::FSubOp>(op, coopType, operands);
41-
return;
45+
return true;
4246
case gpu::MMAElementwiseOp::SUBI:
4347
builder.replaceOpWithNewOp<spirv::ISubOp>(op, coopType, operands);
44-
return;
48+
return true;
4549
case gpu::MMAElementwiseOp::DIVF:
4650
builder.replaceOpWithNewOp<spirv::FDivOp>(op, coopType, operands);
47-
return;
51+
return true;
4852
case gpu::MMAElementwiseOp::DIVS:
4953
builder.replaceOpWithNewOp<spirv::SDivOp>(op, coopType, operands);
50-
return;
54+
return true;
5155
case gpu::MMAElementwiseOp::DIVU:
5256
builder.replaceOpWithNewOp<spirv::UDivOp>(op, coopType, operands);
53-
return;
57+
return true;
5458
case gpu::MMAElementwiseOp::NEGATEF:
5559
builder.replaceOpWithNewOp<spirv::FNegateOp>(op, coopType, operands);
56-
return;
60+
return true;
5761
case gpu::MMAElementwiseOp::NEGATES:
5862
builder.replaceOpWithNewOp<spirv::SNegateOp>(op, coopType, operands);
59-
return;
63+
return true;
6064
default:
61-
llvm_unreachable("unknown op");
65+
break;
6266
}
67+
return false;
6368
}
6469

6570
namespace {
@@ -163,13 +168,14 @@ struct WmmaConstantOpToSPIRVLowering
163168
}
164169
};
165170

166-
/// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops.
167-
struct WmmaElementwiseOpToSPIRVLowering
171+
/// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for
172+
/// the default case.
173+
struct WmmaElementwiseOpToSPIRVDefaultLowering
168174
: public OpConversionPattern<gpu::SubgroupMmaElementwiseOp> {
169175
using OpConversionPattern::OpConversionPattern;
170176

171177
LogicalResult
172-
matchAndRewrite(gpu::SubgroupMmaElementwiseOp subgroupMmaElementwiseOp,
178+
matchAndRewrite(gpu::SubgroupMmaElementwiseOp elementwiseOp,
173179
OpAdaptor adaptor,
174180
ConversionPatternRewriter &rewriter) const override {
175181
// All operands should be of cooperative matrix types.
@@ -178,9 +184,58 @@ struct WmmaElementwiseOpToSPIRVLowering
178184
return failure();
179185
}
180186
auto coopType = convertMMAToSPIRVType(
181-
subgroupMmaElementwiseOp.getType().cast<gpu::MMAMatrixType>());
182-
createElementWiseOp(rewriter, subgroupMmaElementwiseOp, coopType,
183-
adaptor.getOperands());
187+
elementwiseOp.getType().cast<gpu::MMAMatrixType>());
188+
return success(createElementwiseOp(rewriter, elementwiseOp, coopType,
189+
adaptor.getOperands()));
190+
}
191+
};
192+
193+
/// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for
194+
/// matrix times scalar case.
195+
struct WmmaElementwiseOpToSPIRVScalarMulLowering
196+
: public OpConversionPattern<gpu::SubgroupMmaElementwiseOp> {
197+
using OpConversionPattern::OpConversionPattern;
198+
199+
LogicalResult
200+
matchAndRewrite(gpu::SubgroupMmaElementwiseOp elementwiseOp,
201+
OpAdaptor adaptor,
202+
ConversionPatternRewriter &rewriter) const override {
203+
if (adaptor.getOperands().size() != 2)
204+
return failure();
205+
// All operands should be of cooperative matrix types.
206+
for (Value operand : adaptor.getOperands()) {
207+
if (!operand.getType().isa<spirv::CooperativeMatrixNVType>())
208+
return failure();
209+
}
210+
211+
// Use the original operands to check whether one of the operands is a splat
212+
// scalar value.
213+
Value lhs = elementwiseOp.getOperands().front();
214+
Value rhs = elementwiseOp.getOperands().back();
215+
Value splat = nullptr;
216+
Value matrix = nullptr;
217+
if (lhs.getDefiningOp<gpu::SubgroupMmaConstantMatrixOp>()) {
218+
splat = adaptor.getOperands().front();
219+
matrix = adaptor.getOperands().back();
220+
} else if (rhs.getDefiningOp<gpu::SubgroupMmaConstantMatrixOp>()) {
221+
matrix = adaptor.getOperands().front();
222+
splat = adaptor.getOperands().back();
223+
}
224+
if (!splat || !matrix)
225+
return failure();
226+
227+
// Constant MMA matrix ops are converted to spirv.CompositeConstruct ops.
228+
Value scalar = nullptr;
229+
auto cc = splat.getDefiningOp<spirv::CompositeConstructOp>();
230+
if (!cc)
231+
return failure();
232+
assert(cc.getConstituents().size() == 1);
233+
scalar = cc.getConstituents().front();
234+
235+
auto coopType = convertMMAToSPIRVType(
236+
elementwiseOp.getType().cast<gpu::MMAMatrixType>());
237+
rewriter.replaceOpWithNewOp<spirv::MatrixTimesScalarOp>(
238+
elementwiseOp, coopType, ValueRange{matrix, scalar});
184239
return success();
185240
}
186241
};
@@ -198,8 +253,11 @@ mlir::convertMMAToSPIRVType(gpu::MMAMatrixType type) {
198253

199254
void mlir::populateGpuWMMAToSPIRVConversionPatterns(
200255
SPIRVTypeConverter &converter, RewritePatternSet &patterns) {
256+
MLIRContext *context = patterns.getContext();
201257
patterns.add<WmmaLoadOpToSPIRVLowering, WmmaMmaOpToSPIRVLowering,
202258
WmmaStoreOpToSPIRVLowering, WmmaConstantOpToSPIRVLowering,
203-
WmmaElementwiseOpToSPIRVLowering>(converter,
204-
patterns.getContext());
259+
WmmaElementwiseOpToSPIRVDefaultLowering>(converter, context);
260+
// Give the following patterns higher benefit to prevail over the default one.
261+
patterns.add<WmmaElementwiseOpToSPIRVScalarMulLowering>(converter, context,
262+
/*benefit=*/2);
205263
}

mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv.mlir

Lines changed: 31 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,8 @@ module attributes {
44
gpu.container_module,
55
spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, CooperativeMatrixNV, Float16], [SPV_KHR_storage_buffer_storage_class, SPV_NV_cooperative_matrix]>, #spirv.resource_limits<>>} {
66
gpu.module @kernels {
7-
// CHECK: spirv.module @{{.*}} Logical GLSL450 {
87
// CHECK-LABEL: spirv.func @gpu_wmma_load_op
9-
// CHECK-SAME: {{%.*}}: !spirv.ptr<!spirv.struct<(!spirv.array<512 x f32, stride=4> [0])>, StorageBuffer> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>}
10-
// CHECK-SAME: spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>
8+
// CHECK-SAME: !spirv.ptr<!spirv.struct<(!spirv.array<512 x f32, stride=4> [0])>, StorageBuffer>
119
gpu.func @gpu_wmma_load_op(%arg0 : memref<32x32xf16, #spirv.storage_class<StorageBuffer>>) kernel
1210
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
1311
%i = arith.constant 16 : index
@@ -27,7 +25,6 @@ module attributes {
2725
gpu.container_module,
2826
spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, CooperativeMatrixNV, Float16], [SPV_KHR_storage_buffer_storage_class, SPV_NV_cooperative_matrix]>, #spirv.resource_limits<>>} {
2927
gpu.module @kernels {
30-
// CHECK: spirv.module @{{.*}} Logical GLSL450 {
3128
// CHECK-LABEL: spirv.func @gpu_wmma_load_op_transpose
3229
// CHECK-SAME: {{%.*}}: !spirv.ptr<!spirv.struct<(!spirv.array<512 x f32, stride=4> [0])>, StorageBuffer> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>}
3330
// CHECK-SAME: spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>
@@ -50,11 +47,9 @@ module attributes {
5047
gpu.container_module,
5148
spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, CooperativeMatrixNV, Float16], [SPV_KHR_storage_buffer_storage_class, SPV_NV_cooperative_matrix]>, #spirv.resource_limits<>>} {
5249
gpu.module @kernels {
53-
// CHECK: spirv.module @{{.*}} Logical GLSL450 {
5450
// CHECK-LABEL: spirv.func @gpu_wmma_store_op
55-
// CHECK-SAME: {{%.*}}: !spirv.ptr<!spirv.struct<(!spirv.array<512 x f32, stride=4> [0])>, StorageBuffer> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>}
56-
// CHECK-SAME: {{%.*}}: !spirv.coopmatrix<16x16xf16, Subgroup> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>})
57-
// CHECK-SAME: spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>
51+
// CHECK-SAME: !spirv.ptr<!spirv.struct<(!spirv.array<512 x f32, stride=4> [0])>, StorageBuffer>
52+
// CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup>
5853
gpu.func @gpu_wmma_store_op(%arg0 : memref<32x32xf16, #spirv.storage_class<StorageBuffer>>, %arg1 : !gpu.mma_matrix<16x16xf16, "COp">) kernel
5954
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
6055
%i = arith.constant 16 : index
@@ -74,7 +69,6 @@ module attributes {
7469
gpu.container_module,
7570
spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, CooperativeMatrixNV, Float16], [SPV_KHR_storage_buffer_storage_class, SPV_NV_cooperative_matrix]>, #spirv.resource_limits<>>} {
7671
gpu.module @kernels {
77-
// CHECK: spirv.module @{{.*}} Logical GLSL450 {
7872
// CHECK-LABEL: spirv.func @gpu_wmma_store_op_transpose
7973
// CHECK-SAME: {{%.*}}: !spirv.ptr<!spirv.struct<(!spirv.array<512 x f32, stride=4> [0])>, StorageBuffer> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>}
8074
// CHECK-SAME: {{%.*}}: !spirv.coopmatrix<16x16xf16, Subgroup> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>})
@@ -98,12 +92,10 @@ module attributes {
9892
gpu.container_module,
9993
spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, CooperativeMatrixNV, Float16], [SPV_KHR_storage_buffer_storage_class, SPV_NV_cooperative_matrix]>, #spirv.resource_limits<>>} {
10094
gpu.module @kernels {
101-
// CHECK: spirv.module @{{.*}} Logical GLSL450 {
10295
// CHECK-LABEL: spirv.func @gpu_wmma_mma_op
103-
// CHECK-SAME: {{%.*}}: !spirv.coopmatrix<16x16xf16, Subgroup> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>}
104-
// CHECK-SAME: {{%.*}}: !spirv.coopmatrix<16x16xf16, Subgroup> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>}
105-
// CHECK-SAME: {{%.*}}: !spirv.coopmatrix<16x16xf16, Subgroup> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 2)>})
106-
// CHECK-SAME: spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>
96+
// CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup>
97+
// CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup>
98+
// CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup>
10799
gpu.func @gpu_wmma_mma_op(%A : !gpu.mma_matrix<16x16xf16, "AOp">, %B : !gpu.mma_matrix<16x16xf16, "BOp">, %C : !gpu.mma_matrix<16x16xf16, "COp">) kernel
108100
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
109101
// CHECK: {{%.*}} = spirv.NV.CooperativeMatrixMulAdd {{%.*}}, {{%.*}}, {{%.*}} : !spirv.coopmatrix<16x16xf16, Subgroup>, !spirv.coopmatrix<16x16xf16, Subgroup> -> !spirv.coopmatrix<16x16xf16, Subgroup>
@@ -120,7 +112,6 @@ module attributes {
120112
gpu.container_module,
121113
spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, CooperativeMatrixNV, Float16], [SPV_KHR_storage_buffer_storage_class, SPV_NV_cooperative_matrix]>, #spirv.resource_limits<>>} {
122114
gpu.module @kernels {
123-
// CHECK: spirv.module @{{.*}} Logical GLSL450 {
124115
// CHECK-LABEL: spirv.func @gpu_wmma_constant_op
125116
gpu.func @gpu_wmma_constant_op() kernel
126117
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
@@ -140,11 +131,10 @@ module attributes {
140131
gpu.container_module,
141132
spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, CooperativeMatrixNV, Float16], [SPV_KHR_storage_buffer_storage_class, SPV_NV_cooperative_matrix]>, #spirv.resource_limits<>>} {
142133
gpu.module @kernels {
143-
// CHECK: spirv.module @{{.*}} Logical GLSL450 {
144-
// CHECK-LABEL: spirv.func @gpu_wmma_elementwise_op
145-
// CHECK-SAME: {{%.*}}: !spirv.coopmatrix<16x16xf16, Subgroup> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>}
146-
// CHECK-SAME: {{%.*}}: !spirv.coopmatrix<16x16xf16, Subgroup> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>})
147-
gpu.func @gpu_wmma_elementwise_op(%A : !gpu.mma_matrix<16x16xf16, "COp">, %B : !gpu.mma_matrix<16x16xf16, "COp">) kernel
134+
// CHECK-LABEL: spirv.func @gpu_wmma_elementwise_op_default
135+
// CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup>
136+
// CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup>
137+
gpu.func @gpu_wmma_elementwise_op_default(%A : !gpu.mma_matrix<16x16xf16, "COp">, %B : !gpu.mma_matrix<16x16xf16, "COp">) kernel
148138
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
149139
// CHECK: {{%.*}} = spirv.FAdd {{%.*}}, {{%.*}} : !spirv.coopmatrix<16x16xf16, Subgroup>
150140
%C = gpu.subgroup_mma_elementwise addf %A, %B : (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
@@ -157,3 +147,24 @@ module attributes {
157147
}
158148
}
159149
}
150+
151+
// -----
152+
153+
module attributes {
154+
gpu.container_module,
155+
spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, CooperativeMatrixNV, Float16], [SPV_KHR_storage_buffer_storage_class, SPV_NV_cooperative_matrix]>, #spirv.resource_limits<>>} {
156+
gpu.module @kernels {
157+
// CHECK-LABEL: spirv.func @gpu_wmma_elementwise_op_matrix_times_scalar
158+
// CHECK-SAME: %[[A:.+]]: !spirv.coopmatrix<16x16xf16, Subgroup>
159+
// CHECK-SAME: %[[S:.+]]: f16
160+
gpu.func @gpu_wmma_elementwise_op_matrix_times_scalar(%A : !gpu.mma_matrix<16x16xf16, "COp">, %scalar : f16) kernel
161+
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
162+
%B = gpu.subgroup_mma_constant_matrix %scalar : !gpu.mma_matrix<16x16xf16, "COp">
163+
// CHECK: %{{.+}} = spirv.MatrixTimesScalar %[[A]], %[[S]] : !spirv.coopmatrix<16x16xf16, Subgroup>, f16
164+
%C = gpu.subgroup_mma_elementwise mulf %A, %B : (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
165+
// CHECK: %{{.+}} = spirv.MatrixTimesScalar %[[A]], %[[S]] : !spirv.coopmatrix<16x16xf16, Subgroup>, f16
166+
%D = gpu.subgroup_mma_elementwise mulf %B, %A : (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
167+
gpu.return
168+
}
169+
}
170+
}

0 commit comments

Comments
 (0)