Skip to content

Commit cafb628

Browse files
[mlir][VectorToGPU] Update memref stride preconditions on nvgpu.mma.sync path
This change removes the requirement that the row stride be statically known when converting `vector.transfer_read` and `vector.transfer_write` to distributed SIMT operations in the `nvgpu` lowering path. It also adds a check to verify that the last dimension of the source memref is statically known to have stride 1 since this is assumed in the conversion logic. No other change should be required since the generated `vector.load` operations are never created across dimensions other than the last. The routines for checking preconditions on `vector.transfer_read/write` are moved to under nvgpu utilities. The change is NFC with respect to the GPU dialect lowering path. Reviewed By: ThomasRaoux Differential Revision: https://reviews.llvm.org/D155753
1 parent e8ad9b0 commit cafb628

File tree

4 files changed

+214
-28
lines changed

4 files changed

+214
-28
lines changed

mlir/include/mlir/Dialect/NVGPU/Utils/MMAUtils.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,18 @@ FailureOr<AffineMap>
9393
getLaneIdToLdMatrixMatrixCoord(OpBuilder &builder, Location loc,
9494
const LdMatrixParams &params);
9595

96+
/// Returns whether the `vector.transfer_read` instruction can be interpreted
97+
/// as a warp-level cooperative matrix load operation. This function is meant to
98+
/// be used to establish whether `op` is part of a chain of such warp-level
99+
/// operations.
100+
bool canLowerToWarpMatrixOperation(vector::TransferReadOp op);
101+
102+
/// Returns whether the `vector.transfer_write` instruction can be interpreted
103+
/// as a warp-level cooperative matrix store operation. This function is meant
104+
/// to be used to establish whether `op` is part of a chain of such warp-level
105+
/// operations.
106+
bool canLowerToWarpMatrixOperation(vector::TransferWriteOp op);
107+
96108
} // namespace nvgpu
97109
} // namespace mlir
98110

mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp

Lines changed: 27 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -119,10 +119,9 @@ static bool isTransposeMatrixLoadMap(AffineMap permutationMap) {
119119
permutationMap == AffineMap::get(nDim, 0, {innerDim, zero}, ctx);
120120
}
121121

