Skip to content

Commit 7517e24

Browse files
committed
[mlir][linalg] Promote operands for convolution vectorization
We are already doing this for depthwise convolution and pooling. This helps to preserve the promotion semantics from Linalg op definitions to lower layers. Along the way, fixed the type mismatch issue in the existing `promote` implementation. Reviewed By: kuhar Differential Revision: https://reviews.llvm.org/D148471
1 parent 6c4219f commit 7517e24

File tree

2 files changed

+33
-22
lines changed

2 files changed

+33
-22
lines changed

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2512,13 +2512,38 @@ struct Conv1DGenerator
25122512
.getOperation();
25132513
}
25142514

2515+
// Take a value and widen to have the same element type as `ty`.
2516+
Value promote(RewriterBase &rewriter, Location loc, Value val, Type ty) {
2517+
const Type srcElementType = getElementTypeOrSelf(val.getType());
2518+
const Type dstElementType = getElementTypeOrSelf(ty);
2519+
assert(isa<IntegerType>(dstElementType) || isa<FloatType>(dstElementType));
2520+
if (srcElementType == dstElementType)
2521+
return val;
2522+
2523+
const int64_t srcWidth = srcElementType.getIntOrFloatBitWidth();
2524+
const int64_t dstWidth = dstElementType.getIntOrFloatBitWidth();
2525+
const Type dstType =
2526+
cast<ShapedType>(val.getType()).cloneWith(std::nullopt, dstElementType);
2527+
2528+
if (isa<FloatType>(dstElementType) && srcWidth < dstWidth)
2529+
return rewriter.create<arith::ExtFOp>(loc, dstType, val);
2530+
2531+
if (isa<IntegerType>(dstElementType) && srcWidth < dstWidth)
2532+
return rewriter.create<arith::ExtSIOp>(loc, dstType, val);
2533+
2534+
assert(false && "unhandled promotion case");
2535+
return nullptr;
2536+
}
2537+
25152538
// Create a contraction: lhs{n, w, c} * rhs{c, f} -> res{n, w, f}
25162539
Value conv1dSliceAsContraction(RewriterBase &rewriter, Location loc,
25172540
Value lhs, Value rhs, Value res) {
25182541
vector::IteratorType par = vector::IteratorType::parallel;
25192542
vector::IteratorType red = vector::IteratorType::reduction;
25202543
AffineExpr n, w, f, c;
25212544
bindDims(ctx, n, w, f, c);
2545+
lhs = promote(rewriter, loc, lhs, res.getType());
2546+
rhs = promote(rewriter, loc, rhs, res.getType());
25222547
return rewriter.create<vector::ContractionOp>(
25232548
loc, lhs, rhs, res,
25242549
/*indexingMaps=*/MapList{{n, w, c}, {c, f}, {n, w, f}},
@@ -2666,24 +2691,6 @@ struct Conv1DGenerator
26662691
.getOperation();
26672692
}
26682693

2669-
// Take a value of element type T and widen to the destination type.
2670-
Value promote(RewriterBase &rewriter, Location loc, Value val, Type ty) {
2671-
if (val.getType() == ty)
2672-
return val;
2673-
2674-
const int64_t srcWidth =
2675-
getElementTypeOrSelf(val.getType()).getIntOrFloatBitWidth();
2676-
const int64_t destWidth = getElementTypeOrSelf(ty).getIntOrFloatBitWidth();
2677-
2678-
if (getElementTypeOrSelf(ty).isa<FloatType>() && srcWidth < destWidth)
2679-
return rewriter.create<arith::ExtFOp>(loc, ty, val);
2680-
2681-
if (getElementTypeOrSelf(ty).isa<IntegerType>() && srcWidth < destWidth)
2682-
return rewriter.create<arith::ExtSIOp>(loc, ty, val);
2683-
2684-
return nullptr;
2685-
}
2686-
26872694
/// Lower lhs{n, w, c} * rhs{c} -> res{n, w, c} to MulAcc
26882695
Value depthwiseConv1dSliceAsMulAcc(RewriterBase &rewriter, Location loc,
26892696
Value lhs, Value rhs, Value res) {

mlir/test/Dialect/Linalg/vectorize-convolution.mlir

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -100,18 +100,22 @@ func.func @conv1d_nwc_4x2x8_i8i8i32_memref(%input: memref<4x6x3xi8>, %filter: me
100100
// CHECK-SAME: {offsets = [0, 1, 0], sizes = [4, 1, 8], strides = [1, 1, 1]} : vector<4x2x8xi32> to vector<4x1x8xi32>
101101

102102
/// w == 0, kw == 0
103+
// CHECK: %[[EXT_LHS_0:.+]] = arith.extsi %[[V_INPUT_0]] : vector<4x1x3xi8> to vector<4x1x3xi32>
104+
// CHECK: %[[EXT_RHS_0:.+]] = arith.extsi %[[V_FILTER]] : vector<3x8xi8> to vector<3x8xi32>
103105
// CHECK: %[[CONTRACT_0:.+]] = vector.contract {
104106
// CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
105107
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]
106-
// CHECK-SAME: %[[V_INPUT_0]], %[[V_FILTER]], %[[V_OUTPUT_0]]
107-
// CHECK-SAME: : vector<4x1x3xi8>, vector<3x8xi8> into vector<4x1x8xi32>
108+
// CHECK-SAME: %[[EXT_LHS_0]], %[[EXT_RHS_0]], %[[V_OUTPUT_0]]
109+
// CHECK-SAME: : vector<4x1x3xi32>, vector<3x8xi32> into vector<4x1x8xi32>
108110

109111
/// w == 1, kw == 0
112+
// CHECK: %[[EXT_LHS_1:.+]] = arith.extsi %[[V_INPUT_1]] : vector<4x1x3xi8> to vector<4x1x3xi32>
113+
// CHECK: %[[EXT_RHS_1:.+]] = arith.extsi %[[V_FILTER]] : vector<3x8xi8> to vector<3x8xi32>
110114
// CHECK: %[[CONTRACT_1:.+]] = vector.contract {
111115
// CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
112116
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]
113-
// CHECK-SAME: %[[V_INPUT_1]], %[[V_FILTER]], %[[V_OUTPUT_1]]
114-
// CHECK-SAME: : vector<4x1x3xi8>, vector<3x8xi8> into vector<4x1x8xi32>
117+
// CHECK-SAME: %[[EXT_LHS_1]], %[[EXT_RHS_1]], %[[V_OUTPUT_1]]
118+
// CHECK-SAME: : vector<4x1x3xi32>, vector<3x8xi32> into vector<4x1x8xi32>
115119

116120
/// w == 0, kw == 0
117121
// CHECK: %[[RES_0:.+]] = vector.insert_strided_slice %[[CONTRACT_0]], %[[V_OUTPUT_R]]

0 commit comments

Comments
 (0)