Skip to content

Commit 7bdcc9f

Browse files
committed
[mlir][vectorize] Support affine.apply in SuperVectorize
We have no need to vectorize affine.apply inside the vectorizing loop. However, we still need to generate it in the original scalar form. We have to replace all its operands with the generated scalar operands in the vectorizing loop, e.g., induction variables.
1 parent 375bd22 commit 7bdcc9f

File tree

2 files changed

+96
-4
lines changed

2 files changed

+96
-4
lines changed

mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -721,8 +721,7 @@ struct VectorizationState {
721721
/// Example:
722722
/// * 'replaced': induction variable of a loop to be vectorized.
723723
/// * 'replacement': new induction variable in the new vector loop.
724-
void registerValueScalarReplacement(BlockArgument replaced,
725-
BlockArgument replacement);
724+
void registerValueScalarReplacement(Value replaced, Value replacement);
726725

727726
/// Registers the scalar replacement of a scalar result returned from a
728727
/// reduction loop. 'replacement' must be scalar.
@@ -854,8 +853,8 @@ void VectorizationState::registerValueVectorReplacementImpl(Value replaced,
854853
/// Example:
855854
/// * 'replaced': induction variable of a loop to be vectorized.
856855
/// * 'replacement': new induction variable in the new vector loop.
857-
void VectorizationState::registerValueScalarReplacement(
858-
BlockArgument replaced, BlockArgument replacement) {
856+
void VectorizationState::registerValueScalarReplacement(Value replaced,
857+
Value replacement) {
859858
registerValueScalarReplacementImpl(replaced, replacement);
860859
}
861860

@@ -978,6 +977,32 @@ static arith::ConstantOp vectorizeConstant(arith::ConstantOp constOp,
978977
return newConstOp;
979978
}
980979

980+
/// We have no need to vectorize affine.apply. However, we still need to
981+
/// generate it and replace the operands with values in valueScalarReplacement.
982+
static Operation *vectorizeAffineApplyOp(AffineApplyOp applyOp,
983+
VectorizationState &state) {
984+
SmallVector<Value, 8> updatedOperands;
985+
for (Value operand : applyOp.getOperands()) {
986+
Value updatedOperand = operand;
987+
if (state.valueScalarReplacement.contains(operand)) {
988+
updatedOperand = state.valueScalarReplacement.lookupOrDefault(operand);
989+
} else if (state.valueVectorReplacement.contains(operand)) {
990+
LLVM_DEBUG(
991+
dbgs() << "\n[early-vect]+++++ affine.apply on vector operand\n");
992+
return nullptr;
993+
}
994+
updatedOperands.push_back(updatedOperand);
995+
}
996+
997+
auto newApplyOp = state.builder.create<AffineApplyOp>(
998+
applyOp.getLoc(), applyOp.getAffineMap(), updatedOperands);
999+
1000+
// Register the new affine.apply result.
1001+
state.registerValueScalarReplacement(applyOp.getResult(),
1002+
newApplyOp.getResult());
1003+
return newApplyOp;
1004+
}
1005+
9811006
/// Creates a constant vector filled with the neutral elements of the given
9821007
/// reduction. The scalar type of vector elements will be taken from
9831008
/// `oldOperand`.
@@ -1493,6 +1518,8 @@ static Operation *vectorizeOneOperation(Operation *op,
14931518
return vectorizeAffineYieldOp(yieldOp, state);
14941519
if (auto constant = dyn_cast<arith::ConstantOp>(op))
14951520
return vectorizeConstant(constant, state);
1521+
if (auto applyOp = dyn_cast<AffineApplyOp>(op))
1522+
return vectorizeAffineApplyOp(applyOp, state);
14961523

14971524
// Other ops with regions are not supported.
14981525
if (op->getNumRegions() != 0)
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
// RUN: mlir-opt %s -affine-super-vectorize="virtual-vector-size=8 test-fastest-varying=0" -split-input-file | FileCheck %s
2+
3+
// CHECK-DAG: #[[$MAP_ID0:map[0-9a-zA-Z_]*]] = affine_map<(d0) -> (d0 mod 12)>
4+
// CHECK-DAG: #[[$MAP_ID1:map[0-9a-zA-Z_]*]] = affine_map<(d0) -> (d0 mod 16)>
5+
6+
// CHECK-LABEL: vec_affine_apply
7+
// CHECK-SAME: (%[[ARG0:.*]]: memref<8x12x16xf32>, %[[ARG1:.*]]: memref<8x24x48xf32>) {
8+
func.func @vec_affine_apply(%arg0: memref<8x12x16xf32>, %arg1: memref<8x24x48xf32>) {
9+
// CHECK: affine.for %[[ARG2:.*]] = 0 to 8 {
10+
// CHECK-NEXT: affine.for %[[ARG3:.*]] = 0 to 24 {
11+
// CHECK-NEXT: affine.for %[[ARG4:.*]] = 0 to 48 step 8 {
12+
// CHECK-NEXT: %[[S0:.*]] = affine.apply #[[$MAP_ID0]](%[[ARG3]])
13+
// CHECK-NEXT: %[[S1:.*]] = affine.apply #[[$MAP_ID1]](%[[ARG4]])
14+
// CHECK-NEXT: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
15+
// CHECK-NEXT: %[[S2:.*]] = vector.transfer_read %[[ARG0]][%[[ARG2]], %[[S0]], %[[S1]]], %[[CST]] : memref<8x12x16xf32>, vector<8xf32>
16+
// CHECK-NEXT: vector.transfer_write %[[S2]], %[[ARG1]][%[[ARG2]], %[[ARG3]], %[[ARG4]]] : vector<8xf32>, memref<8x24x48xf32>
17+
// CHECK-NEXT: }
18+
// CHECK-NEXT: }
19+
// CHECK-NEXT: }
20+
// CHECK-NEXT: return
21+
affine.for %arg2 = 0 to 8 {
22+
affine.for %arg3 = 0 to 24 {
23+
affine.for %arg4 = 0 to 48 {
24+
%0 = affine.apply affine_map<(d0) -> (d0 mod 12)>(%arg3)
25+
%1 = affine.apply affine_map<(d0) -> (d0 mod 16)>(%arg4)
26+
%2 = affine.load %arg0[%arg2, %0, %1] : memref<8x12x16xf32>
27+
affine.store %2, %arg1[%arg2, %arg3, %arg4] : memref<8x24x48xf32>
28+
}
29+
}
30+
}
31+
return
32+
}
33+
34+
// CHECK-LABEL: no_vec_affine_apply
35+
// CHECK-SAME: (%[[ARG0:.*]]: memref<8x12x16xi32>, %[[ARG1:.*]]: memref<8x24x48xi32>) {
36+
func.func @no_vec_affine_apply(%arg0: memref<8x12x16xi32>, %arg1: memref<8x24x48xi32>) {
37+
// CHECK: affine.for %[[ARG2:.*]] = 0 to 8 {
38+
// CHECK-NEXT: affine.for %[[ARG3:.*]] = 0 to 24 {
39+
// CHECK-NEXT: affine.for %[[ARG4:.*]] = 0 to 48 {
40+
// CHECK-NEXT: %[[S0:.*]] = affine.apply #[[$MAP_ID0]](%[[ARG3]])
41+
// CHECK-NEXT: %[[S1:.*]] = affine.apply #[[$MAP_ID1]](%[[ARG4]])
42+
// CHECK-NEXT: %[[S2:.*]] = affine.load %[[ARG0]][%[[ARG2]], %[[S0]], %[[S1]]] : memref<8x12x16xi32>
43+
// CHECK-NEXT: %[[S3:.*]] = arith.index_cast %[[S2]] : i32 to index
44+
// CHECK-NEXT: %[[S4:.*]] = affine.apply #[[$MAP_ID1]](%[[S3]])
45+
// CHECK-NEXT: %[[S5:.*]] = arith.index_cast %[[S4]] : index to i32
46+
// CHECK-NEXT: affine.store %[[S5]], %[[ARG1]][%[[ARG2]], %[[ARG3]], %[[ARG4]]] : memref<8x24x48xi32>
47+
// CHECK-NEXT: }
48+
// CHECK-NEXT: }
49+
// CHECK-NEXT: }
50+
// CHECK-NEXT: return
51+
affine.for %arg2 = 0 to 8 {
52+
affine.for %arg3 = 0 to 24 {
53+
affine.for %arg4 = 0 to 48 {
54+
%0 = affine.apply affine_map<(d0) -> (d0 mod 12)>(%arg3)
55+
%1 = affine.apply affine_map<(d0) -> (d0 mod 16)>(%arg4)
56+
%2 = affine.load %arg0[%arg2, %0, %1] : memref<8x12x16xi32>
57+
%3 = arith.index_cast %2 : i32 to index
58+
%4 = affine.apply affine_map<(d0) -> (d0 mod 16)>(%3)
59+
%5 = arith.index_cast %4 : index to i32
60+
affine.store %5, %arg1[%arg2, %arg3, %arg4] : memref<8x24x48xi32>
61+
}
62+
}
63+
}
64+
return
65+
}

0 commit comments

Comments
 (0)