122-
// Return the stide for the dimension 0 of |type| if it is a memref and has a
123-
// constant stride.
124-
static std::optional<int64_t>
125-
getMemrefConstantHorizontalStride(ShapedType type) {
122+
// Return the stide for the second-to-last dimension of |type| if it is a memref
123+
// and has a constant stride.
124+
static std::optional<int64_t> getStaticallyKnownRowStride(ShapedType type) {
126125
auto memrefType = dyn_cast<MemRefType>(type);
127126
if (!memrefType)
128127
return false;
@@ -141,35 +140,27 @@ getMemrefConstantHorizontalStride(ShapedType type) {
141140
}
142141

143142
// Return true if the transfer op can be converted to a MMA matrix load.
144-
static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp,
145-
bool useNvGpu) {
143+
static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp) {
146144
if (readOp.getMask() || readOp.hasOutOfBoundsDim() ||
147145
readOp.getVectorType().getRank() != 2)
148146
return false;
149-
if (!getMemrefConstantHorizontalStride(readOp.getShapedType()))
147+
if (!getStaticallyKnownRowStride(readOp.getShapedType()))
150148
return false;
151149

152150
// Only allow integer types if the signedness can be inferred.
153-
if (!useNvGpu && readOp.getVectorType().getElementType().isInteger(8))
151+
if (readOp.getVectorType().getElementType().isInteger(8))
154152
if (!readOp->hasOneUse() || (!isa<arith::ExtSIOp>(*readOp->user_begin()) &&
155153
!isa<arith::ExtUIOp>(*readOp->user_begin())))
156154
return false;
157155

158156
AffineMap map = readOp.getPermutationMap();
159-
160157
MLIRContext *ctx = readOp.getContext();
161158
AffineExpr innerDim = getAffineDimExpr(map.getNumDims() - 1, ctx);
162159
AffineExpr zero = getAffineConstantExpr(0, ctx);
163160
auto broadcastInnerDim =
164161
AffineMap::get(map.getNumDims(), 0, {zero, innerDim}, ctx);
165-
166-
if (!useNvGpu) {
167-
bool result = map.isMinorIdentity() || map == broadcastInnerDim ||
168-
isTransposeMatrixLoadMap(map);
169-
return result;
170-
}
171-
172-
return true;
162+
return map.isMinorIdentity() || map == broadcastInnerDim ||
163+
isTransposeMatrixLoadMap(map);
173164
}
174165

175166
// Return true if the transfer op can be converted to a MMA matrix store.
@@ -182,7 +173,7 @@ transferWriteSupportsMMAMatrixType(vector::TransferWriteOp writeOp) {
182173
if (writeOp.getMask() || writeOp.hasOutOfBoundsDim() ||
183174
writeOp.getVectorType().getRank() != 2)
184175
return false;
185-
if (!getMemrefConstantHorizontalStride(writeOp.getShapedType()))
176+
if (!getStaticallyKnownRowStride(writeOp.getShapedType()))
186177
return false;
187178
// TODO: Support transpose once it is added to GPU dialect ops.
188179
if (!writeOp.getPermutationMap().isMinorIdentity())
@@ -285,9 +276,11 @@ static bool supportsMMaMatrixType(Operation *op, bool useNvGpu) {
285276
if (isa<scf::ForOp, scf::YieldOp>(op))
286277
return true;
287278
if (auto transferRead = dyn_cast<vector::TransferReadOp>(op))
288-
return transferReadSupportsMMAMatrixType(transferRead, useNvGpu);
279+
return useNvGpu ? nvgpu::canLowerToWarpMatrixOperation(transferRead)
280+
: transferReadSupportsMMAMatrixType(transferRead);
289281
if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op))
290-
return transferWriteSupportsMMAMatrixType(transferWrite);
282+
return useNvGpu ? nvgpu::canLowerToWarpMatrixOperation(transferWrite)
283+
: transferWriteSupportsMMAMatrixType(transferWrite);
291284
if (auto extractStridedSlice = dyn_cast<vector::ExtractStridedSliceOp>(op))
292285
return useNvGpu &&
293286
extractStridedSliceSupportsMMAMatrixType(extractStridedSlice);
@@ -372,9 +365,14 @@ static SetVector<Operation *> getOpToConvert(mlir::Operation *op,
372365
// chain. MMA matrix are stored in an opaque type so they cannot be used
373366
// by all operations.
374367
if (llvm::any_of(dependentOps, [useNvGpu](Operation *op) {
375-
return !supportsMMaMatrixType(op, useNvGpu);
368+
if (!supportsMMaMatrixType(op, useNvGpu)) {
369+
LLVM_DEBUG(DBGS() << "cannot convert op: " << *op << "\n");
370+
return true;
371+
}
372+
return false;
376373
}))
377374
return;
375+
378376
opToConvert.insert(dependentOps.begin(), dependentOps.end());
379377
});
380378
// Sort the operations so that we can convert them in topological order.
@@ -537,10 +535,11 @@ convertTransferReadOp(RewriterBase &rewriter, vector::TransferReadOp op,
537535
rewriter.setInsertionPoint(op);
538536

539537
assert(op.getTransferRank() > 0 && "unexpected 0-d transfer");
540-
assert(transferReadSupportsMMAMatrixType(op, /*useNvGpu=*/false));
538+
assert(transferReadSupportsMMAMatrixType(op) &&
539+
"expected convertible operation");
541540

542541
std::optional<int64_t> stride =
543-
getMemrefConstantHorizontalStride(op.getShapedType());
542+
getStaticallyKnownRowStride(op.getShapedType());
544543
if (!stride.has_value()) {
545544
LLVM_DEBUG(DBGS() << "no stride\n");
546545
return rewriter.notifyMatchFailure(op, "no stride");
@@ -591,7 +590,7 @@ convertTransferWriteOp(RewriterBase &rewriter, vector::TransferWriteOp op,
591590

592591
assert(transferWriteSupportsMMAMatrixType(op));
593592
std::optional<int64_t> stride =
594-
getMemrefConstantHorizontalStride(op.getShapedType());
593+
getStaticallyKnownRowStride(op.getShapedType());
595594
if (!stride.has_value()) {
596595
LLVM_DEBUG(DBGS() << "no stride\n");
597596
return rewriter.notifyMatchFailure(op, "no stride");
@@ -1303,7 +1302,8 @@ LogicalResult mlir::convertVectorToNVVMCompatibleMMASync(RewriterBase &rewriter,
13031302
return op->emitError() << "unhandled vector to mma type: " << *op;
13041303
})
13051304
.failed()) {
1306-
return op->emitError() << "Failed to convert op " << *op;
1305+
return op->emitOpError()
1306+
<< "failed to convert op during vector-to-nvgpu conversion";
13071307
}
13081308
}
13091309
return success();
@@ -1326,10 +1326,11 @@ struct ConvertVectorToGPUPass
13261326
return signalPassFailure();
13271327

13281328
IRRewriter rewriter(&getContext());
1329-
if (useNvGpu.getValue()) {
1329+
if (useNvGpu) {
13301330
if (failed(
13311331
convertVectorToNVVMCompatibleMMASync(rewriter, getOperation())))
13321332
return signalPassFailure();
1333+
return;
13331334
}
13341335
(void)convertVectorToMMAOps(rewriter, getOperation());
13351336
}

mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,3 +272,54 @@ nvgpu::getLaneIdToLdMatrixMatrixCoord(OpBuilder &builder, Location loc,
272272

273273
return failure();
274274
}
275+
276+
bool nvgpu::canLowerToWarpMatrixOperation(vector::TransferReadOp op) {
277+
if (op.getMask() || op.hasOutOfBoundsDim())
278+
return false;
279+
VectorType type = op.getType();
280+
// The result type should be 2D. Note that it is possible to expand support so
281+
// that we are robust to extra unit dimensions that failed to fold, but that
282+
// would significantly increase downstream code complexity in the conversion
283+
// step. For now, we rely on other patterns to ensure canonical 2D form is
284+
// used when targeting the `nvgpu.mma.sync` lowering path.
285+
if (!type.hasStaticShape() || type.getRank() != 2)
286+
return false;
287+
288+
// Currently we can't support reads on tensor types because we need stride
289+
// information to ensure correctness of downstream assumptions. It is possible
290+
// to enable this if caller can assert that tensor will be lowered in a
291+
// particular manner.
292+
auto sourceType = dyn_cast<MemRefType>(op.getSource().getType());
293+
if (!sourceType)
294+
return false;
295+
296+
// Check that the last dimension of the read is contiguous. Note that it is
297+
// possible to expand support for this by scalarizing all the loads during
298+
// conversion.
299+
auto [strides, offset] = mlir::getStridesAndOffset(sourceType);
300+
return strides.back() == 1;
301+
}
302+
303+
bool nvgpu::canLowerToWarpMatrixOperation(vector::TransferWriteOp op) {
304+
if (op.getMask() || op.hasOutOfBoundsDim() || op.getTransferRank() == 0)
305+
return false;
306+
VectorType type = op.getVectorType();
307+
if (!type.hasStaticShape() || type.getRank() != 2)
308+
return false;
309+
// TODO: Currently we rely on lowering to a `vector.store` operation. We could
310+
// support the transposed write case by lowering to scalarized `memref.store`
311+
// operations.
312+
if (!op.getPermutationMap().isMinorIdentity())
313+
return false;
314+
// Currently we can't support reads on tensor types because we need stride
315+
// information to ensure correctness of downstream assumptions.
316+
auto sourceType = dyn_cast<MemRefType>(op.getSource().getType());
317+
if (!sourceType)
318+
return false;
319+
320+
// Check that the last dimension of the target memref is contiguous. Note that
321+
// it is possible to expand support for this by scalarizing all the stores
322+
// during conversion.
323+
auto [strides, offset] = mlir::getStridesAndOffset(sourceType);
324+
return strides.back() == 1;
325+
}

mlir/test/Conversion/VectorToGPU/vector-to-mma-ops-mma-sync.mlir

Lines changed: 124 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ func.func @m16n8k32_int8_row_row_row(%arg0: memref<128x128xi8, #gpu.address_spac
4747
// CHECK: nvgpu.ldmatrix %arg0[[[m_coord]], [[k_coord]]] {numTiles = 4 : i32, transpose = false} : memref<128x128xi8, #gpu.address_space<workgroup>> -> vector<4x4xi8>
4848

4949
// Verify that the operandB load is lowered to scalar load to be able
50-
// to transpose at 8-bit granularity. ldmatrix can only transpose at
50+
// to transpose at 8-bit granularity. ldmatrix can only transpose at
5151
// 16-bit granularity.
5252

5353
// CHECK-DAG: [[row:%.+]] = affine.apply [[$rowB0_map]]()[{{%.+}}]
@@ -282,7 +282,7 @@ func.func @multi_dim_m16n8k16_fp16_row_row_row(%arg0: memref<4x32x1x32xf16, #gpu
282282
// CHECK-DAG: [[k_coord:%.+]] = affine.apply [[$strided_map]]
283283
// CHECK-DAG: [[fragmentB:%.+]] = nvgpu.ldmatrix %arg1[[[c0]], [[c0]], [[k_coord]], [[n_coord]]] {numTiles = 4 : i32, transpose = true}
284284
%B = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %cst {in_bounds = [true, true], permutation_map = #map_b} : memref<4x1x32x32xf16, #gpu.address_space<workgroup>>, vector<16x16xf16>
285-
285+
286286
// CHECK-DAG: [[m_coord:%.+]] = affine.apply [[$strided_map]]
287287
// CHECK-DAG: [[n_coord:%.+]] = affine.apply [[$contiguous_map]]
288288
// CHECK-DAG: [[fragmentC:%.*]] = nvgpu.ldmatrix %arg2[[[c0]], [[m_coord]], [[n_coord]]] {numTiles = 4 : i32, transpose = false}
@@ -713,3 +713,125 @@ func.func @m16n8k32_int8_row_col_row(%arg0: memref<128x128xi8, #gpu.address_spac
713713
vector.transfer_write %D, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x8xi32>, memref<128x128xi32>
714714
return
715715
}
716+
717+
// -----
718+
719+
720+
#map0 = affine_map<(d0, d1) -> (d1, d0)>
721+
#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
722+
#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
723+
#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
724+
!smem_type = memref<20x20xf16, strided<[?, 1], offset: ?>, #gpu.address_space<workgroup>>
725+
726+
// This test case is identical to m16n8k16 test case, but it tests that having
727+
// n row dimension with unknown stride is handled correctly.
728+
729+
// CHECK-DAG: [[$strided_map:#.+]] = affine_map<()[s0] -> (s0 mod 16)>
730+
// CHECK-DAG: [[$contiguous_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 16) * 8)>
731+
// CHECK-LABEL: func @strided_memref_read_write
732+
func.func @strided_memref_read_write(%arg0: !smem_type,
733+
%arg1: !smem_type,
734+
%arg2: !smem_type) {
735+
%cst_0 = arith.constant dense<0.000000e+00> : vector<16x8xf16>
736+
%c0 = arith.constant 0 : index
737+
%cst = arith.constant 0.000000e+00 : f16
738+
739+
// CHECK-DAG: [[m_coord:%.+]] = affine.apply [[$strided_map]]
740+
// CHECK-DAG: [[k_coord:%.+]] = affine.apply [[$contiguous_map]]
741+
// CHECK: nvgpu.ldmatrix %arg0[[[m_coord]], [[k_coord]]] {numTiles = 4 : i32, transpose = false}
742+
// CHECK-DAG: [[n_coord:%.+]] = affine.apply [[$contiguous_map]]
743+
// CHECK-DAG: [[k_coord:%.+]] = affine.apply [[$strided_map]]
744+
// CHECK: nvgpu.ldmatrix %arg1[[[k_coord]], [[n_coord]]] {numTiles = 2 : i32, transpose = true}
745+
%A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : !smem_type, vector<16x16xf16>
746+
%B = vector.transfer_read %arg1[%c0, %c0], %cst {permutation_map = #map0, in_bounds = [true, true]} : !smem_type, vector<8x16xf16>
747+
%C = vector.transfer_read %arg2[%c0, %c0], %cst {in_bounds = [true, true]} : !smem_type, vector<16x8xf16>
748+
%D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
749+
%A, %B, %C : vector<16x16xf16>, vector<8x16xf16> into vector<16x8xf16>
750+
vector.transfer_write %D, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x8xf16>, !smem_type
751+
return
752+
}
753+
754+
// -----
755+
756+
757+
#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
758+
#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
759+
#map2 = affine_map<(d0, d1, d2, d3) -> (d2, d0, d3)>
760+
#map3 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
761+
!smem_type = memref<20x20x20xf16, strided<[?, ?, 1], offset: ?>, #gpu.address_space<workgroup>>
762+
763+
// CHECK-LABEL: func @unsupported_non_2d_load_store
764+
func.func @unsupported_non_2d_load_store(%arg0: !smem_type,
765+
%arg1: !smem_type,
766+
%arg2: !smem_type) {
767+
%cst_0 = arith.constant dense<0.000000e+00> : vector<16x8xf16>
768+
%c0 = arith.constant 0 : index
769+
%cst = arith.constant 0.000000e+00 : f16
770+
771+
// CHECK-NOT: nvgpu.ldmatrix
772+
// CHECK-NOT: nvgpu.mma
773+
%A = vector.transfer_read %arg0[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : !smem_type, vector<1x16x16xf16>
774+
%B = vector.transfer_read %arg1[%c0, %c0, %c0], %cst {permutation_map = #map0, in_bounds = [true, true, true]} : !smem_type, vector<8x1x16xf16>
775+
%C = vector.transfer_read %arg2[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : !smem_type, vector<1x16x8xf16>
776+
%D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "parallel", "reduction"], kind = #vector.kind<add>}
777+
%A, %B, %C : vector<1x16x16xf16>, vector<8x1x16xf16> into vector<1x16x8xf16>
778+
vector.transfer_write %D, %arg2[%c0, %c0, %c0] {in_bounds = [true, true, true]} : vector<1x16x8xf16>, !smem_type
779+
return
780+
}
781+
782+
// -----
783+
784+
#map0 = affine_map<(d0, d1) -> (d1, d0)>
785+
#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
786+
#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
787+
#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
788+
789+
!smem_type = memref<20x20xf16, strided<[?, ?], offset: ?>, #gpu.address_space<workgroup>>
790+
791+
// CHECK-LABEL: func @unsupported_fully_dynamic_strides
792+
func.func @unsupported_fully_dynamic_strides(%arg0: !smem_type,
793+
%arg1: !smem_type,
794+
%arg2: !smem_type) {
795+
%cst_0 = arith.constant dense<0.000000e+00> : vector<16x8xf16>
796+
%c0 = arith.constant 0 : index
797+
%cst = arith.constant 0.000000e+00 : f16
798+
799+
// CHECK-NOT: nvgpu.ldmatrix
800+
// CHECK-NOT: nvgpu.mma
801+
%A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : !smem_type, vector<16x16xf16>
802+
%B = vector.transfer_read %arg1[%c0, %c0], %cst {permutation_map = #map0, in_bounds = [true, true]} : !smem_type, vector<8x16xf16>
803+
%C = vector.transfer_read %arg2[%c0, %c0], %cst {in_bounds = [true, true]} : !smem_type, vector<16x8xf16>
804+
%D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
805+
%A, %B, %C : vector<16x16xf16>, vector<8x16xf16> into vector<16x8xf16>
806+
vector.transfer_write %D, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x8xf16>, !smem_type
807+
return
808+
}
809+
810+
// -----
811+
812+
#map0 = affine_map<(d0, d1) -> (d1, d0)>
813+
#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
814+
#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
815+
#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
816+
817+
818+
!smem_type = memref<20x20xf16, strided<[?, 1], offset: ?>, #gpu.address_space<workgroup>>
819+
820+
// CHECK-LABEL: func @unsupported_transposed_store
821+
func.func @unsupported_transposed_store(%arg0: !smem_type,
822+
%arg1: !smem_type,
823+
%arg2: !smem_type) {
824+
%cst_0 = arith.constant dense<0.000000e+00> : vector<16x8xf16>
825+
%c0 = arith.constant 0 : index
826+
%cst = arith.constant 0.000000e+00 : f16
827+
828+
// CHECK-NOT: nvgpu.ldmatrix
829+
// CHECK-NOT: nvgpu.mma
830+
%A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : !smem_type, vector<16x16xf16>
831+
%B = vector.transfer_read %arg1[%c0, %c0], %cst {permutation_map = #map0, in_bounds = [true, true]} : !smem_type, vector<8x16xf16>
832+
%C = vector.transfer_read %arg2[%c0, %c0], %cst {in_bounds = [true, true]} : !smem_type, vector<16x8xf16>
833+
%D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
834+
%A, %B, %C : vector<16x16xf16>, vector<8x16xf16> into vector<16x8xf16>
835+
vector.transfer_write %D, %arg2[%c0, %c0] {in_bounds = [true, true], permutation_map = affine_map<(d0, d1)->(d1, d0)>} : vector<16x8xf16>, !smem_type
836+
return
837+
}

0 commit comments

Comments
 (0)