Skip to content

Commit 15ea230

Browse files
[mlir][NVGPU] Support N-D masks in transform.nvgpu.create_async_groups
Support IR that is generated by the vector-to-scf lowering of N-D vector transfers with a mask. (Until now only 1-D and 2-D transfers were supported.) Only transfers that were fully unrolled are supported. Differential Revision: https://reviews.llvm.org/D157286
1 parent 9329723 commit 15ea230

File tree

2 files changed

+85
-22
lines changed

2 files changed

+85
-22
lines changed

mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -68,19 +68,16 @@ static FailureOr<TransferMask> getMaskOp(Operation *loadOp) {
6868
transferRead.getMask().getDefiningOp<vector::CreateMaskOp>())
6969
return TransferMask{maskOp, {}};
7070

71-
// Case 2: Mask is the result of a vector.extract(vector.create_mask). Only
72-
// 2D -> 1D extracts are supported at the moment.
71+
// Case 2: Mask is the result of a vector.extract(vector.create_mask).
7372
if (auto extractOp =
7473
transferRead.getMask().getDefiningOp<vector::ExtractOp>())
7574
if (auto maskOp =
7675
extractOp.getVector().getDefiningOp<vector::CreateMaskOp>())
77-
if (extractOp.getPosition().size() == 1 &&
78-
extractOp.getSourceVectorType().getRank() == 2)
79-
return TransferMask{maskOp,
80-
SmallVector<int64_t>(extractOp.getPosition())};
76+
return TransferMask{maskOp,
77+
SmallVector<int64_t>(extractOp.getPosition())};
8178

8279
// All other cases: not supported.
83-
return {};
80+
return failure();
8481
}
8582

8683
/// Build an SSA value that represents the number of read elements.
@@ -102,18 +99,27 @@ static Value buildNumReadElements(OpBuilder &b, Location loc,
10299

103100
// vector.extract(vector.create_mask).
104101
// If extract_pos < num_ones, take number of elements from the least
105-
// significant dimension.
106-
assert(transferMask->createMaskOp.getVectorType().getRank() == 2 &&
107-
"expected 2D mask");
108-
assert(transferMask->extractPosition.size() == 1 &&
109-
"expected 2D->1D extract");
110-
Value cmp = b.create<arith::CmpIOp>(
111-
loc, arith::CmpIPredicate::slt,
112-
b.create<arith::ConstantIndexOp>(loc,
113-
transferMask->extractPosition.front()),
114-
transferMask->createMaskOp->getOperands().front());
102+
// significant dimension. (Do this for all dimensions and bit-AND the
103+
// conditions.)
104+
assert(transferMask->createMaskOp.getVectorType().getRank() -
105+
transferMask->extractPosition.size() ==
106+
1 &&
107+
"expected N-D -> (N-1)-D extract");
108+
Value cond;
109+
// Note: There is one more `sz` than `pos`. The loop end with the last `pos`.
110+
for (auto [pos, sz] : llvm::zip(transferMask->extractPosition,
111+
transferMask->createMaskOp->getOperands())) {
112+
Value cmp =
113+
b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
114+
b.create<arith::ConstantIndexOp>(loc, pos), sz);
115+
if (!cond) {
116+
cond = cmp;
117+
continue;
118+
}
119+
cond = b.create<arith::AndIOp>(loc, cmp, cond);
120+
}
115121
return b.create<arith::SelectOp>(
116-
loc, cmp, transferMask->createMaskOp->getOperands().back(),
122+
loc, cond, transferMask->createMaskOp->getOperands().back(),
117123
b.create<arith::ConstantIndexOp>(loc, 0));
118124
}
119125

mlir/test/Dialect/NVGPU/transform-create-async-groups.mlir

