Skip to content

Commit f996270

Browse files
committed
fixup! fixup! [mlir][linalg] Add masked vectorisation for depthwise convolutions
Address comments from Ben and Crefeda
1 parent 0a06be8 commit f996270

File tree

2 files changed

+7
-6
lines changed

2 files changed

+7
-6
lines changed

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2857,7 +2857,6 @@ struct Conv1DGenerator
28572857
return;
28582858
break;
28592859
}
2860-
hasTensorSemantics = linalgOp.hasPureTensorSemantics();
28612860
// The op is now known to be valid.
28622861
valid = true;
28632862
}
@@ -3175,6 +3174,9 @@ struct Conv1DGenerator
31753174
if (ShapedType::isDynamic(cSize)) {
31763175
assert(channelDimVecSize != 0 && "Channel dim vec size must be > 0");
31773176
cSize = channelDimVecSize;
3177+
// Scalable vectors are only used when both conditions are met:
3178+
// 1. channel dim is dynamic
3179+
// 2. channelDimScalableFlag is set
31783180
scalableChDim = channelDimScalableFlag;
31793181
useMasking = true;
31803182
}
@@ -3216,7 +3218,7 @@ struct Conv1DGenerator
32163218
auto maskType =
32173219
VectorType::get(maskShape, rewriter.getI1Type(), scalableDims);
32183220
SmallVector<OpFoldResult> mixedSourceDims =
3219-
hasTensorSemantics
3221+
cast<LinalgOp>(op).hasPureTensorSemantics()
32203222
? TypeSwitch<Operation *, SmallVector<OpFoldResult>>(opToMask)
32213223
.Case<vector::TransferReadOp>([&](auto readOp) {
32223224
return tensor::getMixedSizes(rewriter, loc,
@@ -3490,7 +3492,7 @@ struct Conv1DGenerator
34903492
/// Entry point that transposes into the common form:
34913493
/// {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
34923494
FailureOr<Operation *> generateDilatedConv(uint64_t vecChDimSize = 0,
3493-
bool vecChDimScalableFlag = true,
3495+
bool vecChDimScalableFlag = false,
34943496
bool flatten = false) {
34953497
AffineExpr n, w, c, kw;
34963498
bindDims(ctx, n, w, c, kw);
@@ -3514,7 +3516,6 @@ struct Conv1DGenerator
35143516
StringAttr redOp;
35153517
StringAttr poolExtOp;
35163518
bool isPoolExt = false;
3517-
bool hasTensorSemantics = false;
35183519
int strideW, dilationW;
35193520
Value lhsShaped, rhsShaped, resShaped;
35203521
ShapedType lhsShapedType, rhsShapedType, resShapedType;

mlir/test/Dialect/Linalg/vectorize-conv-scalable.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ module attributes {transform.with_named_sequence} {
124124

125125
// -----
126126

127-
func.func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref_dillation_2(%input: memref<3x5x?xf32>,
127+
func.func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref_dilation_2(%input: memref<3x5x?xf32>,
128128
%filter: memref<2x?xf32>,
129129
%output: memref<3x2x?xf32>) {
130130
linalg.depthwise_conv_1d_nwc_wc
@@ -135,7 +135,7 @@ func.func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref_dillation_2(%input: memref<3
135135
}
136136

137137
// TODO - nice variable names
138-
// CHECK-LABEL: func.func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref_dillation_2(
138+
// CHECK-LABEL: func.func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref_dilation_2(
139139
// CHECK-SAME: %[[VAL_0:.*]]: memref<3x5x?xf32>,
140140
// CHECK-SAME: %[[VAL_1:.*]]: memref<2x?xf32>,
141141
// CHECK-SAME: %[[VAL_2:.*]]: memref<3x2x?xf32>) {

0 commit comments

Comments
 (0)