Skip to content

Commit 2312064

Browse files
fairywreathtomtor
authored andcommitted
[mlir][spirv] Fix FuncOpVectorUnroll to process placeholder values in all blocks (llvm#142339)
`FuncOpVectorUnroll` contains logic that replaces function arguments by placeholders values. These replacements also involve changing all instructions in the function that use the arguments to use these placeholders. These placeholder values will later be changed back to use the function arguments (either new or original if already legal). The current implementation however only replaces back (the second replacement, i.e. replacing the placeholder values to new/legal arguments) the first block of instructions and not all of the blocks. This may leave some instructions to use these placeholder values (which for already legal arguments are just zeroattr values that will get DCE'd) instead of the arguments, which is incorrect. Closes llvm#132158.
1 parent 88affd3 commit 2312064

File tree

2 files changed

+86
-13
lines changed

2 files changed

+86
-13
lines changed

mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1020,22 +1020,22 @@ struct FuncOpVectorUnroll final : OpRewritePattern<func::FuncOp> {
10201020
SmallVector<Location> locs(convertedTypes.size(), newFuncOp.getLoc());
10211021
entryBlock.addArguments(convertedTypes, locs);
10221022

1023-
// Replace the placeholder values with the new arguments. We assume there is
1024-
// only one block for now.
1023+
// Replace all uses of placeholders for initially legal arguments with their
1024+
// original function arguments (that were added to `newFuncOp`).
1025+
for (auto &[placeholderOp, argIdx] : tmpOps) {
1026+
if (!placeholderOp)
1027+
continue;
1028+
Value replacement = newFuncOp.getArgument(argIdx);
1029+
rewriter.replaceAllUsesWith(placeholderOp->getResult(0), replacement);
1030+
}
1031+
1032+
// Replace dummy operands of new `vector.insert_strided_slice` ops with
1033+
// their corresponding new function arguments. The new
1034+
// `vector.insert_strided_slice` ops are inserted only into the entry block,
1035+
// so iterating over that block is sufficient.
10251036
size_t unrolledInputIdx = 0;
10261037
for (auto [count, op] : enumerate(entryBlock.getOperations())) {
1027-
// We first look for operands that are placeholders for initially legal
1028-
// arguments.
10291038
Operation &curOp = op;
1030-
for (auto [operandIdx, operandVal] : llvm::enumerate(op.getOperands())) {
1031-
Operation *operandOp = operandVal.getDefiningOp();
1032-
if (auto it = tmpOps.find(operandOp); it != tmpOps.end()) {
1033-
size_t idx = operandIdx;
1034-
rewriter.modifyOpInPlace(&curOp, [&curOp, &newFuncOp, it, idx] {
1035-
curOp.setOperand(idx, newFuncOp.getArgument(it->second));
1036-
});
1037-
}
1038-
}
10391039
// Since all newly created operations are in the beginning, reaching the
10401040
// end of them means that any later `vector.insert_strided_slice` should
10411041
// not be touched.

mlir/test/Conversion/ConvertToSPIRV/func-signature-vector-unroll.mlir

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,3 +189,76 @@ func.func @unsupported_scalable(%arg0 : vector<[8]xi32>) -> (vector<[8]xi32>) {
189189
return %arg0 : vector<[8]xi32>
190190
}
191191

192+
// -----
193+
194+
// Check that already legal function parameters are properly preserved across multiple blocks.
195+
196+
// CHECK-LABEL: func.func @legal_params_multiple_blocks_simple
197+
// CHECK-SAME: (%[[ARG0:.+]]: i32, %[[ARG1:.+]]: i32) -> i32
198+
func.func @legal_params_multiple_blocks_simple(%arg0: i32, %arg1: i32) -> i32 {
199+
// CHECK: %[[ADD0:.*]] = arith.addi %[[ARG0]], %[[ARG1]] : i32
200+
// CHECK: %[[ADD1:.*]] = arith.addi %[[ADD0]], %[[ARG1]] : i32
201+
// CHECK: return %[[ADD1]] : i32
202+
cf.br ^bb1(%arg0 : i32)
203+
^bb1(%acc0: i32):
204+
%acc1_val = arith.addi %acc0, %arg1 : i32
205+
cf.br ^bb2(%acc1_val : i32)
206+
^bb2(%acc1: i32):
207+
%acc2_val = arith.addi %acc1, %arg1 : i32
208+
cf.br ^bb3(%acc2_val : i32)
209+
^bb3(%acc_final: i32):
210+
return %acc_final : i32
211+
}
212+
213+
// -----
214+
215+
// Check that legal parameters and existing `vector.insert_strided_slice`s are properly preserved across multiple blocks.
216+
217+
// CHECK-LABEL: func.func @legal_params_with_vec_insert_multiple_blocks
218+
// CHECK-SAME: (%[[ARG0:.+]]: i32, %[[ARG1:.+]]: i32, %[[ARG2:.+]]: vector<4xi32>) -> vector<4xi32>
219+
func.func @legal_params_with_vec_insert_multiple_blocks(%arg0: i32, %arg1: i32, %arg2: vector<4xi32>) -> vector<4xi32> {
220+
// CHECK: %[[ADD0:.*]] = arith.addi %[[ARG0]], %[[ARG1]] : i32
221+
// CHECK: %[[ADD1:.*]] = arith.addi %[[ADD0]], %[[ARG1]] : i32
222+
// CHECK: %[[VEC1D:.*]] = vector.broadcast %[[ADD1]] : i32 to vector<1xi32>
223+
// CHECK: %[[VEC0:.*]] = vector.insert_strided_slice %[[VEC1D]], %[[ARG2]] {offsets = [1], strides = [1]} : vector<1xi32> into vector<4xi32>
224+
// CHECK: %[[VEC1:.*]] = vector.insert_strided_slice %[[VEC1D]], %[[VEC0]] {offsets = [2], strides = [1]} : vector<1xi32> into vector<4xi32>
225+
// CHECK: %[[RESULT:.*]] = vector.insert_strided_slice %[[VEC1D]], %[[VEC1]] {offsets = [3], strides = [1]} : vector<1xi32> into vector<4xi32>
226+
// CHECK: return %[[RESULT]] : vector<4xi32>
227+
cf.br ^bb1(%arg0 : i32)
228+
^bb1(%acc0: i32):
229+
%acc1_val = arith.addi %acc0, %arg1 : i32
230+
cf.br ^bb2(%acc1_val : i32)
231+
^bb2(%acc1: i32):
232+
%acc2_val = arith.addi %acc1, %arg1 : i32
233+
cf.br ^bb3(%acc2_val : i32)
234+
^bb3(%acc_final: i32):
235+
%scalar_vec = vector.broadcast %acc_final : i32 to vector<1xi32>
236+
%vec0 = vector.insert_strided_slice %scalar_vec, %arg2 {offsets = [1], strides = [1]} : vector<1xi32> into vector<4xi32>
237+
%vec1 = vector.insert_strided_slice %scalar_vec, %vec0 {offsets = [2], strides = [1]} : vector<1xi32> into vector<4xi32>
238+
%result = vector.insert_strided_slice %scalar_vec, %vec1 {offsets = [3], strides = [1]} : vector<1xi32> into vector<4xi32>
239+
return %result : vector<4xi32>
240+
}
241+
242+
// -----
243+
244+
// Check that already legal function parameters are preserved across a loop (which contains multiple blocks).
245+
246+
// CHECK-LABEL: @legal_params_for_loop
247+
// CHECK-SAME: (%[[ARG0:.+]]: i32, %[[ARG1:.+]]: i32, %[[ARG2:.+]]: i32)
248+
func.func @legal_params_for_loop(%arg0: i32, %arg1: i32, %arg2: i32) -> i32 {
249+
// CHECK: %[[CST0:.*]] = arith.constant 0 : index
250+
// CHECK: %[[CST1:.*]] = arith.constant 1 : index
251+
// CHECK: %[[UB:.*]] = arith.index_cast %[[ARG2]] : i32 to index
252+
// CHECK: %[[RESULT:.*]] = scf.for %[[STEP:.*]] = %[[CST0]] to %[[UB]] step %[[CST1]] iter_args(%[[ACC:.*]] = %[[ARG0]]) -> (i32) {
253+
// CHECK: %[[ADD:.*]] = arith.addi %[[ACC]], %[[ARG1]] : i32
254+
// CHECK: scf.yield %[[ADD]] : i32
255+
// CHECK: return %[[RESULT]] : i32
256+
%zero = arith.constant 0 : index
257+
%one = arith.constant 1 : index
258+
%ub = arith.index_cast %arg2 : i32 to index
259+
%result = scf.for %i = %zero to %ub step %one iter_args(%acc = %arg0) -> (i32) {
260+
%new_acc = arith.addi %acc, %arg1 : i32
261+
scf.yield %new_acc : i32
262+
}
263+
return %result : i32
264+
}

0 commit comments

Comments
 (0)