Skip to content

Commit 24f52eb

Browse files
committed
[mlir][vectorize] Support affine.apply in SuperVectorize
We have no need to vectorize affine.apply inside the vectorizing loop. However, we still need to generate it in the original scalar form. We have to replace all its operands with the generated scalar operands in the vectorizing loop, e.g., induction variables.
1 parent a7bc9cb commit 24f52eb

File tree

2 files changed

+139
-23
lines changed

2 files changed

+139
-23
lines changed

mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp

Lines changed: 43 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -711,18 +711,16 @@ struct VectorizationState {
711711
BlockArgument replacement);
712712

713713
/// Registers the scalar replacement of a scalar value. 'replacement' must be
714-
/// scalar. Both values must be block arguments. Operation results should be
715-
/// replaced using the 'registerOp*' utilitites.
714+
/// scalar.
716715
///
717716
/// This utility is used to register the replacement of block arguments
718-
/// that are within the loop to be vectorized and will continue being scalar
719-
/// within the vector loop.
717+
/// or affine.apply results that are within the loop be vectorized and will
718+
/// continue being scalar within the vector loop.
720719
///
721720
/// Example:
722721
/// * 'replaced': induction variable of a loop to be vectorized.
723722
/// * 'replacement': new induction variable in the new vector loop.
724-
void registerValueScalarReplacement(BlockArgument replaced,
725-
BlockArgument replacement);
723+
void registerValueScalarReplacement(Value replaced, Value replacement);
726724

727725
/// Registers the scalar replacement of a scalar result returned from a
728726
/// reduction loop. 'replacement' must be scalar.
@@ -772,7 +770,6 @@ struct VectorizationState {
772770
/// Internal implementation to map input scalar values to new vector or scalar
773771
/// values.
774772
void registerValueVectorReplacementImpl(Value replaced, Value replacement);
775-
void registerValueScalarReplacementImpl(Value replaced, Value replacement);
776773
};
777774

778775
} // namespace
@@ -844,19 +841,22 @@ void VectorizationState::registerValueVectorReplacementImpl(Value replaced,
844841
}
845842

