Skip to content

Commit f20b8e3

Browse files
authored
[MLIR][Linalg] Fixes for Winograd decomposition and for tiling (llvm#123675)
The PR addresses issues with the filters of 1 x r and of r x 1 and with the tiling. --------- Signed-off-by: Dmitriy Smirnov <[email protected]>
1 parent 690f251 commit f20b8e3

File tree

5 files changed

+220
-101
lines changed

5 files changed

+220
-101
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ def Linalg_SoftmaxOp : Linalg_Op<"softmax",
155155
}
156156

157157
def Linalg_WinogradFilterTransformOp : Linalg_Op<"winograd_filter_transform",
158-
[AllElementTypesMatch<["filter", "output"]>,
158+
[AllElementTypesMatch<["filter", "output"]>, DestinationStyleOpInterface,
159159
DeclareOpInterfaceMethods<TilingInterface,
160160
["getIterationDomain",
161161
"getLoopIteratorTypes",
@@ -220,12 +220,13 @@ def Linalg_WinogradFilterTransformOp : Linalg_Op<"winograd_filter_transform",
220220
int64_t getFilterCDim() {
221221
return 3;
222222
}
223+
MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); }
223224
}];
224225
let hasVerifier = 1;
225226
}
226227

227228
def Linalg_WinogradInputTransformOp : Linalg_Op<"winograd_input_transform",
228-
[AllElementTypesMatch<["input", "output"]>,
229+
[AllElementTypesMatch<["input", "output"]>, DestinationStyleOpInterface,
229230
DeclareOpInterfaceMethods<TilingInterface,
230231
["getIterationDomain",
231232
"getLoopIteratorTypes",
@@ -308,6 +309,7 @@ def Linalg_WinogradInputTransformOp : Linalg_Op<"winograd_input_transform",
308309
int64_t getOutputCDim() {
309310
return 5;
310311
}
312+
MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); }
311313
}];
312314
let hasVerifier = 1;
313315
}

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 42 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -3060,8 +3060,11 @@ LogicalResult WinogradInputTransformOp::verify() {
30603060
int m = getM();
30613061
int r = getR();
30623062
int64_t tileSize = m + r - 1;
3063-
bool leftTransform = inputH != 1;
3064-
bool rightTransform = inputW != 1;
3063+
3064+
auto outputType = cast<ShapedType>(getOutput().getType());
3065+
ArrayRef<int64_t> outputShape = outputType.getShape();
3066+
bool leftTransform = outputShape[getOutputAlphaHDim()] != 1;
3067+
bool rightTransform = outputShape[getOutputAlphaWDim()] != 1;
30653068

30663069
SmallVector<int64_t> expectedOutputShape(6, inputH);
30673070
if (ShapedType::isDynamic(inputH)) {
@@ -3070,21 +3073,19 @@ LogicalResult WinogradInputTransformOp::verify() {
30703073
} else {
30713074
expectedOutputShape[getOutputAlphaHDim()] = leftTransform ? tileSize : 1;
30723075
expectedOutputShape[getOutputTileHDim()] =
3073-
leftTransform ? (inputH - (r - 1)) / m : 1;
3076+
leftTransform ? (inputH - (r - 1)) / m : inputH;
30743077
}
30753078
if (ShapedType::isDynamic(inputW)) {
30763079
expectedOutputShape[getOutputAlphaWDim()] = tileSize;
30773080
expectedOutputShape[getOutputTileWDim()] = ShapedType::kDynamic;
30783081
} else {
30793082
expectedOutputShape[getOutputAlphaWDim()] = rightTransform ? tileSize : 1;
30803083
expectedOutputShape[getOutputTileWDim()] =
3081-
rightTransform ? (inputW - (r - 1)) / m : 1;
3084+
rightTransform ? (inputW - (r - 1)) / m : inputW;
30823085
}
30833086
expectedOutputShape[getOutputNDim()] = inputShape[getInputNDim()];
30843087
expectedOutputShape[getOutputCDim()] = inputShape[getInputCDim()];
30853088

3086-
auto outputType = cast<ShapedType>(getOutput().getType());
3087-
ArrayRef<int64_t> outputShape = outputType.getShape();
30883089
if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
30893090
return emitOpError("the output shape is not expected");
30903091
}
@@ -3121,15 +3122,17 @@ LogicalResult WinogradInputTransformOp::getResultTilePosition(
31213122
ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
31223123
SmallVector<OpFoldResult> &resultSizes) {
31233124
IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
3124-
ShapedType inputType = getInputOperandType();
3125-
ArrayRef<int64_t> inputShape = inputType.getShape();
3126-
int64_t inputH = inputShape[getInputHDim()];
3127-
int64_t inputW = inputShape[getInputWDim()];
3125+
ShapedType outputType = getOutputOperandType();
3126+
ArrayRef<int64_t> outputShape = outputType.getShape();
3127+
int64_t outputAlphaH = outputShape[getOutputAlphaHDim()];
3128+
int64_t outputAlphaW = outputShape[getOutputAlphaWDim()];
3129+
31283130
int64_t m = getM();
31293131
int64_t r = getR();
31303132
int64_t alpha = m + r - 1;
3131-
int64_t alphaH = inputH != 1 ? alpha : 1;
3132-
int64_t alphaW = inputW != 1 ? alpha : 1;
3133+
int64_t alphaH = outputAlphaH != 1 ? alpha : 1;
3134+
int64_t alphaW = outputAlphaW != 1 ? alpha : 1;
3135+
31333136
IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
31343137
IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
31353138

@@ -3154,22 +3157,26 @@ WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
31543157
ArrayRef<OpFoldResult> offsets,
31553158
ArrayRef<OpFoldResult> sizes) {
31563159
IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
3157-
IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
3158-
ShapedType inputType = getInputOperandType();
3159-
ArrayRef<int64_t> inputShape = inputType.getShape();
3160-
int64_t inputH = inputShape[getInputHDim()];
3161-
int64_t inputW = inputShape[getInputWDim()];
31623160
int64_t m = getM();
31633161
int64_t r = getR();
31643162

3163+
ShapedType outputType = getOutputOperandType();
3164+
ArrayRef<int64_t> outputShape = outputType.getShape();
3165+
int64_t alphaH = outputShape[getOutputAlphaHDim()];
3166+
int64_t alphaW = outputShape[getOutputAlphaWDim()];
3167+
31653168
Location loc = getLoc();
31663169
MLIRContext *context = builder.getContext();
3170+
auto identityAffineMap =
3171+
AffineMap::get(1, 0, {builder.getAffineDimExpr(0)}, context);
31673172
auto offsetAffineMap =
31683173
AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
31693174
Value mappedOffsetH = affine::makeComposedAffineApply(
3170-
builder, loc, offsetAffineMap, offsets[getOutputTileHDim()]);
3175+
builder, loc, (alphaH != 1 ? offsetAffineMap : identityAffineMap),
3176+
offsets[getOutputTileHDim()]);
31713177
Value mappedOffsetW = affine::makeComposedAffineApply(
3172-
builder, loc, offsetAffineMap, offsets[getOutputTileWDim()]);
3178+
builder, loc, (alphaW != 1 ? offsetAffineMap : identityAffineMap),
3179+
offsets[getOutputTileWDim()]);
31733180
auto sizeAffineMap = AffineMap::get(
31743181
1, 0, {builder.getAffineDimExpr(0) * m + (r - 1)}, context);
31753182
Value mappedSizeH = affine::makeComposedAffineApply(
@@ -3180,16 +3187,14 @@ WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
31803187
SmallVector<Value> tiledOperands;
31813188
SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
31823189

3183-
OpFoldResult offsetH =
3184-
inputH != 1 ? OpFoldResult(mappedOffsetH) : OpFoldResult(zeroAttr);
3185-
OpFoldResult offsetW =
3186-
inputW != 1 ? OpFoldResult(mappedOffsetW) : OpFoldResult(zeroAttr);
3190+
OpFoldResult offsetH = OpFoldResult(mappedOffsetH);
3191+
OpFoldResult offsetW = OpFoldResult(mappedOffsetW);
31873192
sliceOffsets.append(
31883193
{offsets[getOutputNDim()], offsetH, offsetW, offsets[getOutputCDim()]});
31893194
OpFoldResult sizeH =
3190-
inputH != 1 ? OpFoldResult(mappedSizeH) : OpFoldResult(oneAttr);
3195+
alphaH != 1 ? OpFoldResult(mappedSizeH) : OpFoldResult(oneAttr);
31913196
OpFoldResult sizeW =
3192-
inputW != 1 ? OpFoldResult(mappedSizeW) : OpFoldResult(oneAttr);
3197+
alphaW != 1 ? OpFoldResult(mappedSizeW) : OpFoldResult(oneAttr);
31933198
sliceSizes.append(
31943199
{sizes[getOutputNDim()], sizeH, sizeW, sizes[getOutputCDim()]});
31953200
int64_t inputRank = getInputOperandRank();
@@ -3297,28 +3302,29 @@ LogicalResult WinogradOutputTransformOp::getResultTilePosition(
32973302

32983303
Location loc = getLoc();
32993304
MLIRContext *context = builder.getContext();
3305+
auto identityAffineMap =
3306+
AffineMap::get(1, 0, {builder.getAffineDimExpr(0)}, context);
33003307
auto affineMap =
33013308
AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
33023309

3310+
ShapedType valueType = getValueOperandType();
3311+
ArrayRef<int64_t> valueShape = valueType.getShape();
3312+
int64_t valueH = valueShape[0];
3313+
int64_t valueW = valueShape[1];
33033314
Value mappedOffsetH = affine::makeComposedAffineApply(
3304-
builder, loc, affineMap, offsets[getValueTileHDim()]);
3315+
builder, loc, (valueH != 1 ? affineMap : identityAffineMap),
3316+
offsets[getValueTileHDim()]);
33053317
Value mappedOffsetW = affine::makeComposedAffineApply(
3306-
builder, loc, affineMap, offsets[getValueTileWDim()]);
3318+
builder, loc, (valueW != 1 ? affineMap : identityAffineMap),
3319+
offsets[getValueTileWDim()]);
33073320
Value mappedSizeH = affine::makeComposedAffineApply(
33083321
builder, loc, affineMap, sizes[getValueTileHDim()]);
33093322
Value mappedSizeW = affine::makeComposedAffineApply(
33103323
builder, loc, affineMap, sizes[getValueTileWDim()]);
33113324

