You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
[mlir][linalg] Emit a warning when tile_using_forall generates non thread-safe code (#80813)
**Description**
The documentation of `transform.structured.tile_using_forall` says:
_"It is the user’s responsibility to ensure that num_threads/tile_sizes
is a valid tiling specification (i.e. that only tiles parallel
dimensions, e.g. in the Linalg case)."_
In other words, tiling a non-parallel dimension would generate code with
data races which is not safe to parallelize. For example, consider this
example (included in the tests in this PR):
```
func.func @tile_thread_safety2(%arg0: tensor<100x300x8xf32>, %arg1: tensor<300x8xf32>) -> tensor<300x8xf32> {
%0 = scf.forall (%arg2) in (8) shared_outs(%arg3 = %arg1) -> (tensor<300x8xf32>) {
%1 = affine.min #map(%arg2)
%2 = affine.max #map1(%1)
%3 = affine.apply #map2(%arg2)
%extracted_slice = tensor.extract_slice %arg0[%3, 0, 0] [%2, 300, 8] [1, 1, 1] : tensor<100x300x8xf32> to tensor<?x300x8xf32>
%4 = linalg.generic {indexing_maps = [#map3, #map4], iterator_types = ["reduction", "parallel", "parallel"]} ins(%extracted_slice : tensor<?x300x8xf32>) outs(%arg3 : tensor<300x8xf32>) {
^bb0(%in: f32, %out: f32):
%5 = arith.addf %in, %out : f32
linalg.yield %5 : f32
} -> tensor<300x8xf32>
scf.forall.in_parallel {
tensor.parallel_insert_slice %4 into %arg3[0, 0] [300, 8] [1, 1] : tensor<300x8xf32> into tensor<300x8xf32>
}
}
return %0 : tensor<300x8xf32>
}
```
We can easily see that this is not safe to parallelize because all
threads would be writing to the same position in `%arg3` (in the
`scf.forall.in_parallel`.
This PR detects wether it's safe to `tile_using_forall` and emits a
warning in the case it is not.
**Brief explanation**
It first generates a vector of affine expressions representing the tile
values and stores it in `dimExprs`. These affine expressions are
compared with the affine expressions coming from the results of the
affine map of each output in the linalg op. So going back to the
previous example, the original transform is:
```
#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
func.func @tile_thread_safety2(%arg0: tensor<100x300x8xf32>, %arg1: tensor<300x8xf32>) -> tensor<300x8xf32> {
// expected-warning@+1 {{tiling is not thread safe at axis #0}}
%0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["reduction", "parallel", "parallel"]} ins(%arg0 : tensor<100x300x8xf32>) outs(%arg1 : tensor<300x8xf32>) {
^bb0(%in: f32, %out: f32):
%1 = arith.addf %in, %out : f32
linalg.yield %1 : f32
} -> tensor<300x8xf32>
return %0 : tensor<300x8xf32>
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op
%forall, %tiled_generic = transform.structured.tile_using_forall %0 num_threads [8]
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.yield
}
}
```
The `num_threads` attribute would be represented as `(d0)`. Because the
linalg op has only one output (`arg1`) it would only check against the
results of `#map1`, which are `(d1, d2)`. The idea is to check that all
affine expressions in `dimExprs` are present in the output affine map.
In this example, `d0` is not in `(d1, d2)`, so tiling that axis is
considered not thread safe.
0 commit comments