846843
/// Registers the scalar replacement of a scalar value. 'replacement' must be
847-
/// scalar. Both values must be block arguments. Operation results should be
848-
/// replaced using the 'registerOp*' utilitites.
844+
/// scalar.
849845
///
850846
/// This utility is used to register the replacement of block arguments
851-
/// that are within the loop to be vectorized and will continue being scalar
852-
/// within the vector loop.
847+
/// or affine.apply results that are within the loop be vectorized and will
848+
/// continue being scalar within the vector loop.
853849
///
854850
/// Example:
855851
/// * 'replaced': induction variable of a loop to be vectorized.
856852
/// * 'replacement': new induction variable in the new vector loop.
857-
void VectorizationState::registerValueScalarReplacement(
858-
BlockArgument replaced, BlockArgument replacement) {
859-
registerValueScalarReplacementImpl(replaced, replacement);
853+
void VectorizationState::registerValueScalarReplacement(Value replaced,
854+
Value replacement) {
855+
assert(!valueScalarReplacement.contains(replaced) &&
856+
"Scalar value replacement already registered");
857+
assert(!isa<VectorType>(replacement.getType()) &&
858+
"Expected scalar type in scalar replacement");
859+
valueScalarReplacement.map(replaced, replacement);
860860
}
861861

862862
/// Registers the scalar replacement of a scalar result returned from a
@@ -879,15 +879,6 @@ void VectorizationState::registerLoopResultScalarReplacement(
879879
loopResultScalarReplacement[replaced] = replacement;
880880
}
881881

882-
void VectorizationState::registerValueScalarReplacementImpl(Value replaced,
883-
Value replacement) {
884-
assert(!valueScalarReplacement.contains(replaced) &&
885-
"Scalar value replacement already registered");
886-
assert(!isa<VectorType>(replacement.getType()) &&
887-
"Expected scalar type in scalar replacement");
888-
valueScalarReplacement.map(replaced, replacement);
889-
}
890-
891882
/// Returns in 'replacedVals' the scalar replacement for values in 'inputVals'.
892883
void VectorizationState::getScalarValueReplacementsFor(
893884
ValueRange inputVals, SmallVectorImpl<Value> &replacedVals) {
@@ -978,6 +969,33 @@ static arith::ConstantOp vectorizeConstant(arith::ConstantOp constOp,
978969
return newConstOp;
979970
}
980971

972+
/// We have no need to vectorize affine.apply. However, we still need to
973+
/// generate it and replace the operands with values in valueScalarReplacement.
974+
static Operation *vectorizeAffineApplyOp(AffineApplyOp applyOp,
975+
VectorizationState &state) {
976+
SmallVector<Value, 8> updatedOperands;
977+
for (Value operand : applyOp.getOperands()) {
978+
if (state.valueVectorReplacement.contains(operand)) {
979+
LLVM_DEBUG(
980+
dbgs() << "\n[early-vect]+++++ affine.apply on vector operand\n");
981+
return nullptr;
982+
} else {
983+
Value updatedOperand = state.valueScalarReplacement.lookupOrNull(operand);
984+
if (!updatedOperand)
985+
updatedOperand = operand;
986+
updatedOperands.push_back(updatedOperand);
987+
}
988+
}
989+
990+
auto newApplyOp = state.builder.create<AffineApplyOp>(
991+
applyOp.getLoc(), applyOp.getAffineMap(), updatedOperands);
992+
993+
// Register the new affine.apply result.
994+
state.registerValueScalarReplacement(applyOp.getResult(),
995+
newApplyOp.getResult());
996+
return newApplyOp;
997+
}
998+
981999
/// Creates a constant vector filled with the neutral elements of the given
9821000
/// reduction. The scalar type of vector elements will be taken from
9831001
/// `oldOperand`.
@@ -1493,6 +1511,8 @@ static Operation *vectorizeOneOperation(Operation *op,
14931511
return vectorizeAffineYieldOp(yieldOp, state);
14941512
if (auto constant = dyn_cast<arith::ConstantOp>(op))
14951513
return vectorizeConstant(constant, state);
1514+
if (auto applyOp = dyn_cast<AffineApplyOp>(op))
1515+
return vectorizeAffineApplyOp(applyOp, state);
14961516

14971517
// Other ops with regions are not supported.
14981518
if (op->getNumRegions() != 0)
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
// RUN: mlir-opt %s -affine-super-vectorize="virtual-vector-size=8 test-fastest-varying=0" -split-input-file | FileCheck %s
2+
3+
// CHECK-DAG: #[[$MAP_ID0:map[0-9a-zA-Z_]*]] = affine_map<(d0) -> (d0 mod 12)>
4+
// CHECK-DAG: #[[$MAP_ID1:map[0-9a-zA-Z_]*]] = affine_map<(d0) -> (d0 mod 16)>
5+
6+
// CHECK-LABEL: vec_affine_apply
7+
// CHECK-SAME: (%[[ARG0:.*]]: memref<8x12x16xf32>, %[[ARG1:.*]]: memref<8x24x48xf32>) {
8+
func.func @vec_affine_apply(%arg0: memref<8x12x16xf32>, %arg1: memref<8x24x48xf32>) {
9+
// CHECK: affine.for %[[ARG2:.*]] = 0 to 8 {
10+
// CHECK-NEXT: affine.for %[[ARG3:.*]] = 0 to 24 {
11+
// CHECK-NEXT: affine.for %[[ARG4:.*]] = 0 to 48 step 8 {
12+
// CHECK-NEXT: %[[S0:.*]] = affine.apply #[[$MAP_ID0]](%[[ARG3]])
13+
// CHECK-NEXT: %[[S1:.*]] = affine.apply #[[$MAP_ID1]](%[[ARG4]])
14+
// CHECK-NEXT: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
15+
// CHECK-NEXT: %[[S2:.*]] = vector.transfer_read %[[ARG0]][%[[ARG2]], %[[S0]], %[[S1]]], %[[CST]] : memref<8x12x16xf32>, vector<8xf32>
16+
// CHECK-NEXT: vector.transfer_write %[[S2]], %[[ARG1]][%[[ARG2]], %[[ARG3]], %[[ARG4]]] : vector<8xf32>, memref<8x24x48xf32>
17+
// CHECK-NEXT: }
18+
// CHECK-NEXT: }
19+
// CHECK-NEXT: }
20+
// CHECK-NEXT: return
21+
affine.for %arg2 = 0 to 8 {
22+
affine.for %arg3 = 0 to 24 {
23+
affine.for %arg4 = 0 to 48 {
24+
%0 = affine.apply affine_map<(d0) -> (d0 mod 12)>(%arg3)
25+
%1 = affine.apply affine_map<(d0) -> (d0 mod 16)>(%arg4)
26+
%2 = affine.load %arg0[%arg2, %0, %1] : memref<8x12x16xf32>
27+
affine.store %2, %arg1[%arg2, %arg3, %arg4] : memref<8x24x48xf32>
28+
}
29+
}
30+
}
31+
return
32+
}
33+
34+
// -----
35+
36+
// CHECK-DAG: #[[$MAP:map[0-9a-zA-Z_]*]] = affine_map<(d0) -> (d0 mod 16 + 1)>
37+
38+
// CHECK-LABEL: vec_affine_apply_2
39+
// CHECK-SAME: (%[[ARG0:.*]]: memref<8x12x16xf32>, %[[ARG1:.*]]: memref<8x24x48xf32>) {
40+
func.func @vec_affine_apply_2(%arg0: memref<8x12x16xf32>, %arg1: memref<8x24x48xf32>) {
41+
// CHECK: affine.for %[[ARG2:.*]] = 0 to 8 {
42+
// CHECK-NEXT: affine.for %[[ARG3:.*]] = 0 to 12 {
43+
// CHECK-NEXT: affine.for %[[ARG4:.*]] = 0 to 48 step 8 {
44+
// CHECK-NEXT: %[[S0:.*]] = affine.apply #[[$MAP]](%[[ARG4]])
45+
// CHECK-NEXT: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
46+
// CHECK-NEXT: %[[S1:.*]] = vector.transfer_read %[[ARG0]][%[[ARG2]], %[[ARG3]], %[[S0]]], %[[CST]] : memref<8x12x16xf32>, vector<8xf32>
47+
// CHECK-NEXT: vector.transfer_write %[[S1]], %[[ARG1]][%[[ARG2]], %[[ARG3]], %[[ARG4]]] : vector<8xf32>, memref<8x24x48xf32>
48+
// CHECK-NEXT: }
49+
// CHECK-NEXT: }
50+
// CHECK-NEXT: }
51+
affine.for %arg2 = 0 to 8 {
52+
affine.for %arg3 = 0 to 12 {
53+
affine.for %arg4 = 0 to 48 {
54+
%1 = affine.apply affine_map<(d0) -> (d0 mod 16 + 1)>(%arg4)
55+
%2 = affine.load %arg0[%arg2, %arg3, %1] : memref<8x12x16xf32>
56+
affine.store %2, %arg1[%arg2, %arg3, %arg4] : memref<8x24x48xf32>
57+
}
58+
}
59+
}
60+
return
61+
}
62+
63+
// -----
64+
65+
// CHECK-LABEL: no_vec_affine_apply
66+
// CHECK-SAME: (%[[ARG0:.*]]: memref<8x12x16xi32>, %[[ARG1:.*]]: memref<8x24x48xi32>) {
67+
func.func @no_vec_affine_apply(%arg0: memref<8x12x16xi32>, %arg1: memref<8x24x48xi32>) {
68+
// CHECK: affine.for %[[ARG2:.*]] = 0 to 8 {
69+
// CHECK-NEXT: affine.for %[[ARG3:.*]] = 0 to 24 {
70+
// CHECK-NEXT: affine.for %[[ARG4:.*]] = 0 to 48 {
71+
// CHECK-NEXT: %[[S0:.*]] = affine.apply #[[$MAP_ID0]](%[[ARG3]])
72+
// CHECK-NEXT: %[[S1:.*]] = affine.apply #[[$MAP_ID1]](%[[ARG4]])
73+
// CHECK-NEXT: %[[S2:.*]] = affine.load %[[ARG0]][%[[ARG2]], %[[S0]], %[[S1]]] : memref<8x12x16xi32>
74+
// CHECK-NEXT: %[[S3:.*]] = arith.index_cast %[[S2]] : i32 to index
75+
// CHECK-NEXT: %[[S4:.*]] = affine.apply #[[$MAP_ID1]](%[[S3]])
76+
// CHECK-NEXT: %[[S5:.*]] = arith.index_cast %[[S4]] : index to i32
77+
// CHECK-NEXT: affine.store %[[S5]], %[[ARG1]][%[[ARG2]], %[[ARG3]], %[[ARG4]]] : memref<8x24x48xi32>
78+
// CHECK-NEXT: }
79+
// CHECK-NEXT: }
80+
// CHECK-NEXT: }
81+
// CHECK-NEXT: return
82+
affine.for %arg2 = 0 to 8 {
83+
affine.for %arg3 = 0 to 24 {
84+
affine.for %arg4 = 0 to 48 {
85+
%0 = affine.apply affine_map<(d0) -> (d0 mod 12)>(%arg3)
86+
%1 = affine.apply affine_map<(d0) -> (d0 mod 16)>(%arg4)
87+
%2 = affine.load %arg0[%arg2, %0, %1] : memref<8x12x16xi32>
88+
%3 = arith.index_cast %2 : i32 to index
89+
%4 = affine.apply affine_map<(d0) -> (d0 mod 16)>(%3)
90+
%5 = arith.index_cast %4 : index to i32
91+
affine.store %5, %arg1[%arg2, %arg3, %arg4] : memref<8x24x48xi32>
92+
}
93+
}
94+
}
95+
return
96+
}

0 commit comments

Comments
 (0)