Skip to content

Commit c41286a

Browse files
[mlir][linalg] Emit a warning when tile_using_forall generates non thread-safe code (#80813)
**Description** The documentation of `transform.structured.tile_using_forall` says: _"It is the user’s responsibility to ensure that num_threads/tile_sizes is a valid tiling specification (i.e. that only tiles parallel dimensions, e.g. in the Linalg case)."_ In other words, tiling a non-parallel dimension would generate code with data races which is not safe to parallelize. For example, consider this example (included in the tests in this PR): ``` func.func @tile_thread_safety2(%arg0: tensor<100x300x8xf32>, %arg1: tensor<300x8xf32>) -> tensor<300x8xf32> { %0 = scf.forall (%arg2) in (8) shared_outs(%arg3 = %arg1) -> (tensor<300x8xf32>) { %1 = affine.min #map(%arg2) %2 = affine.max #map1(%1) %3 = affine.apply #map2(%arg2) %extracted_slice = tensor.extract_slice %arg0[%3, 0, 0] [%2, 300, 8] [1, 1, 1] : tensor<100x300x8xf32> to tensor<?x300x8xf32> %4 = linalg.generic {indexing_maps = [#map3, #map4], iterator_types = ["reduction", "parallel", "parallel"]} ins(%extracted_slice : tensor<?x300x8xf32>) outs(%arg3 : tensor<300x8xf32>) { ^bb0(%in: f32, %out: f32): %5 = arith.addf %in, %out : f32 linalg.yield %5 : f32 } -> tensor<300x8xf32> scf.forall.in_parallel { tensor.parallel_insert_slice %4 into %arg3[0, 0] [300, 8] [1, 1] : tensor<300x8xf32> into tensor<300x8xf32> } } return %0 : tensor<300x8xf32> } ``` We can easily see that this is not safe to parallelize because all threads would be writing to the same position in `%arg3` (in the `scf.forall.in_parallel`. This PR detects wether it's safe to `tile_using_forall` and emits a warning in the case it is not. **Brief explanation** It first generates a vector of affine expressions representing the tile values and stores it in `dimExprs`. These affine expressions are compared with the affine expressions coming from the results of the affine map of each output in the linalg op. So going back to the previous example, the original transform is: ``` #map = affine_map<(d0, d1, d2) -> (d0, d1, d2)> #map1 = affine_map<(d0, d1, d2) -> (d1, d2)> func.func @tile_thread_safety2(%arg0: tensor<100x300x8xf32>, %arg1: tensor<300x8xf32>) -> tensor<300x8xf32> { // expected-warning@+1 {{tiling is not thread safe at axis #0}} %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["reduction", "parallel", "parallel"]} ins(%arg0 : tensor<100x300x8xf32>) outs(%arg1 : tensor<300x8xf32>) { ^bb0(%in: f32, %out: f32): %1 = arith.addf %in, %out : f32 linalg.yield %1 : f32 } -> tensor<300x8xf32> return %0 : tensor<300x8xf32> } module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { %0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op %forall, %tiled_generic = transform.structured.tile_using_forall %0 num_threads [8] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) transform.yield } } ``` The `num_threads` attribute would be represented as `(d0)`. Because the linalg op has only one output (`arg1`) it would only check against the results of `#map1`, which are `(d1, d2)`. The idea is to check that all affine expressions in `dimExprs` are present in the output affine map. In this example, `d0` is not in `(d1, d2)`, so tiling that axis is considered not thread safe.
1 parent ceabaa7 commit c41286a

File tree

3 files changed

+180
-3
lines changed

3 files changed

+180
-3
lines changed

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1918,7 +1918,9 @@ def TileUsingForallOp :
19181918

19191919
It is the user's responsibility to ensure that `num_threads/tile_sizes` is
19201920
a valid tiling specification (i.e. that only tiles parallel dimensions,
1921-
e.g. in the Linalg case).
1921+
e.g. in the Linalg case). If the dimension is not parallelizable, a warning
1922+
is issued to notify the user that the generated code is not safe to
1923+
parallelize.
19221924

19231925
If non-empty, the `mapping` is added as an attribute to the
19241926
resulting `scf.forall`.

mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,28 @@ static void calculateTileOffsetsAndSizes(
304304
}
305305
}
306306

