Skip to content

[mlir][vectorize] Support affine.apply in SuperVectorize #77968

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 51 additions & 25 deletions mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -711,18 +711,16 @@ struct VectorizationState {
BlockArgument replacement);

/// Registers the scalar replacement of a scalar value. 'replacement' must be
/// scalar. Both values must be block arguments. Operation results should be
/// replaced using the 'registerOp*' utilitites.
/// scalar.
///
/// This utility is used to register the replacement of block arguments
/// that are within the loop to be vectorized and will continue being scalar
/// within the vector loop.
/// or affine.apply results that are within the loop be vectorized and will
/// continue being scalar within the vector loop.
///
/// Example:
/// * 'replaced': induction variable of a loop to be vectorized.
/// * 'replacement': new induction variable in the new vector loop.
void registerValueScalarReplacement(BlockArgument replaced,
BlockArgument replacement);
void registerValueScalarReplacement(Value replaced, Value replacement);

/// Registers the scalar replacement of a scalar result returned from a
/// reduction loop. 'replacement' must be scalar.
Expand Down Expand Up @@ -772,7 +770,6 @@ struct VectorizationState {
/// Internal implementation to map input scalar values to new vector or scalar
/// values.
void registerValueVectorReplacementImpl(Value replaced, Value replacement);
void registerValueScalarReplacementImpl(Value replaced, Value replacement);
};

} // namespace
Expand Down Expand Up @@ -844,19 +841,22 @@ void VectorizationState::registerValueVectorReplacementImpl(Value replaced,
}

