Skip to content

Commit dae8c72

Browse files
[mlir][linalg] TileToForallOp: Support memref ops
Support tiling of ops with memref semantics. Differential Revision: https://reviews.llvm.org/D153353
1 parent 9fa7998 commit dae8c72

File tree

4 files changed

+91
-36
lines changed

4 files changed

+91
-36
lines changed

mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -380,11 +380,14 @@ static FailureOr<ForallTilingResult> tileToForallOpImpl(
380380
auto destinationStyleOp = dyn_cast<DestinationStyleOpInterface>(clonedOp);
381381
if (destinationStyleOp) {
382382
for (OpOperand *outOperand : destinationStyleOp.getDpsInitOperands()) {
383-
auto *it = llvm::find(dest, outOperand->get());
384-
if (it == dest.end())
385-
return op->emitOpError("must have \"tensor semantic\" for tiling");
386-
unsigned destNum = std::distance(dest.begin(), it);
387-
outOperand->set(destBbArgs[destNum]);
383+
// Swap tensor inits with the corresponding block argument of the
384+
// scf.forall op. Memref inits remain as is.
385+
if (outOperand->get().getType().isa<TensorType>()) {
386+
auto *it = llvm::find(dest, outOperand->get());
387+
assert(it != dest.end() && "could not find destination tensor");
388+
unsigned destNum = std::distance(dest.begin(), it);
389+
outOperand->set(destBbArgs[destNum]);
390+
}
388391
}
389392
}
390393

mlir/test/Dialect/GPU/transform-gpu-failing.mlir

Lines changed: 0 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -274,34 +274,3 @@ transform.sequence failures(propagate) {
274274
// expected-error @below {{duplicated attribute, cannot map different loops to the same processor}}
275275
transform.gpu.map_nested_forall_to_threads %funcop block_dims = [32, 32, 1] : (!transform.any_op) -> !transform.any_op
276276
}
277-
278-
// -----
279-
280-
func.func @tiling_buffer_semantic_op(%x: memref<32x32xf32>, %y: memref<32x32xf32>, %stream : !gpu.async.token) {
281-
%one = arith.constant 1 : index
282-
%name = gpu.launch async[%stream] blocks(%arg3, %arg4, %arg5) in (%arg9 = %one, %arg10 = %one, %arg11 = %one)
283-
threads(%arg6, %arg7, %arg8) in (%arg12 = %one, %arg13 = %one, %arg14 = %one)
284-
{
285-
// expected-error @below {{'linalg.generic' op must have "tensor semantic" for tiling}}
286-
// expected-note @below {{when applied to this op}}
287-
linalg.generic
288-
{indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
289-
affine_map<(d0, d1) -> (d0, d1)>],
290-
iterator_types = ["parallel", "parallel"]}
291-
ins(%x : memref<32x32xf32>)
292-
outs(%y : memref<32x32xf32>) {
293-
^bb0(%in: f32, %out: f32):
294-
linalg.yield %in : f32
295-
}
296-
gpu.terminator
297-
}
298-
return
299-
}
300-
301-
transform.sequence failures(propagate) {
302-
^bb1(%arg0: !transform.any_op):
303-
%matmul = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op
304-
// expected-error @below {{transform.structured.tile_to_forall_op failed to apply}}
305-
%forall, %tiled = transform.structured.tile_to_forall_op %matmul num_threads [10, 20, 30] (mapping = [ #gpu.thread<y>, #gpu.thread<x>, #gpu.thread<z> ] )
306-
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
307-
}

mlir/test/Dialect/GPU/transform-gpu.mlir

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,3 +307,39 @@ transform.sequence failures(propagate) {
307307
transform.gpu.map_nested_forall_to_threads %funcop
308308
block_dims = [12, 11, 1] warp_dims = [3, 2, 1] : (!transform.any_op) -> !transform.any_op
309309
}
310+
311+
// -----
312+
313+
// CHECK-LABEL: func.func @tiling_buffer_semantic_op(
314+
// CHECK: gpu.launch {{.*}} {
315+
// CHECK: scf.forall {{.*}} {
316+
// CHECK: memref.subview
317+
// CHECK: memref.subview
318+
// CHECK: linalg.generic
319+
// CHECK: }
320+
// CHECK: }
321+
func.func @tiling_buffer_semantic_op(%x: memref<32x32xf32>, %y: memref<32x32xf32>, %stream : !gpu.async.token) {
322+
%one = arith.constant 1 : index
323+
%name = gpu.launch async[%stream] blocks(%arg3, %arg4, %arg5) in (%arg9 = %one, %arg10 = %one, %arg11 = %one)
324+
threads(%arg6, %arg7, %arg8) in (%arg12 = %one, %arg13 = %one, %arg14 = %one)
325+
{
326+
linalg.generic
327+
{indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
328+
affine_map<(d0, d1) -> (d0, d1)>],
329+
iterator_types = ["parallel", "parallel"]}
330+
ins(%x : memref<32x32xf32>)
331+
outs(%y : memref<32x32xf32>) {
332+
^bb0(%in: f32, %out: f32):
333+
linalg.yield %in : f32
334+
}
335+
gpu.terminator
336+
}
337+
return
338+
}
339+
340+
transform.sequence failures(propagate) {
341+
^bb1(%arg0: !transform.any_op):
342+
%matmul = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op
343+
%forall, %tiled = transform.structured.tile_to_forall_op %matmul num_threads [10, 20, 30] (mapping = [ #gpu.thread<y>, #gpu.thread<x>, #gpu.thread<z> ] )
344+
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
345+
}

mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,53 @@ module {
4040

4141
// -----
4242

43+
module {
44+
// CHECK-LABEL: func @matmul_memref(
45+
// CHECK: scf.forall (%{{.*}}, %{{.*}}) in (10, 20) {
46+
// CHECK: memref.subview
47+
// CHECK: memref.subview
48+
// CHECK: memref.subview
49+
// CHECK: linalg.matmul
50+
// CHECK: } {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
51+
func.func @matmul_memref(%A: memref<?x?xf32>, %B: memref<?x?xf32>, %C: memref<?x?xf32>) {
52+
linalg.matmul ins(%A, %B : memref<?x?xf32>, memref<?x?xf32>)
53+
outs(%C : memref<?x?xf32>)
54+
return
55+
}
56+
57+
transform.sequence failures(propagate) {
58+
^bb1(%arg1: !transform.any_op):
59+
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
60+
%1:2 = transform.structured.tile_to_forall_op %0 num_threads [10, 20] (mapping = [ #gpu.thread<y>, #gpu.thread<x> ] )
61+
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
62+
}
63+
}
64+
65+
// -----
66+
67+
module {
68+
// CHECK-LABEL: func @copy_memref(
69+
// CHECK: scf.forall (%{{.*}}, %{{.*}}) in (10, 20) {
70+
// CHECK: memref.subview
71+
// CHECK: memref.subview
72+
// CHECK: linalg.copy
73+
// CHECK: } {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
74+
func.func @copy_memref(%A: memref<?x?xf32>, %B: memref<?x?xf32>) {
75+
linalg.copy ins(%A: memref<?x?xf32>)
76+
outs(%B : memref<?x?xf32>)
77+
return
78+
}
79+
80+
transform.sequence failures(propagate) {
81+
^bb1(%arg1: !transform.any_op):
82+
%0 = transform.structured.match ops{["linalg.copy"]} in %arg1 : (!transform.any_op) -> !transform.any_op
83+
%1:2 = transform.structured.tile_to_forall_op %0 num_threads [10, 20] (mapping = [ #gpu.thread<y>, #gpu.thread<x> ] )
84+
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
85+
}
86+
}
87+
88+
// -----
89+
4390
// In this test case, matmul dims and tile size are dynamic.
4491

4592
// CHECK-DAG: #[[$map0:.+]] = affine_map<()[s0, s1] -> (s0 ceildiv s1)>

0 commit comments

Comments
 (0)