Lines changed: 61 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -165,10 +165,6 @@ builtin.module {
165165
%0 = memref.alloc() : memref<4x32x16xf32, #gpu.address_space<workgroup>>
166166
%c0 = arith.constant 0 : index
167167
%cst_0 = arith.constant 0.000000e+00 : f32
168-
// CHECK: %[[mask:.*]] = vector.create_mask
169-
// CHECK: %[[e0:.*]] = vector.extract %[[mask]][0] : vector<3x4xi1>
170-
// CHECK: %[[e1:.*]] = vector.extract %[[mask]][1] : vector<3x4xi1>
171-
// CHECK: %[[e2:.*]] = vector.extract %[[mask]][2] : vector<3x4xi1>
172168

173169
// CHECK: %[[cmpi0:.*]] = arith.cmpi slt, %[[c0]], %[[sz0]]
174170
// CHECK: %[[s0:.*]] = arith.select %[[cmpi0]], %[[sz1]], %[[c0]]
@@ -199,3 +195,64 @@ builtin.module {
199195
transform.apply_cse to %top_level_func_2 : !transform.any_op
200196
}
201197
}
198+
199+
// -----
200+
201+
// 3D vector.transfer_read with a mask.
202+
builtin.module {
203+
// CHECK-LABEL: @read_3d_with_mask(
204+
// CHECK-SAME: %[[sz0:.*]]: index, %[[sz1:.*]]: index, %[[sz2:.*]]: index, %[[a:.*]]: memref<1024x1024x1024xf32>
205+
func.func @read_3d_with_mask(%sz0: index, %sz1: index, %sz2: index, %a: memref<1024x1024x1024xf32>) {
206+
// CHECK: %[[c0:.*]] = arith.constant 0 : index
207+
// CHECK: %[[c1:.*]] = arith.constant 1 : index
208+
// CHECK: %[[c2:.*]] = arith.constant 2 : index
209+
%0 = memref.alloc() : memref<4x32x16xf32, #gpu.address_space<workgroup>>
210+
%c0 = arith.constant 0 : index
211+
%cst_0 = arith.constant 0.000000e+00 : f32
212+
213+
// CHECK: %[[cmpi0:.*]] = arith.cmpi slt, %[[c0]], %[[sz0]]
214+
// CHECK: %[[cmpi1:.*]] = arith.cmpi slt, %[[c0]], %[[sz1]]
215+
// CHECK: %[[cond0:.*]] = arith.andi %[[cmpi1]], %[[cmpi0]]
216+
// CHECK: %[[s0:.*]] = arith.select %[[cond0]], %[[sz2]], %[[c0]]
217+
// CHECK: nvgpu.device_async_copy %[[a]][%[[c0]], %[[c0]], %[[c0]]], {{.*}}, 4, %[[s0]] {bypassL1}
218+
219+
// CHECK: %[[cmpi2:.*]] = arith.cmpi slt, %[[c1]], %[[sz1]]
220+
// CHECK: %[[cond1:.*]] = arith.andi %[[cmpi2]], %[[cmpi0]]
221+
// CHECK: %[[s1:.*]] = arith.select %[[cond1]], %[[sz2]], %[[c0]]
222+
// CHECK: nvgpu.device_async_copy %[[a]][%[[c0]], %[[c1]], %[[c0]]], {{.*}}, 4, %[[s1]] {bypassL1}
223+
224+
// CHECK: %[[cmpi3:.*]] = arith.cmpi slt, %[[c2]], %[[sz1]]
225+
// CHECK: %[[cond2:.*]] = arith.andi %[[cmpi3]], %[[cmpi0]]
226+
// CHECK: %[[s2:.*]] = arith.select %[[cond2]], %[[sz2]], %[[c0]]
227+
// CHECK: nvgpu.device_async_copy %[[a]][%[[c0]], %[[c2]], %[[c0]]], {{.*}}, 4, %[[s2]] {bypassL1}
228+
229+
// CHECK: %[[cmpi4:.*]] = arith.cmpi slt, %[[c1]], %[[sz0]]
230+
// CHECK: %[[cond3:.*]] = arith.andi %[[cmpi1]], %[[cmpi4]]
231+
// CHECK: %[[s3:.*]] = arith.select %[[cond3]], %[[sz2]], %[[c0]]
232+
// CHECK: nvgpu.device_async_copy %[[a]][%[[c1]], %[[c0]], %[[c0]]], {{.*}}, 4, %[[s3]] {bypassL1}
233+
234+
// CHECK: %[[cond4:.*]] = arith.andi %[[cmpi2]], %[[cmpi4]]
235+
// CHECK: %[[s4:.*]] = arith.select %[[cond4]], %[[sz2]], %[[c0]]
236+
// CHECK: nvgpu.device_async_copy %[[a]][%[[c1]], %[[c1]], %[[c0]]], {{.*}}, 4, %[[s4]] {bypassL1}
237+
238+
// CHECK: %[[cond5:.*]] = arith.andi %[[cmpi3]], %[[cmpi4]]
239+
// CHECK: %[[s5:.*]] = arith.select %[[cond5]], %[[sz2]], %[[c0]]
240+
// CHECK: nvgpu.device_async_copy %[[a]][%[[c1]], %[[c2]], %[[c0]]], {{.*}}, 4, %[[s5]] {bypassL1}
241+
%mask = vector.create_mask %sz0, %sz1, %sz2 : vector<2x3x4xi1>
242+
%1 = vector.transfer_read %a[%c0, %c0, %c0], %cst_0, %mask {in_bounds = [true, true, true]} : memref<1024x1024x1024xf32>, vector<2x3x4xf32>
243+
vector.transfer_write %1, %0[%c0, %c0, %c0] {in_bounds = [true, true, true]} : vector<2x3x4xf32>, memref<4x32x16xf32, #gpu.address_space<workgroup>>
244+
245+
return
246+
}
247+
248+
transform.sequence failures(propagate) {
249+
^bb1(%variant_op: !transform.any_op):
250+
%top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
251+
transform.apply_patterns to %top_level_func {
252+
transform.apply_patterns.vector.transfer_to_scf max_transfer_rank = 1 full_unroll = true
253+
} : !transform.any_op
254+
transform.nvgpu.create_async_groups %top_level_func {bypass_l1} : (!transform.any_op) -> (!transform.any_op)
255+
%top_level_func_2 = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
256+
transform.apply_cse to %top_level_func_2 : !transform.any_op
257+
}
258+
}

0 commit comments

Comments
 (0)