@@ -3060,8 +3060,11 @@ LogicalResult WinogradInputTransformOp::verify() {
3060
3060
int m = getM ();
3061
3061
int r = getR ();
3062
3062
int64_t tileSize = m + r - 1 ;
3063
- bool leftTransform = inputH != 1 ;
3064
- bool rightTransform = inputW != 1 ;
3063
+
3064
+ auto outputType = cast<ShapedType>(getOutput ().getType ());
3065
+ ArrayRef<int64_t > outputShape = outputType.getShape ();
3066
+ bool leftTransform = outputShape[getOutputAlphaHDim ()] != 1 ;
3067
+ bool rightTransform = outputShape[getOutputAlphaWDim ()] != 1 ;
3065
3068
3066
3069
SmallVector<int64_t > expectedOutputShape (6 , inputH);
3067
3070
if (ShapedType::isDynamic (inputH)) {
@@ -3070,21 +3073,19 @@ LogicalResult WinogradInputTransformOp::verify() {
3070
3073
} else {
3071
3074
expectedOutputShape[getOutputAlphaHDim ()] = leftTransform ? tileSize : 1 ;
3072
3075
expectedOutputShape[getOutputTileHDim ()] =
3073
- leftTransform ? (inputH - (r - 1 )) / m : 1 ;
3076
+ leftTransform ? (inputH - (r - 1 )) / m : inputH ;
3074
3077
}
3075
3078
if (ShapedType::isDynamic (inputW)) {
3076
3079
expectedOutputShape[getOutputAlphaWDim ()] = tileSize;
3077
3080
expectedOutputShape[getOutputTileWDim ()] = ShapedType::kDynamic ;
3078
3081
} else {
3079
3082
expectedOutputShape[getOutputAlphaWDim ()] = rightTransform ? tileSize : 1 ;
3080
3083
expectedOutputShape[getOutputTileWDim ()] =
3081
- rightTransform ? (inputW - (r - 1 )) / m : 1 ;
3084
+ rightTransform ? (inputW - (r - 1 )) / m : inputW ;
3082
3085
}
3083
3086
expectedOutputShape[getOutputNDim ()] = inputShape[getInputNDim ()];
3084
3087
expectedOutputShape[getOutputCDim ()] = inputShape[getInputCDim ()];
3085
3088
3086
- auto outputType = cast<ShapedType>(getOutput ().getType ());
3087
- ArrayRef<int64_t > outputShape = outputType.getShape ();
3088
3089
if (failed (verifyCompatibleShape (expectedOutputShape, outputShape))) {
3089
3090
return emitOpError (" the output shape is not expected" );
3090
3091
}
@@ -3121,15 +3122,17 @@ LogicalResult WinogradInputTransformOp::getResultTilePosition(
3121
3122
ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
3122
3123
SmallVector<OpFoldResult> &resultSizes) {
3123
3124
IntegerAttr zeroAttr = builder.getI64IntegerAttr (0 );
3124
- ShapedType inputType = getInputOperandType ();
3125
- ArrayRef<int64_t > inputShape = inputType.getShape ();
3126
- int64_t inputH = inputShape[getInputHDim ()];
3127
- int64_t inputW = inputShape[getInputWDim ()];
3125
+ ShapedType outputType = getOutputOperandType ();
3126
+ ArrayRef<int64_t > outputShape = outputType.getShape ();
3127
+ int64_t outputAlphaH = outputShape[getOutputAlphaHDim ()];
3128
+ int64_t outputAlphaW = outputShape[getOutputAlphaWDim ()];
3129
+
3128
3130
int64_t m = getM ();
3129
3131
int64_t r = getR ();
3130
3132
int64_t alpha = m + r - 1 ;
3131
- int64_t alphaH = inputH != 1 ? alpha : 1 ;
3132
- int64_t alphaW = inputW != 1 ? alpha : 1 ;
3133
+ int64_t alphaH = outputAlphaH != 1 ? alpha : 1 ;
3134
+ int64_t alphaW = outputAlphaW != 1 ? alpha : 1 ;
3135
+
3133
3136
IntegerAttr alphaHAttr = builder.getI64IntegerAttr (alphaH);
3134
3137
IntegerAttr alphaWAttr = builder.getI64IntegerAttr (alphaW);
3135
3138
@@ -3154,22 +3157,26 @@ WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
3154
3157
ArrayRef<OpFoldResult> offsets,
3155
3158
ArrayRef<OpFoldResult> sizes) {
3156
3159
IntegerAttr oneAttr = builder.getI64IntegerAttr (1 );
3157
- IntegerAttr zeroAttr = builder.getI64IntegerAttr (0 );
3158
- ShapedType inputType = getInputOperandType ();
3159
- ArrayRef<int64_t > inputShape = inputType.getShape ();
3160
- int64_t inputH = inputShape[getInputHDim ()];
3161
- int64_t inputW = inputShape[getInputWDim ()];
3162
3160
int64_t m = getM ();
3163
3161
int64_t r = getR ();
3164
3162
3163
+ ShapedType outputType = getOutputOperandType ();
3164
+ ArrayRef<int64_t > outputShape = outputType.getShape ();
3165
+ int64_t alphaH = outputShape[getOutputAlphaHDim ()];
3166
+ int64_t alphaW = outputShape[getOutputAlphaWDim ()];
3167
+
3165
3168
Location loc = getLoc ();
3166
3169
MLIRContext *context = builder.getContext ();
3170
+ auto identityAffineMap =
3171
+ AffineMap::get (1 , 0 , {builder.getAffineDimExpr (0 )}, context);
3167
3172
auto offsetAffineMap =
3168
3173
AffineMap::get (1 , 0 , {builder.getAffineDimExpr (0 ) * m}, context);
3169
3174
Value mappedOffsetH = affine::makeComposedAffineApply (
3170
- builder, loc, offsetAffineMap, offsets[getOutputTileHDim ()]);
3175
+ builder, loc, (alphaH != 1 ? offsetAffineMap : identityAffineMap),
3176
+ offsets[getOutputTileHDim ()]);
3171
3177
Value mappedOffsetW = affine::makeComposedAffineApply (
3172
- builder, loc, offsetAffineMap, offsets[getOutputTileWDim ()]);
3178
+ builder, loc, (alphaW != 1 ? offsetAffineMap : identityAffineMap),
3179
+ offsets[getOutputTileWDim ()]);
3173
3180
auto sizeAffineMap = AffineMap::get (
3174
3181
1 , 0 , {builder.getAffineDimExpr (0 ) * m + (r - 1 )}, context);
3175
3182
Value mappedSizeH = affine::makeComposedAffineApply (
@@ -3180,16 +3187,14 @@ WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
3180
3187
SmallVector<Value> tiledOperands;
3181
3188
SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3182
3189
3183
- OpFoldResult offsetH =
3184
- inputH != 1 ? OpFoldResult (mappedOffsetH) : OpFoldResult (zeroAttr);
3185
- OpFoldResult offsetW =
3186
- inputW != 1 ? OpFoldResult (mappedOffsetW) : OpFoldResult (zeroAttr);
3190
+ OpFoldResult offsetH = OpFoldResult (mappedOffsetH);
3191
+ OpFoldResult offsetW = OpFoldResult (mappedOffsetW);
3187
3192
sliceOffsets.append (
3188
3193
{offsets[getOutputNDim ()], offsetH, offsetW, offsets[getOutputCDim ()]});
3189
3194
OpFoldResult sizeH =
3190
- inputH != 1 ? OpFoldResult (mappedSizeH) : OpFoldResult (oneAttr);
3195
+ alphaH != 1 ? OpFoldResult (mappedSizeH) : OpFoldResult (oneAttr);
3191
3196
OpFoldResult sizeW =
3192
- inputW != 1 ? OpFoldResult (mappedSizeW) : OpFoldResult (oneAttr);
3197
+ alphaW != 1 ? OpFoldResult (mappedSizeW) : OpFoldResult (oneAttr);
3193
3198
sliceSizes.append (
3194
3199
{sizes[getOutputNDim ()], sizeH, sizeW, sizes[getOutputCDim ()]});
3195
3200
int64_t inputRank = getInputOperandRank ();
@@ -3297,28 +3302,29 @@ LogicalResult WinogradOutputTransformOp::getResultTilePosition(
3297
3302
3298
3303
Location loc = getLoc ();
3299
3304
MLIRContext *context = builder.getContext ();
3305
+ auto identityAffineMap =
3306
+ AffineMap::get (1 , 0 , {builder.getAffineDimExpr (0 )}, context);
3300
3307
auto affineMap =
3301
3308
AffineMap::get (1 , 0 , {builder.getAffineDimExpr (0 ) * m}, context);
3302
3309
3310
+ ShapedType valueType = getValueOperandType ();
3311
+ ArrayRef<int64_t > valueShape = valueType.getShape ();
3312
+ int64_t valueH = valueShape[0 ];
3313
+ int64_t valueW = valueShape[1 ];
3303
3314
Value mappedOffsetH = affine::makeComposedAffineApply (
3304
- builder, loc, affineMap, offsets[getValueTileHDim ()]);
3315
+ builder, loc, (valueH != 1 ? affineMap : identityAffineMap),
3316
+ offsets[getValueTileHDim ()]);
3305
3317
Value mappedOffsetW = affine::makeComposedAffineApply (
3306
- builder, loc, affineMap, offsets[getValueTileWDim ()]);
3318
+ builder, loc, (valueW != 1 ? affineMap : identityAffineMap),
3319
+ offsets[getValueTileWDim ()]);
3307
3320
Value mappedSizeH = affine::makeComposedAffineApply (
3308
3321
builder, loc, affineMap, sizes[getValueTileHDim ()]);
3309
3322
Value mappedSizeW = affine::makeComposedAffineApply (
3310
3323
builder, loc, affineMap, sizes[getValueTileWDim ()]);
3311
3324
3312
- ShapedType valueType = getValueOperandType ();
3313
- ArrayRef<int64_t > valueShape = valueType.getShape ();
3314
- int64_t valueH = valueShape[0 ];
3315
- int64_t valueW = valueShape[1 ];
3316
3325
IntegerAttr oneAttr = builder.getI64IntegerAttr (1 );
3317
- IntegerAttr zeroAttr = builder.getI64IntegerAttr (0 );
3318
- OpFoldResult offsetH =
3319
- valueH != 1 ? OpFoldResult (mappedOffsetH) : OpFoldResult (zeroAttr);
3320
- OpFoldResult offsetW =
3321
- valueW != 1 ? OpFoldResult (mappedOffsetW) : OpFoldResult (zeroAttr);
3326
+ OpFoldResult offsetH = OpFoldResult (mappedOffsetH);
3327
+ OpFoldResult offsetW = OpFoldResult (mappedOffsetW);
3322
3328
OpFoldResult sizeH =
3323
3329
valueH != 1 ? OpFoldResult (mappedSizeH) : OpFoldResult (oneAttr);
3324
3330
OpFoldResult sizeW =
0 commit comments