-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][linalg] Use vector.shuffle to flatten conv filter #75038
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][linalg] Use vector.shuffle to flatten conv filter #75038
Conversation
Updates the vectorisation of 1D depthwise convolution when flattening the channel dimension (introduced in llvm#71918). In particular - how the convolution filter is "flattened". ATM, the vectoriser will use `vector.shape_cast`: ```mlir %b_filter = vector.broadcast %filter : vector<4xf32> to vector<3x2x4xf32> %sc_filter = vector.shape_cast %b_filter : vector<3x2x4xf32> to vector<3x8xf32> ``` This lowering is not ideal - `vector.shape_cast` can be convenient when it's folded away, but that's not happening in this case. Instead, this patch updates the vectoriser to use `vector.shuffle` (the overall result is identical): ```mlir %sh_filter = vector.shuffle %filter, %filter [0, 1, 2, 3, 0, 1, 2, 3] : vector<4xf32>, vector<4xf32> %b_filter = vector.broadcast %sh_filter : vector<8xf32> to vector<3x8xf32> ```
@llvm/pr-subscribers-mlir-linalg @llvm/pr-subscribers-mlir Author: Andrzej Warzyński (banach-space) ChangesUpdates the vectorisation of 1D depthwise convolution when flattening %b_filter = vector.broadcast %filter : vector<4xf32> to vector<3x2x4xf32>
%sc_filter = vector.shape_cast %b_filter : vector<3x2x4xf32> to vector<3x8xf32> This lowering is not ideal - %sh_filter = vector.shuffle %filter, %filter
[0, 1, 2, 3, 0, 1, 2, 3] : vector<4xf32>, vector<4xf32>
%b_filter = vector.broadcast %sh_filter : vector<8xf32> to vector<3x8xf32> Full diff: https://github.com/llvm/llvm-project/pull/75038.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index c21d007c931b9b..d956fd4fdd9bd8 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -2910,17 +2910,16 @@ struct Conv1DGenerator
for (int64_t w = 0; w < wSize; w += wSizeStep) {
Value lhsVal = lhsVals[linearIndex(kw, w)];
Value resVal = resVals[w];
- ShapedType filterBCastTy = cast<ShapedType>(resVal.getType());
if (flatten) {
- // Flatten the input and filter vectors (collapse the channel
+ // Flatten the input and output vectors (collapse the channel
// dimension)
lhsVal = rewriter.create<vector::ShapeCastOp>(
loc, lhsCastType, lhsVals[linearIndex(kw, w)]);
resVal = rewriter.create<vector::ShapeCastOp>(loc, resCastType,
resVals[w]);
}
- resVals[w] = depthwiseConv1dSliceAsMulAcc(
- rewriter, loc, lhsVal, rhsVals[kw], resVal, filterBCastTy, flatten);
+ resVals[w] = depthwiseConv1dSliceAsMulAcc(rewriter, loc, lhsVal,
+ rhsVals[kw], resVal, flatten);
if (flatten) {
// Un-flatten the output vector (restore the channel dimension)
resVals[w] = rewriter.create<vector::ShapeCastOp>(
@@ -2964,20 +2963,32 @@ struct Conv1DGenerator
/// to MulAcc.
Value depthwiseConv1dSliceAsMulAcc(RewriterBase &rewriter, Location loc,
Value lhs, Value rhs, Value res,
- ShapedType bcastTy, bool flatten) {
+ bool flatten) {
auto rhsTy = cast<ShapedType>(rhs.getType());
auto resTy = cast<ShapedType>(res.getType());
// TODO(suderman): Change this to use a vector.ima intrinsic.
lhs = promote(rewriter, loc, lhs, resTy);
- rhs = rewriter.create<vector::BroadcastOp>(
- loc, bcastTy.clone(rhsTy.getElementType()), rhs);
if (flatten) {
- // Flatten the channel dimension
- rhs = rewriter.create<vector::ShapeCastOp>(
- loc, resTy.clone(rhsTy.getElementType()), rhs);
+ // There are two options for handling the filter:
+ // * shape_cast(broadcast(filter))
+ // * broadcast(shuffle(filter))
+ // Opt for the option without shape_cast to simplify the codegen.
+ auto rhsSize = rhs.getType().cast<VectorType>().getShape()[0];
+ auto resSize = res.getType().cast<VectorType>().getShape()[1];
+
+ SmallVector<int64_t, 16> indicies;
+ for (int i = 0; i < resSize / rhsSize; ++i) {
+ for (int j = 0; j < rhsSize; ++j)
+ indicies.push_back(j);
+ }
+
+ rhs = rewriter.create<vector::ShuffleOp>(loc, rhs, rhs, indicies);
}
+ // Broadcast the filter to match the output vector
+ rhs = rewriter.create<vector::BroadcastOp>(
+ loc, resTy.clone(rhsTy.getElementType()), rhs);
rhs = promote(rewriter, loc, rhs, resTy);
diff --git a/mlir/test/Dialect/Linalg/vectorize-convolution-flatten.mlir b/mlir/test/Dialect/Linalg/vectorize-convolution-flatten.mlir
index a242d09671825b..afb59cb26188a6 100644
--- a/mlir/test/Dialect/Linalg/vectorize-convolution-flatten.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-convolution-flatten.mlir
@@ -36,9 +36,10 @@ module attributes {transform.with_named_sequence} {
/// w == 0, kw = 0
// CHECK: %[[SC_INPUT:.*]] = vector.shape_cast %[[V_INPUT_R]] : vector<1x8x3xi8> to vector<1x24xi8>
// CHECK: %[[SC_OUTPUT:.*]] = vector.shape_cast %[[V_OUTPUT_R]] : vector<1x8x3xi8> to vector<1x24xi8>
-// CHECK: %[[B_FILTER:.*]] = vector.broadcast %[[V_FILTER_0]] : vector<3xi8> to vector<1x8x3xi8>
-// CHECK: %[[SC_FILTER:.*]] = vector.shape_cast %[[B_FILTER]] : vector<1x8x3xi8> to vector<1x24xi8>
-// CHECK: %[[MULI:.*]] = arith.muli %[[SC_INPUT]], %[[SC_FILTER]] : vector<1x24xi8>
+// CHECK: %[[SH_FILTER_0:.*]] = vector.shuffle %[[V_FILTER_0]], %[[V_FILTER_0]]
+// CHECK-SAME: [0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2] : vector<3xi8>, vector<3xi8>
+// CHECK: %[[B_FILTER:.*]] = vector.broadcast %[[SH_FILTER_0]] : vector<24xi8> to vector<1x24xi8>
+// CHECK: %[[MULI:.*]] = arith.muli %[[SC_INPUT]], %[[B_FILTER]] : vector<1x24xi8>
// CHECK: %[[ADDI:.*]] = arith.addi %[[MULI]], %[[SC_OUTPUT]] : vector<1x24xi8>
// Write the result back in one shot.
@@ -80,15 +81,17 @@ func.func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref_dillation_2(%input: memref<3
/// w == 0, kw = 0
// CHECK: %[[SC_V_INPUT_0:.*]] = vector.shape_cast %[[V_INPUT_0]] : vector<3x2x4xf32> to vector<3x8xf32>
// CHECK: %[[SC_V_OUTPUT_R:.*]] = vector.shape_cast %[[V_OUTPUT_R]] : vector<3x2x4xf32> to vector<3x8xf32>
-// CHECK: %[[B_FILTER_0:.*]] = vector.broadcast %[[V_FILTER_0]] : vector<4xf32> to vector<3x2x4xf32>
-// CHECK: %[[SC_B_FILTER_0:.*]] = vector.shape_cast %[[B_FILTER_0]] : vector<3x2x4xf32> to vector<3x8xf32>
-// CHECK: %[[FMA_0:.*]] = vector.fma %[[SC_V_INPUT_0]], %[[SC_B_FILTER_0]], %[[SC_V_OUTPUT_R]] : vector<3x8xf32>
+// CHECK: %[[SH_FILTER_0:.*]] = vector.shuffle %[[V_FILTER_0]], %[[V_FILTER_0]]
+// CHECK-SAME: [0, 1, 2, 3, 0, 1, 2, 3] : vector<4xf32>, vector<4xf32>
+// CHECK: %[[B_FILTER_0:.*]] = vector.broadcast %[[SH_FILTER_0]] : vector<8xf32> to vector<3x8xf32>
+// CHECK: %[[FMA_0:.*]] = vector.fma %[[SC_V_INPUT_0]], %[[B_FILTER_0]], %[[SC_V_OUTPUT_R]] : vector<3x8xf32>
/// w == 0, kw = 1
// CHECK: %[[SC_V_INPUT_1:.*]] = vector.shape_cast %[[V_INPUT_1]] : vector<3x2x4xf32> to vector<3x8xf32>
-// CHECK: %[[B_V_FILTER_1:.*]] = vector.broadcast %[[V_FILTER_1]] : vector<4xf32> to vector<3x2x4xf32>
-// CHECK: %[[SC_B_FILTER_1:.*]] = vector.shape_cast %[[B_V_FILTER_1]] : vector<3x2x4xf32> to vector<3x8xf32>
-// CHECK: %[[FMA_1:.*]] = vector.fma %[[SC_V_INPUT_1]], %[[SC_B_FILTER_1]], %[[FMA_0]] : vector<3x8xf32>
+// CHECK: %[[SH_FILTER_1:.*]] = vector.shuffle %[[V_FILTER_1]], %[[V_FILTER_1]]
+// CHECK-SAME: [0, 1, 2, 3, 0, 1, 2, 3] : vector<4xf32>, vector<4xf32>
+// CHECK: %[[B_FILTER_1:.*]] = vector.broadcast %[[SH_FILTER_1]] : vector<8xf32> to vector<3x8xf32>
+// CHECK: %[[FMA_1:.*]] = vector.fma %[[SC_V_INPUT_1]], %[[B_FILTER_1]], %[[FMA_0]] : vector<3x8xf32>
// Write the result back in one shot.
// CHECK: %[[SC_FMA_1:.*]] = vector.shape_cast %[[FMA_1]] : vector<3x8xf32> to vector<3x2x4xf32>
@@ -138,19 +141,21 @@ func.func @depthwise_conv1d_nwc_wc_3x5x4xi8_memref_dilation_2(%input: memref<3x5
// CHECK: %[[SC_V_INPUT_0:.*]] = vector.shape_cast %[[V_INPUT_0]] : vector<3x2x4xi8> to vector<3x8xi8>
// CHECK: %[[SC_V_OUTPUT_R:.*]] = vector.shape_cast %[[V_OUTPUT_R]] : vector<3x2x4xi32> to vector<3x8xi32>
// CHECK: %[[EXT_INPUT_0:.*]] = arith.extsi %[[SC_V_INPUT_0]] : vector<3x8xi8> to vector<3x8xi32>
-// CHECK: %[[B_FILTER_0:.*]] = vector.broadcast %[[V_FILTER_0]] : vector<4xi8> to vector<3x2x4xi8>
-// CHECK: %[[SC_B_FILTER_0:.*]] = vector.shape_cast %[[B_FILTER_0]] : vector<3x2x4xi8> to vector<3x8xi8>
-// CHECK: %[[EXT_FILTER_0:.*]] = arith.extsi %[[SC_B_FILTER_0]] : vector<3x8xi8> to vector<3x8xi32>
-// CHECK: %[[MUL_0:.*]] = arith.muli %[[EXT_INPUT_0]], %[[EXT_FILTER_0]] : vector<3x8xi32>
+// CHECK: %[[SH_FILTER_0:.*]] = vector.shuffle %[[V_FILTER_0]], %[[V_FILTER_0]]
+// CHECK-SAME: [0, 1, 2, 3, 0, 1, 2, 3] : vector<4xi8>, vector<4xi8>
+// CHECK: %[[EXT_FILTER_0:.*]] = arith.extsi %[[SH_FILTER_0]] : vector<8xi8> to vector<8xi32>
+// CHECK: %[[B_FILTER_0:.*]] = vector.broadcast %[[EXT_FILTER_0]] : vector<8xi32> to vector<3x8xi32>
+// CHECK: %[[MUL_0:.*]] = arith.muli %[[EXT_INPUT_0]], %[[B_FILTER_0]] : vector<3x8xi32>
// CHECK: %[[ADD_0:.*]] = arith.addi %[[MUL_0]], %[[SC_V_OUTPUT_R]] : vector<3x8xi32>
/// w == 0, kw = 1
// CHECK: %[[SC_V_INPUT_1:.*]] = vector.shape_cast %[[V_INPUT_1]] : vector<3x2x4xi8> to vector<3x8xi8>
// CHECK: %[[EXT_INPUT_1:.*]] = arith.extsi %[[SC_V_INPUT_1]] : vector<3x8xi8> to vector<3x8xi32>
-// CHECK: %[[B_FILTER_1:.*]] = vector.broadcast %[[V_FILTER_1]] : vector<4xi8> to vector<3x2x4xi8>
-// CHECK: %[[SC_B_FILTER_1:.*]] = vector.shape_cast %[[B_FILTER_1]] : vector<3x2x4xi8> to vector<3x8xi8>
-// CHECK: %[[EXT_FILTER_1:.*]] = arith.extsi %[[SC_B_FILTER_1]] : vector<3x8xi8> to vector<3x8xi32>
-// CHECK: %[[MUL_1:.*]] = arith.muli %[[EXT_INPUT_1]], %[[EXT_FILTER_1]] : vector<3x8xi32>
+// CHECK: %[[SH_FILTER_1:.*]] = vector.shuffle %[[V_FILTER_1]], %[[V_FILTER_1]]
+// CHECK-SAME: [0, 1, 2, 3, 0, 1, 2, 3] : vector<4xi8>, vector<4xi8>
+// CHECK: %[[EXT_FILTER_1:.*]] = arith.extsi %[[SH_FILTER_1]] : vector<8xi8> to vector<8xi32>
+// CHECK: %[[B_FILTER_1:.*]] = vector.broadcast %[[EXT_FILTER_1]] : vector<8xi32> to vector<3x8xi32>
+// CHECK: %[[MUL_1:.*]] = arith.muli %[[EXT_INPUT_1]], %[[B_FILTER_1]] : vector<3x8xi32>
// CHECK: %[[ADD_1:.*]] = arith.addi %[[MUL_1]], %[[ADD_0]] : vector<3x8xi32>
// Write the result back in one shot.
@@ -223,69 +228,60 @@ func.func @depthwise_conv1d_nwc_wc_3x9x4xi8_tensor_stride_2(%input: tensor<3x9x4
/// w == 0, kw == 0
// CHECK: %[[VAL_23:.*]] = vector.shape_cast %[[V_INPUT_0]] : vector<3x1x4xi8> to vector<3x4xi8>
// CHECK: %[[VAL_24:.*]] = vector.shape_cast %[[V_OUTPUT_0]] : vector<3x1x4xi8> to vector<3x4xi8>
-// CHECK: %[[B_FILTER_0:.*]] = vector.broadcast %[[V_FILTER_0]] : vector<4xi8> to vector<3x1x4xi8>
-// CHECK: %[[VAL_26:.*]] = vector.shape_cast %[[B_FILTER_0]] : vector<3x1x4xi8> to vector<3x4xi8>
-// CHECK: %[[VAL_27:.*]] = arith.muli %[[VAL_23]], %[[VAL_26]] : vector<3x4xi8>
+// CHECK: %[[B_FILTER_0:.*]] = vector.broadcast %[[V_FILTER_0]] : vector<4xi8> to vector<3x4xi8>
+// CHECK: %[[VAL_27:.*]] = arith.muli %[[VAL_23]], %[[B_FILTER_0]] : vector<3x4xi8>
// CHECK: %[[VAL_28:.*]] = arith.addi %[[VAL_27]], %[[VAL_24]] : vector<3x4xi8>
/// w == 1, kw == 0
// CHECK: %[[VAL_29:.*]] = vector.shape_cast %[[V_INPUT_1]] : vector<3x1x4xi8> to vector<3x4xi8>
// CHECK: %[[VAL_30:.*]] = vector.shape_cast %[[V_OUTPUT_1]] : vector<3x1x4xi8> to vector<3x4xi8>
-// CHECK: %[[B_FILTER_0:.*]] = vector.broadcast %[[V_FILTER_0]] : vector<4xi8> to vector<3x1x4xi8>
-// CHECK: %[[VAL_32:.*]] = vector.shape_cast %[[B_FILTER_0]] : vector<3x1x4xi8> to vector<3x4xi8>
-// CHECK: %[[VAL_33:.*]] = arith.muli %[[VAL_29]], %[[VAL_32]] : vector<3x4xi8>
+// CHECK: %[[B_FILTER_0_1:.*]] = vector.broadcast %[[V_FILTER_0]] : vector<4xi8> to vector<3x4xi8>
+// CHECK: %[[VAL_33:.*]] = arith.muli %[[VAL_29]], %[[B_FILTER_0_1]] : vector<3x4xi8>
// CHECK: %[[VAL_34:.*]] = arith.addi %[[VAL_33]], %[[VAL_30]] : vector<3x4xi8>
/// w == 2, kw == 0
// CHECK: %[[VAL_35:.*]] = vector.shape_cast %[[V_INPUT_2]] : vector<3x1x4xi8> to vector<3x4xi8>
// CHECK: %[[VAL_36:.*]] = vector.shape_cast %[[V_OUTPUT_2]] : vector<3x1x4xi8> to vector<3x4xi8>
-// CHECK: %[[B_FILTER_0:.*]] = vector.broadcast %[[V_FILTER_0]] : vector<4xi8> to vector<3x1x4xi8>
-// CHECK: %[[VAL_38:.*]] = vector.shape_cast %[[B_FILTER_0]] : vector<3x1x4xi8> to vector<3x4xi8>
-// CHECK: %[[VAL_39:.*]] = arith.muli %[[VAL_35]], %[[VAL_38]] : vector<3x4xi8>
+// CHECK: %[[B_FILTER_0_2:.*]] = vector.broadcast %[[V_FILTER_0]] : vector<4xi8> to vector<3x4xi8>
+// CHECK: %[[VAL_39:.*]] = arith.muli %[[VAL_35]], %[[B_FILTER_0_2]] : vector<3x4xi8>
// CHECK: %[[VAL_40:.*]] = arith.addi %[[VAL_39]], %[[VAL_36]] : vector<3x4xi8>
/// w == 3, kw == 1
// CHECK: %[[VAL_41:.*]] = vector.shape_cast %[[V_INPUT_3]] : vector<3x1x4xi8> to vector<3x4xi8>
-// CHECK: %[[B_FILTER_1:.*]] = vector.broadcast %[[V_FILTER_1]] : vector<4xi8> to vector<3x1x4xi8>
-// CHECK: %[[VAL_43:.*]] = vector.shape_cast %[[B_FILTER_1]] : vector<3x1x4xi8> to vector<3x4xi8>
-// CHECK: %[[VAL_44:.*]] = arith.muli %[[VAL_41]], %[[VAL_43]] : vector<3x4xi8>
+// CHECK: %[[B_FILTER_1:.*]] = vector.broadcast %[[V_FILTER_1]] : vector<4xi8> to vector<3x4xi8>
+// CHECK: %[[VAL_44:.*]] = arith.muli %[[VAL_41]], %[[B_FILTER_1]] : vector<3x4xi8>
// CHECK: %[[VAL_45:.*]] = arith.addi %[[VAL_44]], %[[VAL_28]] : vector<3x4xi8>
/// w == 4, kw == 1
// CHECK: %[[VAL_46:.*]] = vector.shape_cast %[[V_INPUT_4]] : vector<3x1x4xi8> to vector<3x4xi8>
-// CHECK: %[[B_FILTER_1:.*]] = vector.broadcast %[[V_FILTER_1]] : vector<4xi8> to vector<3x1x4xi8>
-// CHECK: %[[VAL_48:.*]] = vector.shape_cast %[[B_FILTER_1]] : vector<3x1x4xi8> to vector<3x4xi8>
-// CHECK: %[[VAL_49:.*]] = arith.muli %[[VAL_46]], %[[VAL_48]] : vector<3x4xi8>
+// CHECK: %[[B_FILTER_1_1:.*]] = vector.broadcast %[[V_FILTER_1]] : vector<4xi8> to vector<3x4xi8>
+// CHECK: %[[VAL_49:.*]] = arith.muli %[[VAL_46]], %[[B_FILTER_1_1]] : vector<3x4xi8>
// CHECK: %[[VAL_50:.*]] = arith.addi %[[VAL_49]], %[[VAL_34]] : vector<3x4xi8>
/// w == 5, kw == 1
// CHECK: %[[VAL_51:.*]] = vector.shape_cast %[[V_INPUT_5]] : vector<3x1x4xi8> to vector<3x4xi8>
-// CHECK: %[[B_FILTER_1:.*]] = vector.broadcast %[[V_FILTER_1]] : vector<4xi8> to vector<3x1x4xi8>
-// CHECK: %[[VAL_53:.*]] = vector.shape_cast %[[B_FILTER_1]] : vector<3x1x4xi8> to vector<3x4xi8>
-// CHECK: %[[VAL_54:.*]] = arith.muli %[[VAL_51]], %[[VAL_53]] : vector<3x4xi8>
+// CHECK: %[[B_FILTER_1_2:.*]] = vector.broadcast %[[V_FILTER_1]] : vector<4xi8> to vector<3x4xi8>
+// CHECK: %[[VAL_54:.*]] = arith.muli %[[VAL_51]], %[[B_FILTER_1_2]] : vector<3x4xi8>
// CHECK: %[[VAL_55:.*]] = arith.addi %[[VAL_54]], %[[VAL_40]] : vector<3x4xi8>
/// w == 6, kw == 2
// CHECK: %[[VAL_56:.*]] = vector.shape_cast %[[V_INPUT_6]] : vector<3x1x4xi8> to vector<3x4xi8>
-// CHECK: %[[B_FILTER_2:.*]] = vector.broadcast %[[V_FILTER_2]] : vector<4xi8> to vector<3x1x4xi8>
-// CHECK: %[[VAL_58:.*]] = vector.shape_cast %[[B_FILTER_2]] : vector<3x1x4xi8> to vector<3x4xi8>
-// CHECK: %[[VAL_59:.*]] = arith.muli %[[VAL_56]], %[[VAL_58]] : vector<3x4xi8>
+// CHECK: %[[B_FILTER_2:.*]] = vector.broadcast %[[V_FILTER_2]] : vector<4xi8> to vector<3x4xi8>
+// CHECK: %[[VAL_59:.*]] = arith.muli %[[VAL_56]], %[[B_FILTER_2]] : vector<3x4xi8>
// CHECK: %[[VAL_60:.*]] = arith.addi %[[VAL_59]], %[[VAL_45]] : vector<3x4xi8>
/// w == 7, kw == 2
// CHECK: %[[VAL_61:.*]] = vector.shape_cast %[[VAL_60]] : vector<3x4xi8> to vector<3x1x4xi8>
// CHECK: %[[VAL_62:.*]] = vector.shape_cast %[[V_INPUT_7]] : vector<3x1x4xi8> to vector<3x4xi8>
-// CHECK: %[[B_FILTER_2:.*]] = vector.broadcast %[[V_FILTER_2]] : vector<4xi8> to vector<3x1x4xi8>
-// CHECK: %[[VAL_64:.*]] = vector.shape_cast %[[B_FILTER_2]] : vector<3x1x4xi8> to vector<3x4xi8>
-// CHECK: %[[VAL_65:.*]] = arith.muli %[[VAL_62]], %[[VAL_64]] : vector<3x4xi8>
+// CHECK: %[[B_FILTER_2_1:.*]] = vector.broadcast %[[V_FILTER_2]] : vector<4xi8> to vector<3x4xi8>
+// CHECK: %[[VAL_65:.*]] = arith.muli %[[VAL_62]], %[[B_FILTER_2_1]] : vector<3x4xi8>
// CHECK: %[[VAL_66:.*]] = arith.addi %[[VAL_65]], %[[VAL_50]] : vector<3x4xi8>
/// w == 8, kw == 2
// CHECK: %[[VAL_67:.*]] = vector.shape_cast %[[VAL_66]] : vector<3x4xi8> to vector<3x1x4xi8>
// CHECK: %[[VAL_68:.*]] = vector.shape_cast %[[V_INPUT_8]] : vector<3x1x4xi8> to vector<3x4xi8>
-// CHECK: %[[B_FILTER_2:.*]] = vector.broadcast %[[V_FILTER_2]] : vector<4xi8> to vector<3x1x4xi8>
-// CHECK: %[[VAL_70:.*]] = vector.shape_cast %[[B_FILTER_2]] : vector<3x1x4xi8> to vector<3x4xi8>
-// CHECK: %[[VAL_71:.*]] = arith.muli %[[VAL_68]], %[[VAL_70]] : vector<3x4xi8>
+// CHECK: %[[B_FILTER_2_2:.*]] = vector.broadcast %[[V_FILTER_2]] : vector<4xi8> to vector<3x4xi8>
+// CHECK: %[[VAL_71:.*]] = arith.muli %[[VAL_68]], %[[B_FILTER_2_2]] : vector<3x4xi8>
// CHECK: %[[VAL_72:.*]] = arith.addi %[[VAL_71]], %[[VAL_55]] : vector<3x4xi8>
// Write the result back.
|
@nicolasvasilache If you don't have strong objections, I would like to go ahead and land this. We discussed this offline on Tue and you suggested that it would be better to create a pattern to match: %b_filter = vector.broadcast %filter : vector<4xf32> to vector<3x2x4xf32>
%sc_filter = vector.shape_cast %b_filter : vector<3x2x4xf32> to vector<3x8xf32> and to replace it with: %sh_filter = vector.shuffle %filter, %filter
[0, 1, 2, 3, 0, 1, 2, 3] : vector<4xf32>, vector<4xf32>
%b_filter = vector.broadcast %sh_filter : vector<8xf32> to vector<3x8xf32> However, it's not clear to me how to go about it. Also, mindful of the recent discussions on Separately, this removes one problem with "flattened" convs (flattening the filter). There are still a few other issues that I need to address (flattening the input/output vectors and making that work at both Tensor and Memref level). By landing this I can re-focus on the other issues. If there are no new comments I will merge it this week. Happy to work on an alternative solution, but I may need some guidance :) |
Updates the vectorisation of 1D depthwise convolution when flattening
the channel dimension (introduced in #71918). In particular - how the
convolution filter is "flattened". ATM, the vectoriser will use
vector.shape_cast
:This lowering is not ideal -
vector.shape_cast
can be convenient whenit's folded away, but that's not happening in this case. Instead, this
patch updates the vectoriser to use
vector.shuffle
(the overall resultis identical):