3312-
ShapedType valueType = getValueOperandType();
3313-
ArrayRef<int64_t> valueShape = valueType.getShape();
3314-
int64_t valueH = valueShape[0];
3315-
int64_t valueW = valueShape[1];
33163325
IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
3317-
IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
3318-
OpFoldResult offsetH =
3319-
valueH != 1 ? OpFoldResult(mappedOffsetH) : OpFoldResult(zeroAttr);
3320-
OpFoldResult offsetW =
3321-
valueW != 1 ? OpFoldResult(mappedOffsetW) : OpFoldResult(zeroAttr);
3326+
OpFoldResult offsetH = OpFoldResult(mappedOffsetH);
3327+
OpFoldResult offsetW = OpFoldResult(mappedOffsetW);
33223328
OpFoldResult sizeH =
33233329
valueH != 1 ? OpFoldResult(mappedSizeH) : OpFoldResult(oneAttr);
33243330
OpFoldResult sizeW =

mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -514,12 +514,14 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
514514
Value CIter = ivs[3];
515515

516516
auto context = builder.getContext();
517+
518+
auto identityAffineMap = rewriter.getMultiDimIdentityMap(1);
517519
auto affineMap =
518520
AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
519-
Value heightOffset =
520-
builder.create<affine::AffineApplyOp>(loc, affineMap, tileHIter);
521-
Value widthOffset =
522-
builder.create<affine::AffineApplyOp>(loc, affineMap, tileWIter);
521+
Value heightOffset = builder.create<affine::AffineApplyOp>(
522+
loc, leftTransform ? affineMap : identityAffineMap, tileHIter);
523+
Value widthOffset = builder.create<affine::AffineApplyOp>(
524+
loc, rightTransform ? affineMap : identityAffineMap, tileWIter);
523525

