@@ -2857,7 +2857,6 @@ struct Conv1DGenerator
2857
2857
return ;
2858
2858
break ;
2859
2859
}
2860
- hasTensorSemantics = linalgOp.hasPureTensorSemantics ();
2861
2860
// The op is now known to be valid.
2862
2861
valid = true ;
2863
2862
}
@@ -3175,6 +3174,9 @@ struct Conv1DGenerator
3175
3174
if (ShapedType::isDynamic (cSize)) {
3176
3175
assert (channelDimVecSize != 0 && " Channel dim vec size must be > 0" );
3177
3176
cSize = channelDimVecSize;
3177
+ // Scalable vectors are only used when both conditions are met:
3178
+ // 1. channel dim is dynamic
3179
+ // 2. channelDimScalableFlag is set
3178
3180
scalableChDim = channelDimScalableFlag;
3179
3181
useMasking = true ;
3180
3182
}
@@ -3216,7 +3218,7 @@ struct Conv1DGenerator
3216
3218
auto maskType =
3217
3219
VectorType::get (maskShape, rewriter.getI1Type (), scalableDims);
3218
3220
SmallVector<OpFoldResult> mixedSourceDims =
3219
- hasTensorSemantics
3221
+ cast<LinalgOp>(op). hasPureTensorSemantics ()
3220
3222
? TypeSwitch<Operation *, SmallVector<OpFoldResult>>(opToMask)
3221
3223
.Case <vector::TransferReadOp>([&](auto readOp) {
3222
3224
return tensor::getMixedSizes (rewriter, loc,
@@ -3490,7 +3492,7 @@ struct Conv1DGenerator
3490
3492
// / Entry point that transposes into the common form:
3491
3493
// / {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
3492
3494
FailureOr<Operation *> generateDilatedConv (uint64_t vecChDimSize = 0 ,
3493
- bool vecChDimScalableFlag = true ,
3495
+ bool vecChDimScalableFlag = false ,
3494
3496
bool flatten = false ) {
3495
3497
AffineExpr n, w, c, kw;
3496
3498
bindDims (ctx, n, w, c, kw);
@@ -3514,7 +3516,6 @@ struct Conv1DGenerator
3514
3516
StringAttr redOp;
3515
3517
StringAttr poolExtOp;
3516
3518
bool isPoolExt = false ;
3517
- bool hasTensorSemantics = false ;
3518
3519
int strideW, dilationW;
3519
3520
Value lhsShaped, rhsShaped, resShaped;
3520
3521
ShapedType lhsShapedType, rhsShapedType, resShapedType;
0 commit comments