Skip to content

Commit c0321ed

Browse files
qedawkinsThomasRaoux
authored andcommitted
[mlir][gpu] Adding support for transposed mma_load_matrix
Enables transposed gpu.subgroup_mma_load_matrix and updates the lowerings in Vector to GPU and GPU to SPIRV. Needed to enable B transpose matmuls lowering to wmma ops. Taken over from author: stanley-nod <[email protected]> Reviewed By: ThomasRaoux, antiagainst Differential Revision: https://reviews.llvm.org/D138770
1 parent f49d069 commit c0321ed

File tree

5 files changed

+49
-9
lines changed

5 files changed

+49
-9
lines changed

mlir/include/mlir/Dialect/GPU/IR/GPUOps.td

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1102,7 +1102,7 @@ def GPU_SubgroupMmaLoadMatrixOp : GPU_Op<"subgroup_mma_load_matrix",
11021102
determined using `indices`. The matrix being loaded into is the result. The
11031103
`leadDimension` attribute specifies the leading dimension size of the source
11041104
matrix which eventually allows the lowering to determine the size of each
1105-
row.
1105+
row. If the `transpose` attribute is present then the op does a transposed load.
11061106

11071107
This op is often meant to be used along with `gpu.subgroup_mma_store_matrix` and
11081108
`gpu.subgroup_mma_compute`.
@@ -1117,7 +1117,8 @@ def GPU_SubgroupMmaLoadMatrixOp : GPU_Op<"subgroup_mma_load_matrix",
11171117

11181118
let arguments = (ins Arg<GPU_MMAMemRef, "", [MemRead]>:$srcMemref,
11191119
Variadic<Index>:$indices,
1120-
IndexAttr:$leadDimension);
1120+
IndexAttr:$leadDimension,
1121+
OptionalAttr<UnitAttr>:$transpose);
11211122

11221123
let results = (outs GPU_MMAMatrix:$res);
11231124

mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,10 @@ struct WmmaLoadOpToNVVMLowering
7777
if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)))
7878
return failure();
7979

80+
// TODO: Support transposed mma loads.
81+
if (subgroupMmaLoadMatrixOp.getTranspose())
82+
return failure();
83+
8084
// Get the shape of the MMAMatrix type being returned. The shape will
8185
// choose which intrinsic this op will be lowered to.
8286
gpu::MMAMatrixType retType =

mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,10 +87,12 @@ struct WmmaLoadOpToSPIRVLowering
8787
auto i32Type = rewriter.getI32Type();
8888
auto strideValue = rewriter.create<spirv::ConstantOp>(
8989
loc, i32Type, IntegerAttr::get(i32Type, stride));
90-
auto coloumnMajor = rewriter.create<spirv::ConstantOp>(
91-
loc, rewriter.getI1Type(), rewriter.getBoolAttr(false));
90+
bool useColMajor =
91+
static_cast<bool>(subgroupMmaLoadMatrixOp.getTranspose());
92+
auto columnMajor = rewriter.create<spirv::ConstantOp>(
93+
loc, rewriter.getI1Type(), rewriter.getBoolAttr(useColMajor));
9294
rewriter.replaceOpWithNewOp<spirv::NVCooperativeMatrixLoadOp>(
93-
subgroupMmaLoadMatrixOp, coopType, bufferPtr, strideValue, coloumnMajor,
95+
subgroupMmaLoadMatrixOp, coopType, bufferPtr, strideValue, columnMajor,
9496
spirv::MemoryAccessAttr());
9597
return success();
9698
}

mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,19 @@ static bool contractSupportsMMAMatrixType(vector::ContractionOp contract,
9292
return true;
9393
}
9494

95+
// Return true if the given map represents a transposed matrix load,
96+
// i.e. (d0, d1, ...) -> (dn-1, dn-2).
97+
static bool isTransposeMatrixLoadMap(OpBuilder &b, AffineMap permutationMap) {
98+
auto nDim = permutationMap.getNumDims();
99+
if (nDim < 2)
100+
return false;
101+
102+
AffineExpr innerDim = b.getAffineDimExpr(nDim - 1);
103+
AffineExpr outerDim = b.getAffineDimExpr(nDim - 2);
104+
return permutationMap ==
105+
AffineMap::get(nDim, 0, {innerDim, outerDim}, b.getContext());
106+
}
107+
95108
// Return the stide for the dimension 0 of |type| if it is a memref and has a
96109
// constant stride.
97110
static std::optional<int64_t>
@@ -129,9 +142,9 @@ static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp,
129142
readOp.getContext());
130143

131144
if (!useNvGpu) {
132-
// TODO: Support transpose once it is added to GPU dialect ops.
133-
// For now we only support (d0, d1) -> (d0, d1) and (d0, d1) -> (0, d1).
134-
return map.isMinorIdentity() || map == broadcastInnerDim;
145+
bool result = map.isMinorIdentity() || map == broadcastInnerDim ||
146+
isTransposeMatrixLoadMap(b, map);
147+
return result;
135148
}
136149

137150
return true;
@@ -445,9 +458,10 @@ static void convertTransferReadOp(vector::TransferReadOp op,
445458
gpu::MMAMatrixType::get(op.getVectorType().getShape(),
446459
op.getVectorType().getElementType(), fragType);
447460
OpBuilder b(op);
461+
bool isTranspose = isTransposeMatrixLoadMap(b, map);
448462
Value load = b.create<gpu::SubgroupMmaLoadMatrixOp>(
449463
op.getLoc(), type, op.getSource(), op.getIndices(),
450-
b.getIndexAttr(*stride));
464+
b.getIndexAttr(*stride), isTranspose ? b.getUnitAttr() : UnitAttr());
451465
valueMapping[op.getResult()] = load;
452466
}
453467

mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
66
#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
77
#map4 = affine_map<(d0) -> (d0, 0)>
8+
#map5 = affine_map<(d0, d1) -> (d0, d1)>
89

910
// CHECK-LABEL: func @matmul
1011
// CHECK-DAG: %[[A:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "AOp">
@@ -170,3 +171,21 @@ func.func @matmul_memref_strided(%arg0: memref<2x16x16xf16, affine_map<(d0, d1,
170171
vector.transfer_write %D, %arg2[%c0, %c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<2x16x16xf16>
171172
return
172173
}
174+
175+
// CHECK-LABEL: func @matmul_transposed
176+
// CHECK-DAG: %[[A:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "AOp">
177+
// CHECK-DAG: %[[B:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index, transpose} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "BOp">
178+
// CHECK-DAG: %[[C:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "COp">
179+
// CHECK: %[[D:.+]] = gpu.subgroup_mma_compute %[[A]], %[[B]], %[[C]] : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf16, "COp">
180+
// CHECK: gpu.subgroup_mma_store_matrix %[[D]], %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<16x16xf16>
181+
func.func @matmul_transposed(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<16x16xf16>) {
182+
%cst_0 = arith.constant dense<0.000000e+00> : vector<16x16xf16>
183+
%c0 = arith.constant 0 : index
184+
%cst = arith.constant 0.000000e+00 : f16
185+
%A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16>
186+
%B = vector.transfer_read %arg1[%c0, %c0], %cst {permutation_map = #map5, in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16>
187+
%C = vector.transfer_read %arg2[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16>
188+
%D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B, %C : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
189+
vector.transfer_write %D, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<16x16xf16>
190+
return
191+
}

0 commit comments

Comments
 (0)