524526
// Extract (H, W) from (N, H, W, C).
525527
auto extractInput =
@@ -753,12 +755,13 @@ Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
753755
Value zero = builder.create<arith::ConstantOp>(
754756
loc, rewriter.getZeroAttr(elementType));
755757

758+
auto identityAffineMap = rewriter.getMultiDimIdentityMap(1);
756759
auto affineMap =
757760
AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
758-
Value heightOffset =
759-
builder.create<affine::AffineApplyOp>(loc, affineMap, tileHIter);
760-
Value widthOffset =
761-
builder.create<affine::AffineApplyOp>(loc, affineMap, tileWIter);
761+
Value heightOffset = builder.create<affine::AffineApplyOp>(
762+
loc, leftTransform ? affineMap : identityAffineMap, tileHIter);
763+
Value widthOffset = builder.create<affine::AffineApplyOp>(
764+
loc, rightTransform ? affineMap : identityAffineMap, tileWIter);
762765

763766
Value outInitVal =
764767
extract2DDataFrom4D(builder, loc, args[0], NIter, FIter, heightOffset,
@@ -1075,16 +1078,17 @@ FailureOr<Operation *>
10751078
decomposeWinogradInputTransformHelper(RewriterBase &rewriter,
10761079
linalg::WinogradInputTransformOp op) {
10771080
Location loc = op.getLoc();
1078-
Value input = op.getInput();
1079-
auto inputType = cast<ShapedType>(input.getType());
1080-
auto inputShape = inputType.getShape();
1081-
int64_t inputH = inputShape[1];
1082-
int64_t inputW = inputShape[2];
1081+
Value output = op.getOutput();
1082+
auto outputType = cast<ShapedType>(output.getType());
1083+
auto outputShape = outputType.getShape();
1084+
1085+
int64_t outputH = outputShape[0];
1086+
int64_t outputW = outputShape[1];
10831087

10841088
// For F(m x 1, r x 1), we only need to do left side transform.
1085-
bool leftTransform = inputH != 1;
1089+
bool leftTransform = outputH != 1;
10861090
// For F(1 x m, 1 x r), we only need to do right side transform.
1087-
bool rightTransform = inputW != 1;
1091+
bool rightTransform = outputW != 1;
10881092
Value transformedInput =
10891093
inputTransform(rewriter, loc, op.getInput(), op.getOutput(), op.getM(),
10901094
op.getR(), leftTransform, rightTransform);

0 commit comments

Comments
 (0)