307+
/// Returns a vector of bools representing if, for each axis, `op` can be tiled
308+
/// without incurring in a race condition and thus it is thread-safe to do the
309+
/// tiling. This is checked by iterating over numThreads and ensuring that the
310+
/// corresponding iterator type is "parallel". If it is not, then we know that
311+
/// such dimension is unsafe to tile.
312+
SmallVector<bool> safeToTileToForall(mlir::MLIRContext *ctx, LinalgOp linalgOp,
313+
ArrayRef<OpFoldResult> numThreads) {
314+
auto iterators = linalgOp.getIteratorTypesArray();
315+
SmallVector<bool> safeToTile(numThreads.size(), true);
316+
317+
for (unsigned i = 0, e = numThreads.size(); i != e; i++) {
318+
if (auto attr = llvm::dyn_cast_if_present<Attribute>(numThreads[i])) {
319+
if (cast<IntegerAttr>(attr).getValue().getSExtValue() > 1) {
320+
safeToTile[i] = iterators[i] == utils::IteratorType::parallel;
321+
}
322+
} else {
323+
safeToTile[i] = iterators[i] == utils::IteratorType::parallel;
324+
}
325+
}
326+
return safeToTile;
327+
}
328+
307329
/// Rewrite a TilingInterface `op` to a tiled `scf.forall`. The
308330
/// tiling is specified by the number of tiles/threads `numThreads` and the
309331
/// optional nominal tile size `nominalTileSizes`. If `nominalTilSizes` is
@@ -314,8 +336,10 @@ static void calculateTileOffsetsAndSizes(
314336
/// size of data.
315337
/// It is the user's responsibility to ensure that `numThreads` is a valid
316338
/// tiling specification (i.e. that only tiles parallel dimensions, e.g. in the
317-
/// Linalg case). If `omitTileOffsetBoundsCheck` is true, then the function will
318-
/// assume that `tileSize[i] * (numThread[i] -1) <= dimSize[i]` holds.
339+
/// Linalg case). If the dimension is not parallelizable, a warning is issued to
340+
/// notify the user that the generated code is not safe to parallelize. If
341+
/// `omitTileOffsetBoundsCheck` is true, then the function will assume that
342+
/// `tileSize[i] * (numThread[i] -1) <= dimSize[i]` holds.
319343
static FailureOr<ForallTilingResult> tileToForallOpImpl(
320344
RewriterBase &b, TilingInterface op, ArrayRef<OpFoldResult> numThreads,
321345
std::optional<ArrayRef<OpFoldResult>> nominalTileSizes,
@@ -344,6 +368,16 @@ static FailureOr<ForallTilingResult> tileToForallOpImpl(
344368
return getValueOrCreateConstantIndexOp(b, loc, ofr);
345369
}));
346370

371+
LinalgOp linalgOp = dyn_cast<LinalgOp>(op.getOperation());
372+
if (linalgOp) {
373+
// Check if tiling is thread safe and print a warning if not.
374+
SmallVector<bool> tilingSafety =
375+
safeToTileToForall(b.getContext(), linalgOp, numThreads);
376+
for (size_t i = 0; i < tilingSafety.size(); i++)
377+
if (!tilingSafety[i])
378+
op.emitWarning() << "tiling is not thread safe at axis #" << i;
379+
}
380+
347381
// 1. Create the ForallOp. We don't use the lambda body-builder
348382
// version because we require the use of RewriterBase in the body, so we
349383
// manually move the insertion point to the body below.

