Skip to content

Commit 1596105

Browse files
committed
Address comments
1 parent fc3bf65 commit 1596105

File tree

5 files changed

+256
-507
lines changed

5 files changed

+256
-507
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -155,12 +155,7 @@ def Linalg_SoftmaxOp : Linalg_Op<"softmax",
155155
}
156156

157157
def Linalg_WinogradFilterTransformOp : Linalg_Op<"winograd_filter_transform",
158-
[AllElementTypesMatch<["filter", "output"]>,
159-
DeclareOpInterfaceMethods<TilingInterface,
160-
["getIterationDomain",
161-
"getLoopIteratorTypes",
162-
"getResultTilePosition",
163-
"getTiledImplementation"]>]> {
158+
[AllElementTypesMatch<["filter", "output"]>]> {
164159
let summary = "Winograd filter transform operator";
165160
let description = [{
166161
Winograd Conv2D algorithm will convert linalg Conv2D operator into batched

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2644,16 +2644,16 @@ def DecomposeWinogradOp : Op<Transform_Dialect,
26442644
TransformOpInterface, TransformEachOpTrait,
26452645
ReportTrackingListenerFailuresOpTrait]> {
26462646
let description = [{
2647-
Decompose winograd operators. It will convert filter, input and output
2648-
transform operators into a combination of scf, tensor, and linalg
2649-
equivalent operators. Before applying this transform operator, users
2650-
need to tile winograd transform operators into supported sizes.
2647+
Decompose winograd operations. It will convert filter, input and output
2648+
transform operations into a combination of scf, tensor, and linalg
2649+
equivalent operations. Before applying this transform operations, users
2650+
need to tile winograd transform operations into supported sizes.
26512651

26522652
#### Return modes:
26532653

26542654
This operation fails if `target` is unsupported. Otherwise, the operation
26552655
succeeds and returns a handle of the sequence that replaces the original
2656-
operator.
2656+
operations.
26572657
}];
26582658

26592659
let arguments = (ins TransformHandleTypeInterface:$target);

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 27 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -2755,93 +2755,19 @@ LogicalResult WinogradFilterTransformOp::verify() {
27552755
return success();
27562756
}
27572757

2758-
SmallVector<Range>
2759-
WinogradFilterTransformOp::getIterationDomain(OpBuilder &builder) {
2760-
Location loc = getLoc();
2761-
Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
2762-
Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
2763-
Value output = getOutput();
2764-
SmallVector<Range> loopBounds(6);
2765-
for (unsigned dim = 0; dim < 6; ++dim) {
2766-
loopBounds[dim].offset = zero;
2767-
loopBounds[dim].size = getDimValue(builder, loc, output, dim);
2768-
loopBounds[dim].stride = one;
2769-
}
2770-
return loopBounds;
2771-
}
2772-
2773-
SmallVector<utils::IteratorType>
2774-
WinogradFilterTransformOp::getLoopIteratorTypes() {
2775-
SmallVector<utils::IteratorType> iteratorTypes(6,
2776-
utils::IteratorType::parallel);
2777-
return iteratorTypes;
2778-
}
2758+
//===----------------------------------------------------------------------===//
2759+
// WinogradInputTransformOp
2760+
//===----------------------------------------------------------------------===//
27792761

27802762
Value getValueFromOpFoldResult(OpFoldResult opFoldResult, OpBuilder &builder,
27812763
Location loc) {
2782-
if (auto val = opFoldResult.dyn_cast<Value>()) {
2783-
return val;
2784-
} else if (auto attr = opFoldResult.dyn_cast<Attribute>()) {
2764+
if (auto attr = opFoldResult.dyn_cast<Attribute>()) {
27852765
auto intAttr = cast<IntegerAttr>(attr);
27862766
return builder.create<arith::ConstantOp>(loc, intAttr);
27872767
}
2788-
// This should never happen if OpFoldResult is correctly formed.
2789-
return nullptr;
2768+
return opFoldResult.get<Value>();
27902769
}
27912770

2792-
LogicalResult WinogradFilterTransformOp::getResultTilePosition(
2793-
OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
2794-
ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
2795-
SmallVector<OpFoldResult> &resultSizes) {
2796-
auto zeroAttr = builder.getI64IntegerAttr(0);
2797-
auto oneAttr = builder.getI64IntegerAttr(1);
2798-
2799-
resultOffsets.push_back(offsets[0]);
2800-
resultOffsets.push_back(offsets[1]);
2801-
resultOffsets.push_back(zeroAttr);
2802-
resultOffsets.push_back(zeroAttr);
2803-
resultOffsets.push_back(zeroAttr);
2804-
resultOffsets.push_back(zeroAttr);
2805-
resultSizes.push_back(oneAttr);
2806-
resultSizes.push_back(oneAttr);
2807-
resultSizes.push_back(sizes[2]);
2808-
resultSizes.push_back(sizes[3]);
2809-
resultSizes.push_back(sizes[4]);
2810-
resultSizes.push_back(sizes[5]);
2811-
2812-
return success();
2813-
}
2814-
2815-
FailureOr<TilingResult> WinogradFilterTransformOp::getTiledImplementation(
2816-
OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
2817-
ArrayRef<OpFoldResult> sizes) {
2818-
auto oneAttr = builder.getI64IntegerAttr(1);
2819-
2820-
Location loc = getLoc();
2821-
SmallVector<OpFoldResult> strides(6, oneAttr);
2822-
SmallVector<Value> tiledOperands;
2823-
tiledOperands.emplace_back(getFilter());
2824-
2825-
SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
2826-
if (failed(getResultTilePosition(builder, 1, offsets, sizes, sliceOffsets,
2827-
sliceSizes)))
2828-
return failure();
2829-
2830-
tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
2831-
loc, getOutput(), sliceOffsets, sliceSizes, strides));
2832-
2833-
SmallVector<Type, 4> resultTypes;
2834-
resultTypes.push_back(tiledOperands[1].getType());
2835-
Operation *tiledOp =
2836-
mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
2837-
2838-
return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
2839-
}
2840-
2841-
//===----------------------------------------------------------------------===//
2842-
// WinogradInputTransformOp
2843-
//===----------------------------------------------------------------------===//
2844-
28452771
LogicalResult WinogradInputTransformOp::verify() {
28462772
auto inputType = cast<ShapedType>(getInput().getType());
28472773
ArrayRef<int64_t> inputShape = inputType.getShape();
@@ -2887,14 +2813,15 @@ LogicalResult WinogradInputTransformOp::verify() {
28872813
SmallVector<Range>
28882814
WinogradInputTransformOp::getIterationDomain(OpBuilder &builder) {
28892815
Location loc = getLoc();
2890-
Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
2891-
Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
2816+
auto indexType = builder.getIndexType();
2817+
auto zeroAttr = builder.getIntegerAttr(indexType, 0);
2818+
auto oneAttr = builder.getIntegerAttr(indexType, 1);
28922819
Value output = getOutput();
28932820
SmallVector<Range> loopBounds(6);
28942821
for (unsigned dim = 0; dim < 6; ++dim) {
2895-
loopBounds[dim].offset = zero;
2822+
loopBounds[dim].offset = zeroAttr;
28962823
loopBounds[dim].size = getDimValue(builder, loc, output, dim);
2897-
loopBounds[dim].stride = one;
2824+
loopBounds[dim].stride = oneAttr;
28982825
}
28992826
return loopBounds;
29002827
}
@@ -2913,16 +2840,16 @@ LogicalResult WinogradInputTransformOp::getResultTilePosition(
29132840
auto zeroAttr = builder.getI64IntegerAttr(0);
29142841
auto oneAttr = builder.getI64IntegerAttr(1);
29152842

2916-
resultOffsets.push_back(offsets[0]);
2917-
resultOffsets.push_back(offsets[1]);
29182843
resultOffsets.push_back(zeroAttr);
29192844
resultOffsets.push_back(zeroAttr);
2845+
resultOffsets.push_back(offsets[2]);
2846+
resultOffsets.push_back(offsets[3]);
29202847
resultOffsets.push_back(zeroAttr);
29212848
resultOffsets.push_back(zeroAttr);
2849+
resultSizes.push_back(sizes[0]);
2850+
resultSizes.push_back(sizes[1]);
29222851
resultSizes.push_back(oneAttr);
29232852
resultSizes.push_back(oneAttr);
2924-
resultSizes.push_back(sizes[2]);
2925-
resultSizes.push_back(sizes[3]);
29262853
resultSizes.push_back(sizes[4]);
29272854
resultSizes.push_back(sizes[5]);
29282855

@@ -2956,9 +2883,9 @@ WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
29562883
auto affineMap =
29572884
AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
29582885
Value mappedOffset1 = builder.create<affine::AffineApplyOp>(
2959-
loc, affineMap, getValueFromOpFoldResult(offsets[0], builder, loc));
2886+
loc, affineMap, getValueFromOpFoldResult(offsets[2], builder, loc));
29602887
Value mappedOffset2 = builder.create<affine::AffineApplyOp>(
2961-
loc, affineMap, getValueFromOpFoldResult(offsets[1], builder, loc));
2888+
loc, affineMap, getValueFromOpFoldResult(offsets[3], builder, loc));
29622889

29632890
sliceOffsets.push_back(zeroAttr);
29642891
sliceOffsets.push_back(mappedOffset1);
@@ -3033,14 +2960,15 @@ LogicalResult WinogradOutputTransformOp::verify() {
30332960
SmallVector<Range>
30342961
WinogradOutputTransformOp::getIterationDomain(OpBuilder &builder) {
30352962
Location loc = getLoc();
3036-
Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
3037-
Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
2963+
auto indexType = builder.getIndexType();
2964+
auto zeroAttr = builder.getIntegerAttr(indexType, 0);
2965+
auto oneAttr = builder.getIntegerAttr(indexType, 1);
30382966
Value value = getValue();
30392967
SmallVector<Range> loopBounds(6);
30402968
for (unsigned dim = 0; dim < 6; ++dim) {
3041-
loopBounds[dim].offset = zero;
2969+
loopBounds[dim].offset = zeroAttr;
30422970
loopBounds[dim].size = getDimValue(builder, loc, value, dim);
3043-
loopBounds[dim].stride = one;
2971+
loopBounds[dim].stride = oneAttr;
30442972
}
30452973
return loopBounds;
30462974
}
@@ -3071,9 +2999,9 @@ LogicalResult WinogradOutputTransformOp::getResultTilePosition(
30712999
auto affineMap =
30723000
AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
30733001
Value mappedOffset1 = builder.create<affine::AffineApplyOp>(
3074-
loc, affineMap, getValueFromOpFoldResult(offsets[0], builder, loc));
3002+
loc, affineMap, getValueFromOpFoldResult(offsets[2], builder, loc));
30753003
Value mappedOffset2 = builder.create<affine::AffineApplyOp>(
3076-
loc, affineMap, getValueFromOpFoldResult(offsets[1], builder, loc));
3004+
loc, affineMap, getValueFromOpFoldResult(offsets[3], builder, loc));
30773005

30783006
resultOffsets.push_back(zeroAttr);
30793007
resultOffsets.push_back(mappedOffset1);
@@ -3095,16 +3023,16 @@ FailureOr<TilingResult> WinogradOutputTransformOp::getTiledImplementation(
30953023
SmallVector<Value> tiledOperands;
30963024
SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
30973025

3098-
sliceOffsets.push_back(offsets[0]);
3099-
sliceOffsets.push_back(offsets[1]);
31003026
sliceOffsets.push_back(zeroAttr);
31013027
sliceOffsets.push_back(zeroAttr);
3028+
sliceOffsets.push_back(offsets[2]);
3029+
sliceOffsets.push_back(offsets[3]);
31023030
sliceOffsets.push_back(zeroAttr);
31033031
sliceOffsets.push_back(zeroAttr);
3032+
sliceSizes.push_back(sizes[0]);
3033+
sliceSizes.push_back(sizes[1]);
31043034
sliceSizes.push_back(oneAttr);
31053035
sliceSizes.push_back(oneAttr);
3106-
sliceSizes.push_back(sizes[2]);
3107-
sliceSizes.push_back(sizes[3]);
31083036
sliceSizes.push_back(sizes[4]);
31093037
sliceSizes.push_back(sizes[5]);
31103038
SmallVector<OpFoldResult> sliceStrides(6, oneAttr);

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3518,23 +3518,37 @@ DiagnosedSilenceableFailure transform::DecomposeWinogradOp::applyToOne(
35183518
transform::ApplyToEachResultList &results,
35193519
transform::TransformState &state) {
35203520
rewriter.setInsertionPoint(target);
3521-
auto maybeTransformed =
3522-
TypeSwitch<Operation *, FailureOr<Operation *>>(target)
3521+
FailureOr<Operation *> maybeTransformed = failure();
3522+
bool supported =
3523+
TypeSwitch<Operation *, bool>(target)
35233524
.Case([&](linalg::WinogradFilterTransformOp op) {
3524-
return decomposeWinogradFilterTransformOp(rewriter, op);
3525+
maybeTransformed = decomposeWinogradFilterTransformOp(rewriter, op);
3526+
return true;
35253527
})
35263528
.Case([&](linalg::WinogradInputTransformOp op) {
3527-
return decomposeWinogradInputTransformOp(rewriter, op);
3529+
maybeTransformed = decomposeWinogradInputTransformOp(rewriter, op);
3530+
return true;
35283531
})
35293532
.Case([&](linalg::WinogradOutputTransformOp op) {
3530-
return decomposeWinogradOutputTransformOp(rewriter, op);
3533+
maybeTransformed = decomposeWinogradOutputTransformOp(rewriter, op);
3534+
return true;
35313535
})
3532-
.Default([&](Operation *op) {
3533-
return rewriter.notifyMatchFailure(op, "not supported");
3534-
});
3536+
.Default([&](Operation *op) { return false; });
35353537

3536-
if (failed(maybeTransformed))
3537-
return emitDefaultSilenceableFailure(target);
3538+
if (!supported) {
3539+
DiagnosedSilenceableFailure diag =
3540+
emitSilenceableError()
3541+
<< "this operation is not supported to decompose into other operations";
3542+
diag.attachNote(target->getLoc()) << "target op";
3543+
return diag;
3544+
}
3545+
3546+
if (supported && failed(maybeTransformed)) {
3547+
DiagnosedSilenceableFailure diag =
3548+
emitSilenceableError() << "decompose Winograd operations failed";
3549+
diag.attachNote(target->getLoc()) << "target op";
3550+
return diag;
3551+
}
35383552

35393553
results.push_back(*maybeTransformed);
35403554
return DiagnosedSilenceableFailure::success();

0 commit comments

Comments
 (0)