/// Registers the scalar replacement of a scalar value. 'replacement' must be
/// scalar. Both values must be block arguments. Operation results should be
/// replaced using the 'registerOp*' utilitites.
/// scalar.
///
/// This utility is used to register the replacement of block arguments
/// that are within the loop to be vectorized and will continue being scalar
/// within the vector loop.
/// or affine.apply results that are within the loop be vectorized and will
/// continue being scalar within the vector loop.
///
/// Example:
/// * 'replaced': induction variable of a loop to be vectorized.
/// * 'replacement': new induction variable in the new vector loop.
void VectorizationState::registerValueScalarReplacement(
BlockArgument replaced, BlockArgument replacement) {
registerValueScalarReplacementImpl(replaced, replacement);
void VectorizationState::registerValueScalarReplacement(Value replaced,
Value replacement) {
assert(!valueScalarReplacement.contains(replaced) &&
"Scalar value replacement already registered");
assert(!isa<VectorType>(replacement.getType()) &&
"Expected scalar type in scalar replacement");
valueScalarReplacement.map(replaced, replacement);
}

/// Registers the scalar replacement of a scalar result returned from a
Expand All @@ -879,15 +879,6 @@ void VectorizationState::registerLoopResultScalarReplacement(
loopResultScalarReplacement[replaced] = replacement;
}

void VectorizationState::registerValueScalarReplacementImpl(Value replaced,
Value replacement) {
assert(!valueScalarReplacement.contains(replaced) &&
"Scalar value replacement already registered");
assert(!isa<VectorType>(replacement.getType()) &&
"Expected scalar type in scalar replacement");
valueScalarReplacement.map(replaced, replacement);
}

/// Returns in 'replacedVals' the scalar replacement for values in 'inputVals'.
void VectorizationState::getScalarValueReplacementsFor(
ValueRange inputVals, SmallVectorImpl<Value> &replacedVals) {
Expand Down Expand Up @@ -978,6 +969,33 @@ static arith::ConstantOp vectorizeConstant(arith::ConstantOp constOp,
return newConstOp;
}

/// We have no need to vectorize affine.apply. However, we still need to
/// generate it and replace the operands with values in valueScalarReplacement.
static Operation *vectorizeAffineApplyOp(AffineApplyOp applyOp,
VectorizationState &state) {
SmallVector<Value, 8> updatedOperands;
for (Value operand : applyOp.getOperands()) {
if (state.valueVectorReplacement.contains(operand)) {
LLVM_DEBUG(
dbgs() << "\n[early-vect]+++++ affine.apply on vector operand\n");
return nullptr;
} else {
Value updatedOperand = state.valueScalarReplacement.lookupOrNull(operand);
if (!updatedOperand)
updatedOperand = operand;
updatedOperands.push_back(updatedOperand);
}
}

auto newApplyOp = state.builder.create<AffineApplyOp>(
applyOp.getLoc(), applyOp.getAffineMap(), updatedOperands);

// Register the new affine.apply result.
state.registerValueScalarReplacement(applyOp.getResult(),
newApplyOp.getResult());
return newApplyOp;
}

/// Creates a constant vector filled with the neutral elements of the given
/// reduction. The scalar type of vector elements will be taken from
/// `oldOperand`.
Expand Down Expand Up @@ -1184,11 +1202,17 @@ static Operation *vectorizeAffineLoad(AffineLoadOp loadOp,
SmallVector<Value, 8> indices;
indices.reserve(memRefType.getRank());
if (loadOp.getAffineMap() !=
state.builder.getMultiDimIdentityMap(memRefType.getRank()))
state.builder.getMultiDimIdentityMap(memRefType.getRank())) {
// Check the operand in loadOp affine map does not come from AffineApplyOp.
for (auto op : mapOperands) {
if (op.getDefiningOp<AffineApplyOp>())
return nullptr;
}
computeMemoryOpIndices(loadOp, loadOp.getAffineMap(), mapOperands, state,
indices);
else
} else {
indices.append(mapOperands.begin(), mapOperands.end());
}

// Compute permutation map using the information of new vector loops.
auto permutationMap = makePermutationMap(state.builder.getInsertionBlock(),
Expand Down Expand Up @@ -1493,6 +1517,8 @@ static Operation *vectorizeOneOperation(Operation *op,
return vectorizeAffineYieldOp(yieldOp, state);
if (auto constant = dyn_cast<arith::ConstantOp>(op))
return vectorizeConstant(constant, state);
if (auto applyOp = dyn_cast<AffineApplyOp>(op))
return vectorizeAffineApplyOp(applyOp, state);

// Other ops with regions are not supported.
if (op->getNumRegions() != 0)
Expand Down
159 changes: 159 additions & 0 deletions mlir/test/Dialect/Affine/SuperVectorize/vectorize_affine_apply.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
// RUN: mlir-opt %s -affine-super-vectorize="virtual-vector-size=8 test-fastest-varying=0" -split-input-file | FileCheck %s

// CHECK-DAG: #[[$MAP_ID0:map[0-9a-zA-Z_]*]] = affine_map<(d0) -> (d0 mod 12)>
// CHECK-DAG: #[[$MAP_ID1:map[0-9a-zA-Z_]*]] = affine_map<(d0) -> (d0 mod 16)>

// CHECK-LABEL: vec_affine_apply
// CHECK-SAME: (%[[ARG0:.*]]: memref<8x12x16xf32>, %[[ARG1:.*]]: memref<8x24x48xf32>) {
func.func @vec_affine_apply(%arg0: memref<8x12x16xf32>, %arg1: memref<8x24x48xf32>) {
// CHECK: affine.for %[[ARG2:.*]] = 0 to 8 {
// CHECK-NEXT: affine.for %[[ARG3:.*]] = 0 to 24 {
// CHECK-NEXT: affine.for %[[ARG4:.*]] = 0 to 48 step 8 {
// CHECK-NEXT: %[[S0:.*]] = affine.apply #[[$MAP_ID0]](%[[ARG3]])
// CHECK-NEXT: %[[S1:.*]] = affine.apply #[[$MAP_ID1]](%[[ARG4]])
// CHECK-NEXT: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
// CHECK-NEXT: %[[S2:.*]] = vector.transfer_read %[[ARG0]][%[[ARG2]], %[[S0]], %[[S1]]], %[[CST]] : memref<8x12x16xf32>, vector<8xf32>
// CHECK-NEXT: vector.transfer_write %[[S2]], %[[ARG1]][%[[ARG2]], %[[ARG3]], %[[ARG4]]] : vector<8xf32>, memref<8x24x48xf32>
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: return
affine.for %arg2 = 0 to 8 {
affine.for %arg3 = 0 to 24 {
affine.for %arg4 = 0 to 48 {
%0 = affine.apply affine_map<(d0) -> (d0 mod 12)>(%arg3)
%1 = affine.apply affine_map<(d0) -> (d0 mod 16)>(%arg4)
%2 = affine.load %arg0[%arg2, %0, %1] : memref<8x12x16xf32>
affine.store %2, %arg1[%arg2, %arg3, %arg4] : memref<8x24x48xf32>
}
}
}
return
}

// -----

// CHECK-DAG: #[[$MAP_ID2:map[0-9a-zA-Z_]*]] = affine_map<(d0) -> (d0 mod 16 + 1)>

// CHECK-LABEL: vec_affine_apply_2
// CHECK-SAME: (%[[ARG0:.*]]: memref<8x12x16xf32>, %[[ARG1:.*]]: memref<8x24x48xf32>) {
func.func @vec_affine_apply_2(%arg0: memref<8x12x16xf32>, %arg1: memref<8x24x48xf32>) {
// CHECK: affine.for %[[ARG2:.*]] = 0 to 8 {
// CHECK-NEXT: affine.for %[[ARG3:.*]] = 0 to 12 {
// CHECK-NEXT: affine.for %[[ARG4:.*]] = 0 to 48 step 8 {
// CHECK-NEXT: %[[S0:.*]] = affine.apply #[[$MAP_ID2]](%[[ARG4]])
// CHECK-NEXT: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
// CHECK-NEXT: %[[S1:.*]] = vector.transfer_read %[[ARG0]][%[[ARG2]], %[[ARG3]], %[[S0]]], %[[CST]] : memref<8x12x16xf32>, vector<8xf32>
// CHECK-NEXT: vector.transfer_write %[[S1]], %[[ARG1]][%[[ARG2]], %[[ARG3]], %[[ARG4]]] : vector<8xf32>, memref<8x24x48xf32>
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }
affine.for %arg2 = 0 to 8 {
affine.for %arg3 = 0 to 12 {
affine.for %arg4 = 0 to 48 {
%1 = affine.apply affine_map<(d0) -> (d0 mod 16 + 1)>(%arg4)
%2 = affine.load %arg0[%arg2, %arg3, %1] : memref<8x12x16xf32>
affine.store %2, %arg1[%arg2, %arg3, %arg4] : memref<8x24x48xf32>
}
}
}
return
}

// -----

// CHECK-LABEL: no_vec_affine_apply
// CHECK-SAME: (%[[ARG0:.*]]: memref<8x12x16xi32>, %[[ARG1:.*]]: memref<8x24x48xi32>) {
func.func @no_vec_affine_apply(%arg0: memref<8x12x16xi32>, %arg1: memref<8x24x48xi32>) {
// CHECK: affine.for %[[ARG2:.*]] = 0 to 8 {
// CHECK-NEXT: affine.for %[[ARG3:.*]] = 0 to 24 {
// CHECK-NEXT: affine.for %[[ARG4:.*]] = 0 to 48 {
// CHECK-NEXT: %[[S0:.*]] = affine.apply #[[$MAP_ID0]](%[[ARG3]])
// CHECK-NEXT: %[[S1:.*]] = affine.apply #[[$MAP_ID1]](%[[ARG4]])
// CHECK-NEXT: %[[S2:.*]] = affine.load %[[ARG0]][%[[ARG2]], %[[S0]], %[[S1]]] : memref<8x12x16xi32>
// CHECK-NEXT: %[[S3:.*]] = arith.index_cast %[[S2]] : i32 to index
// CHECK-NEXT: %[[S4:.*]] = affine.apply #[[$MAP_ID1]](%[[S3]])
// CHECK-NEXT: %[[S5:.*]] = arith.index_cast %[[S4]] : index to i32
// CHECK-NEXT: affine.store %[[S5]], %[[ARG1]][%[[ARG2]], %[[ARG3]], %[[ARG4]]] : memref<8x24x48xi32>
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: return
affine.for %arg2 = 0 to 8 {
affine.for %arg3 = 0 to 24 {
affine.for %arg4 = 0 to 48 {
%0 = affine.apply affine_map<(d0) -> (d0 mod 12)>(%arg3)
%1 = affine.apply affine_map<(d0) -> (d0 mod 16)>(%arg4)
%2 = affine.load %arg0[%arg2, %0, %1] : memref<8x12x16xi32>
%3 = arith.index_cast %2 : i32 to index
%4 = affine.apply affine_map<(d0) -> (d0 mod 16)>(%3)
%5 = arith.index_cast %4 : index to i32
affine.store %5, %arg1[%arg2, %arg3, %arg4] : memref<8x24x48xi32>
}
}
}
return
}

// -----

// CHECK-DAG: #[[$MAP_ID1:map[0-9a-zA-Z_]*]] = affine_map<(d0) -> (d0 mod 16)>

// CHECK-LABEL: affine_map_with_expr
// CHECK-SAME: (%[[ARG0:.*]]: memref<8x12x16xf32>, %[[ARG1:.*]]: memref<8x24x48xf32>) {
func.func @affine_map_with_expr(%arg0: memref<8x12x16xf32>, %arg1: memref<8x24x48xf32>) {
// CHECK: affine.for %[[ARG2:.*]] = 0 to 8 {
// CHECK-NEXT: affine.for %[[ARG3:.*]] = 0 to 12 {
// CHECK-NEXT: affine.for %[[ARG4:.*]] = 0 to 48 {
// CHECK-NEXT: %[[S0:.*]] = affine.apply #[[$MAP_ID1]](%[[ARG4]])
// CHECK-NEXT: %[[S1:.*]] = affine.load %[[ARG0]][%[[ARG2]], %[[ARG3]], %[[S0]] + 1] : memref<8x12x16xf32>
// CHECK-NEXT: affine.store %[[S1]], %[[ARG1]][%[[ARG2]], %[[ARG3]], %[[ARG4]]] : memref<8x24x48xf32>
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: return
affine.for %arg2 = 0 to 8 {
affine.for %arg3 = 0 to 12 {
affine.for %arg4 = 0 to 48 {
%1 = affine.apply affine_map<(d0) -> (d0 mod 16)>(%arg4)
%2 = affine.load %arg0[%arg2, %arg3, %1 + 1] : memref<8x12x16xf32>
affine.store %2, %arg1[%arg2, %arg3, %arg4] : memref<8x24x48xf32>
}
}
}
return
}

// -----

// CHECK-DAG: #[[$MAP_ID3:map[0-9a-zA-Z_]*]] = affine_map<(d0, d1, d2) -> (d0)>
// CHECK-DAG: #[[$MAP_ID4:map[0-9a-zA-Z_]*]] = affine_map<(d0, d1, d2) -> (d1)>
// CHECK-DAG: #[[$MAP_ID5:map[0-9a-zA-Z_]*]] = affine_map<(d0, d1, d2) -> (d2 + 1)>
// CHECK-DAG: #[[$MAP_ID6:map[0-9a-zA-Z_]*]] = affine_map<(d0, d1, d2) -> (0)>

// CHECK-LABEL: affine_map_with_expr_2
// CHECK-SAME: (%[[ARG0:.*]]: memref<8x12x16xf32>, %[[ARG1:.*]]: memref<8x24x48xf32>, %[[I0:.*]]: index) {
func.func @affine_map_with_expr_2(%arg0: memref<8x12x16xf32>, %arg1: memref<8x24x48xf32>, %i: index) {
// CHECK: affine.for %[[ARG3:.*]] = 0 to 8 {
// CHECK-NEXT: affine.for %[[ARG4:.*]] = 0 to 12 {
// CHECK-NEXT: affine.for %[[ARG5:.*]] = 0 to 48 step 8 {
// CHECK-NEXT: %[[S0:.*]] = affine.apply #[[$MAP_ID3]](%[[ARG3]], %[[ARG4]], %[[I0]])
// CHECK-NEXT: %[[S1:.*]] = affine.apply #[[$MAP_ID4]](%[[ARG3]], %[[ARG4]], %[[I0]])
// CHECK-NEXT: %[[S2:.*]] = affine.apply #[[$MAP_ID5]](%[[ARG3]], %[[ARG4]], %[[I0]])
// CHECK-NEXT: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
// CHECK-NEXT: %[[S3:.*]] = vector.transfer_read %[[ARG0]][%[[S0]], %[[S1]], %[[S2]]], %[[CST]] {permutation_map = #[[$MAP_ID6]]} : memref<8x12x16xf32>, vector<8xf32>
// CHECK-NEXT: vector.transfer_write %[[S3]], %[[ARG1]][%[[ARG3]], %[[ARG4]], %[[ARG5]]] : vector<8xf32>, memref<8x24x48xf32>
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: return
affine.for %arg2 = 0 to 8 {
affine.for %arg3 = 0 to 12 {
affine.for %arg4 = 0 to 48 {
%2 = affine.load %arg0[%arg2, %arg3, %i + 1] : memref<8x12x16xf32>
affine.store %2, %arg1[%arg2, %arg3, %arg4] : memref<8x24x48xf32>
}
}
}
return
}