mlir/test/Dialect/Linalg/tile-to-forall.mlir

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -586,3 +586,144 @@ module attributes {transform.with_named_sequence} {
586586
transform.yield
587587
}
588588
}
589+
590+
// -----
591+
592+
#map = affine_map<(d0, d1) -> (d0, d1)>
593+
#map1 = affine_map<(d0, d1) -> (d0)>
594+
595+
func.func @tile_thread_safety1(%arg0: tensor<100x300xf32>, %arg1: tensor<100xf32>) -> tensor<100xf32> {
596+
// expected-warning@below {{tiling is not thread safe at axis #1}}
597+
%0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "reduction"]} ins(%arg0 : tensor<100x300xf32>) outs(%arg1 : tensor<100xf32>) {
598+
^bb0(%in: f32, %out: f32):
599+
%1 = arith.addf %in, %out : f32
600+
linalg.yield %1 : f32
601+
} -> tensor<100xf32>
602+
return %0 : tensor<100xf32>
603+
}
604+
605+
module attributes {transform.with_named_sequence} {
606+
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
607+
%0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op
608+
%forall, %tiled_generic = transform.structured.tile_using_forall %0 num_threads [4, 2]
609+
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
610+
transform.yield
611+
}
612+
}
613+
614+
// -----
615+
616+
#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
617+
#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
618+
619+
func.func @tile_thread_safety2(%arg0: tensor<100x300x8xf32>, %arg1: tensor<300x8xf32>) -> tensor<300x8xf32> {
620+
// expected-warning@below {{tiling is not thread safe at axis #0}}
621+
%0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["reduction", "parallel", "parallel"]} ins(%arg0 : tensor<100x300x8xf32>) outs(%arg1 : tensor<300x8xf32>) {
622+
^bb0(%in: f32, %out: f32):
623+
%1 = arith.addf %in, %out : f32
624+
linalg.yield %1 : f32
625+
} -> tensor<300x8xf32>
626+
return %0 : tensor<300x8xf32>
627+
}
628+
629+
module attributes {transform.with_named_sequence} {
630+
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
631+
%0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op
632+
%forall, %tiled_generic = transform.structured.tile_using_forall %0 num_threads [8]
633+
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
634+
transform.yield
635+
}
636+
}
637+
638+
// -----
639+
640+
#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
641+
#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
642+
643+
func.func @tile_thread_safety3(%arg0: tensor<100x300x8xf32>, %arg1: tensor<100x8xf32>) -> tensor<100x8xf32> {
644+
// expected-warning@below {{tiling is not thread safe at axis #1}}
645+
%0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "reduction", "parallel"]} ins(%arg0 : tensor<100x300x8xf32>) outs(%arg1 : tensor<100x8xf32>) {
646+
^bb0(%in: f32, %out: f32):
647+
%1 = arith.addf %in, %out : f32
648+
linalg.yield %1 : f32
649+
} -> tensor<100x8xf32>
650+
return %0 : tensor<100x8xf32>
651+
}
652+
653+
module attributes {transform.with_named_sequence} {
654+
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
655+
%0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op
656+
%forall, %tiled_generic = transform.structured.tile_using_forall %0 num_threads [8, 4, 2]
657+
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
658+
transform.yield
659+
}
660+
}
661+
662+
// -----
663+
664+
#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
665+
#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
666+
#map2 = affine_map<(d0, d1, d2) -> (d2)>
667+
668+
func.func @tile_thread_safety4(%arg0: tensor<100x300x8xf32>, %arg1: tensor<100x8xf32>, %arg2 : tensor<8xf32>) -> (tensor<100x8xf32>, tensor<8xf32>) {
669+
// expected-warning@+2 {{tiling is not thread safe at axis #0}}
670+
// expected-warning@below {{tiling is not thread safe at axis #1}}
671+
%0:2 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "reduction", "parallel"]} ins(%arg0 : tensor<100x300x8xf32>) outs(%arg1, %arg2 : tensor<100x8xf32>, tensor<8xf32>) {
672+
^bb0(%in: f32, %out1: f32, %out2: f32):
673+
%1 = arith.addf %in, %out1 : f32
674+
%2 = arith.addf %in, %out2 : f32
675+
linalg.yield %1, %2 : f32, f32
676+
} -> (tensor<100x8xf32>, tensor<8xf32>)
677+
return %0#0, %0#1 : tensor<100x8xf32>, tensor<8xf32>
678+
}
679+
680+
module attributes {transform.with_named_sequence} {
681+
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
682+
%0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op
683+
%forall, %tiled_generic = transform.structured.tile_using_forall %0 num_threads [8, 4, 2]
684+
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
685+
transform.yield
686+
}
687+
}
688+
689+
// -----
690+
691+
#map = affine_map<(d0, d1) -> (d0, d1)>
692+
#map1 = affine_map<(d0, d1) -> (d0)>
693+
694+
func.func @tile_thread_safety5(%arg0: tensor<100x300xf32>, %arg1: tensor<100xf32>) -> tensor<100xf32> {
695+
// expected-warning@below {{tiling is not thread safe at axis #1}}
696+
%0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "reduction"]} ins(%arg0 : tensor<100x300xf32>) outs(%arg1 : tensor<100xf32>) {
697+
^bb0(%in: f32, %out: f32):
698+
%1 = arith.addf %in, %out : f32
699+
linalg.yield %1 : f32
700+
} -> tensor<100xf32>
701+
return %0 : tensor<100xf32>
702+
}
703+
704+
module attributes {transform.with_named_sequence} {
705+
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
706+
%0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op
707+
%forall, %tiled_generic = transform.structured.tile_using_forall %0 tile_sizes [10, 1]
708+
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
709+
transform.yield
710+
}
711+
}
712+
713+
// -----
714+
715+
func.func @tile_thread_safety6(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
716+
// expected-warning@below {{tiling is not thread safe at axis #2}}
717+
%0 = linalg.matmul ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>)
718+
outs(%C : tensor<?x?xf32>) -> (tensor<?x?xf32>)
719+
return %0 : tensor<?x?xf32>
720+
}
721+
722+
module attributes {transform.with_named_sequence} {
723+
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
724+
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg0 : (!transform.any_op) -> !transform.any_op
725+
%forall, %tiled_generic = transform.structured.tile_using_forall %0 num_threads [2, 0, 8]
726+
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
727+
transform.yield
728+
}
729+
}

0 commit comments

Comments
 (0)