Skip to content

Commit 7247795

Browse files
committed
[mlir][linalg] Implement TilingInterface for winograd operations
In order to support arbitrary size input data of conv2d, implement TilingInterface for winograd operations. Before converting winograd operations into nested loops with matrix multiply, tile the input of conv2d into the supported size first. Add a transform operation structured.decompose_winograd_op to decompose winograd operations. Before applying the transform op, use tile_using_for to tile the input data into supported size. The test case shows how to tile and decompose winograd operations.
1 parent 27ee33d commit 7247795

File tree

7 files changed

+633
-6
lines changed

7 files changed

+633
-6
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -154,8 +154,8 @@ def Linalg_SoftmaxOp : Linalg_Op<"softmax",
154154
let hasVerifier = 1;
155155
}
156156

157-
def Linalg_WinogradFilterTransformOp :
158-
Linalg_Op<"winograd_filter_transform", [AllElementTypesMatch<["filter", "output"]>]> {
157+
def Linalg_WinogradFilterTransformOp : Linalg_Op<"winograd_filter_transform",
158+
[AllElementTypesMatch<["filter", "output"]>]> {
159159
let summary = "Winograd filter transform operator";
160160
let description = [{
161161
Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
@@ -193,8 +193,13 @@ def Linalg_WinogradFilterTransformOp :
193193
let hasVerifier = 1;
194194
}
195195

196-
def Linalg_WinogradInputTransformOp :
197-
Linalg_Op<"winograd_input_transform", [AllElementTypesMatch<["input", "output"]>]> {
196+
def Linalg_WinogradInputTransformOp : Linalg_Op<"winograd_input_transform",
197+
[AllElementTypesMatch<["input", "output"]>,
198+
DeclareOpInterfaceMethods<TilingInterface,
199+
["getIterationDomain",
200+
"getLoopIteratorTypes",
201+
"getResultTilePosition",
202+
"getTiledImplementation"]>]> {
198203
let summary = "Winograd input transform operator";
199204
let description = [{
200205
Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
@@ -232,8 +237,13 @@ def Linalg_WinogradInputTransformOp :
232237
let hasVerifier = 1;
233238
}
234239

235-
def Linalg_WinogradOutputTransformOp :
236-
Linalg_Op<"winograd_output_transform", [AllElementTypesMatch<["value", "output"]>]> {
240+
def Linalg_WinogradOutputTransformOp : Linalg_Op<"winograd_output_transform",
241+
[AllElementTypesMatch<["value", "output"]>,
242+
DeclareOpInterfaceMethods<TilingInterface,
243+
["getIterationDomain",
244+
"getLoopIteratorTypes",
245+
"getResultTilePosition",
246+
"getTiledImplementation"]>]> {
237247
let summary = "Winograd output transform operator";
238248
let description = [{
239249
Winograd Conv2D algorithm will convert linalg Conv2D operator into batched

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2697,4 +2697,41 @@ def WinogradConv2DOp : Op<Transform_Dialect,
26972697
}];
26982698
}
26992699

2700+
def DecomposeWinogradOp : Op<Transform_Dialect,
2701+
"structured.decompose_winograd_op",
2702+
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
2703+
TransformOpInterface, TransformEachOpTrait,
2704+
ReportTrackingListenerFailuresOpTrait]> {
2705+
let description = [{
2706+
Decompose winograd operations. It will convert filter, input and output
2707+
transform operations into a combination of scf, tensor, and linalg
2708+
equivalent operations. Before applying this transform operations, users
2709+
need to tile winograd transform operations into supported sizes.
2710+
2711+
#### Return modes:
2712+
2713+
This operation fails if `target` is unsupported. Otherwise, the operation
2714+
succeeds and returns a handle of the sequence that replaces the original
2715+
operations.
2716+
}];
2717+
2718+
let arguments = (ins TransformHandleTypeInterface:$target);
2719+
let results = (outs TransformHandleTypeInterface:$transformed);
2720+
2721+
let assemblyFormat =
2722+
"$target attr-dict `:` functional-type($target, results)";
2723+
2724+
let builders = [
2725+
OpBuilder<(ins "Value":$target)>
2726+
];
2727+
2728+
let extraClassDeclaration = [{
2729+
::mlir::DiagnosedSilenceableFailure applyToOne(
2730+
::mlir::transform::TransformRewriter &rewriter,
2731+
::mlir::Operation *target,
2732+
::mlir::transform::ApplyToEachResultList &results,
2733+
::mlir::transform::TransformState &state);
2734+
}];
2735+
}
2736+
27002737
#endif // LINALG_TRANSFORM_OPS

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1339,6 +1339,51 @@ FailureOr<Operation *> winogradConv2D(RewriterBase &rewriter,
13391339
linalg::Conv2DNhwcFhwcOp op, int64_t m,
13401340
int64_t r);
13411341

1342+
/// Rewrite linalg.winograd_filter_transform. The data layout of the filter is
1343+
/// FHWC. The transformation matrix is 2-dimension. We need to extract H x W
1344+
/// from FHWC first. We need to generate 2 levels of loops to iterate on F and
1345+
/// C. After the rewriting, we get
1346+
///
1347+
/// scf.for %f = lo_f to hi_f step 1
1348+
/// scf.for %c = lo_c to hi_c step 1
1349+
/// %extracted = extract filter<h x w> from filter<f x h x w x c>
1350+
/// %ret = linalg.matmul G, %extracted
1351+
/// %ret = linalg.matmul %ret, GT
1352+
/// %inserted = insert %ret into filter<h x w x c x f>
1353+
FailureOr<Operation *>
1354+
decomposeWinogradFilterTransformOp(RewriterBase &rewriter,
1355+
linalg::WinogradFilterTransformOp op);
1356+
1357+
/// Rewrite linalg.winograd_input_transform. The data layout of the input is
1358+
/// NHWC. The transformation matrix is 2-dimension. We need to extract H x W
1359+
/// from NHWC first. We need to generate 2 levels of loops to iterate on N and
1360+
/// C. After the rewriting, we get
1361+
///
1362+
/// scf.for %n = lo_n to hi_n step 1
1363+
/// scf.for %c = lo_c to hi_c step 1
1364+
/// %extracted = extract input<h x w> from input<n x h x w x c>
1365+
/// %ret = linalg.matmul BT, %extracted
1366+
/// %ret = linalg.matmul %ret, B
1367+
/// %inserted = insert %ret into input<h x w x n x c>
1368+
FailureOr<Operation *>
1369+
decomposeWinogradInputTransformOp(RewriterBase &rewriter,
1370+
linalg::WinogradInputTransformOp op);
1371+
1372+
/// Rewrite linalg.winograd_output_transform. The data layout of the output is
1373+
/// HWNF. The transformation matrix is 2-dimension. We need to extract H x W
1374+
/// from HWNF first. We need to generate 2 levels of loops to iterate on N and
1375+
/// F. After the transformation, we get
1376+
///
1377+
/// scf.for %n = lo_n to hi_n step 1
1378+
/// scf.for %f = lo_f to hi_f step 1
1379+
/// %extracted = extract input<h x w> from result<h x w x n x f>
1380+
/// %ret = linalg.matmul AT, %extracted
1381+
/// %ret = linalg.matmul %ret, A
1382+
/// %inserted = insert %ret into ret<n x h x w x f>
1383+
FailureOr<Operation *>
1384+
decomposeWinogradOutputTransformOp(RewriterBase &rewriter,
1385+
linalg::WinogradOutputTransformOp op);
1386+
13421387
//===----------------------------------------------------------------------===//
13431388
// Rewrite patterns wrapping transformations.
13441389
// TODO: every single such pattern should be a close to noop wrapper around a

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2776,6 +2776,15 @@ LogicalResult WinogradFilterTransformOp::verify() {
27762776
// WinogradInputTransformOp
27772777
//===----------------------------------------------------------------------===//
27782778

2779+
Value getValueFromOpFoldResult(OpFoldResult opFoldResult, OpBuilder &builder,
2780+
Location loc) {
2781+
if (auto attr = opFoldResult.dyn_cast<Attribute>()) {
2782+
auto intAttr = cast<IntegerAttr>(attr);
2783+
return builder.create<arith::ConstantOp>(loc, intAttr);
2784+
}
2785+
return opFoldResult.get<Value>();
2786+
}
2787+
27792788
LogicalResult WinogradInputTransformOp::verify() {
27802789
auto inputType = cast<ShapedType>(getInput().getType());
27812790
ArrayRef<int64_t> inputShape = inputType.getShape();
@@ -2813,6 +2822,113 @@ LogicalResult WinogradInputTransformOp::verify() {
28132822
return success();
28142823
}
28152824

2825+
SmallVector<Range>
2826+
WinogradInputTransformOp::getIterationDomain(OpBuilder &builder) {
2827+
Location loc = getLoc();
2828+
auto indexType = builder.getIndexType();
2829+
auto zeroAttr = builder.getIntegerAttr(indexType, 0);
2830+
auto oneAttr = builder.getIntegerAttr(indexType, 1);
2831+
Value output = getOutput();
2832+
SmallVector<Range> loopBounds(6);
2833+
for (unsigned dim = 0; dim < 6; ++dim) {
2834+
loopBounds[dim].offset = zeroAttr;
2835+
loopBounds[dim].size = getDimValue(builder, loc, output, dim);
2836+
loopBounds[dim].stride = oneAttr;
2837+
}
2838+
return loopBounds;
2839+
}
2840+
2841+
SmallVector<utils::IteratorType>
2842+
WinogradInputTransformOp::getLoopIteratorTypes() {
2843+
SmallVector<utils::IteratorType> iteratorTypes(6,
2844+
utils::IteratorType::parallel);
2845+
return iteratorTypes;
2846+
}
2847+
2848+
LogicalResult WinogradInputTransformOp::getResultTilePosition(
2849+
OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
2850+
ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
2851+
SmallVector<OpFoldResult> &resultSizes) {
2852+
auto zeroAttr = builder.getI64IntegerAttr(0);
2853+
auto oneAttr = builder.getI64IntegerAttr(1);
2854+
2855+
resultOffsets.push_back(zeroAttr);
2856+
resultOffsets.push_back(zeroAttr);
2857+
resultOffsets.push_back(offsets[2]);
2858+
resultOffsets.push_back(offsets[3]);
2859+
resultOffsets.push_back(zeroAttr);
2860+
resultOffsets.push_back(zeroAttr);
2861+
resultSizes.push_back(sizes[0]);
2862+
resultSizes.push_back(sizes[1]);
2863+
resultSizes.push_back(oneAttr);
2864+
resultSizes.push_back(oneAttr);
2865+
resultSizes.push_back(sizes[4]);
2866+
resultSizes.push_back(sizes[5]);
2867+
2868+
return success();
2869+
}
2870+
2871+
FailureOr<TilingResult>
2872+
WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
2873+
ArrayRef<OpFoldResult> offsets,
2874+
ArrayRef<OpFoldResult> sizes) {
2875+
auto oneAttr = builder.getI64IntegerAttr(1);
2876+
auto zeroAttr = builder.getI64IntegerAttr(0);
2877+
Value input = getInput();
2878+
auto inputType = cast<ShapedType>(input.getType());
2879+
auto inputShape = inputType.getShape();
2880+
int64_t inputH = inputShape[1];
2881+
int64_t inputW = inputShape[2];
2882+
int64_t m = getM();
2883+
int64_t r = getR();
2884+
int64_t alpha = m + r - 1;
2885+
int64_t alphaH = inputH != 1 ? alpha : 1;
2886+
int64_t alphaW = inputW != 1 ? alpha : 1;
2887+
auto alphaHAttr = builder.getI64IntegerAttr(alphaH);
2888+
auto alphaWAttr = builder.getI64IntegerAttr(alphaW);
2889+
2890+
Location loc = getLoc();
2891+
SmallVector<Value> tiledOperands;
2892+
SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
2893+
2894+
auto context = builder.getContext();
2895+
auto affineMap =
2896+
AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
2897+
Value mappedOffset1 = builder.create<affine::AffineApplyOp>(
2898+
loc, affineMap, getValueFromOpFoldResult(offsets[2], builder, loc));
2899+
Value mappedOffset2 = builder.create<affine::AffineApplyOp>(
2900+
loc, affineMap, getValueFromOpFoldResult(offsets[3], builder, loc));
2901+
2902+
sliceOffsets.push_back(zeroAttr);
2903+
sliceOffsets.push_back(mappedOffset1);
2904+
sliceOffsets.push_back(mappedOffset2);
2905+
sliceOffsets.push_back(zeroAttr);
2906+
sliceSizes.push_back(sizes[4]);
2907+
sliceSizes.push_back(alphaHAttr);
2908+
sliceSizes.push_back(alphaWAttr);
2909+
sliceSizes.push_back(sizes[5]);
2910+
SmallVector<OpFoldResult> inputStrides(4, oneAttr);
2911+
tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
2912+
loc, getInput(), sliceOffsets, sliceSizes, inputStrides));
2913+
2914+
sliceOffsets.clear();
2915+
sliceSizes.clear();
2916+
if (failed(getResultTilePosition(builder, 1, offsets, sizes, sliceOffsets,
2917+
sliceSizes)))
2918+
return failure();
2919+
2920+
SmallVector<OpFoldResult> outputStrides(6, oneAttr);
2921+
tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
2922+
loc, getOutput(), sliceOffsets, sliceSizes, outputStrides));
2923+
2924+
SmallVector<Type, 4> resultTypes;
2925+
resultTypes.push_back(tiledOperands[1].getType());
2926+
Operation *tiledOp =
2927+
mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
2928+
2929+
return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
2930+
}
2931+
28162932
//===----------------------------------------------------------------------===//
28172933
// WinogradOutputTransformOp
28182934
//===----------------------------------------------------------------------===//
@@ -2855,6 +2971,106 @@ LogicalResult WinogradOutputTransformOp::verify() {
28552971
return success();
28562972
}
28572973

2974+
SmallVector<Range>
2975+
WinogradOutputTransformOp::getIterationDomain(OpBuilder &builder) {
2976+
Location loc = getLoc();
2977+
auto indexType = builder.getIndexType();
2978+
auto zeroAttr = builder.getIntegerAttr(indexType, 0);
2979+
auto oneAttr = builder.getIntegerAttr(indexType, 1);
2980+
Value value = getValue();
2981+
SmallVector<Range> loopBounds(6);
2982+
for (unsigned dim = 0; dim < 6; ++dim) {
2983+
loopBounds[dim].offset = zeroAttr;
2984+
loopBounds[dim].size = getDimValue(builder, loc, value, dim);
2985+
loopBounds[dim].stride = oneAttr;
2986+
}
2987+
return loopBounds;
2988+
}
2989+
2990+
SmallVector<utils::IteratorType>
2991+
WinogradOutputTransformOp::getLoopIteratorTypes() {
2992+
SmallVector<utils::IteratorType> iteratorTypes(6,
2993+
utils::IteratorType::parallel);
2994+
return iteratorTypes;
2995+
}
2996+
2997+
LogicalResult WinogradOutputTransformOp::getResultTilePosition(
2998+
OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
2999+
ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
3000+
SmallVector<OpFoldResult> &resultSizes) {
3001+
auto zeroAttr = builder.getI64IntegerAttr(0);
3002+
Value output = getOutput();
3003+
auto outputType = cast<ShapedType>(output.getType());
3004+
auto outputShape = outputType.getShape();
3005+
int64_t outputH = outputShape[1];
3006+
int64_t outputW = outputShape[2];
3007+
int64_t m = getM();
3008+
auto heightM = builder.getI64IntegerAttr(outputH != 1 ? m : 1);
3009+
auto widthM = builder.getI64IntegerAttr(outputW != 1 ? m : 1);
3010+
3011+
Location loc = getLoc();
3012+
auto context = builder.getContext();
3013+
auto affineMap =
3014+
AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
3015+
Value mappedOffset1 = builder.create<affine::AffineApplyOp>(
3016+
loc, affineMap, getValueFromOpFoldResult(offsets[2], builder, loc));
3017+
Value mappedOffset2 = builder.create<affine::AffineApplyOp>(
3018+
loc, affineMap, getValueFromOpFoldResult(offsets[3], builder, loc));
3019+
3020+
resultOffsets.push_back(zeroAttr);
3021+
resultOffsets.push_back(mappedOffset1);
3022+
resultOffsets.push_back(mappedOffset2);
3023+
resultOffsets.push_back(zeroAttr);
3024+
resultSizes.push_back(sizes[4]);
3025+
resultSizes.push_back(heightM);
3026+
resultSizes.push_back(widthM);
3027+
resultSizes.push_back(sizes[5]);
3028+
return success();
3029+
}
3030+
3031+
FailureOr<TilingResult> WinogradOutputTransformOp::getTiledImplementation(
3032+
OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
3033+
ArrayRef<OpFoldResult> sizes) {
3034+
auto oneAttr = builder.getI64IntegerAttr(1);
3035+
auto zeroAttr = builder.getI64IntegerAttr(0);
3036+
Location loc = getLoc();
3037+
SmallVector<Value> tiledOperands;
3038+
SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3039+
3040+
sliceOffsets.push_back(zeroAttr);
3041+
sliceOffsets.push_back(zeroAttr);
3042+
sliceOffsets.push_back(offsets[2]);
3043+
sliceOffsets.push_back(offsets[3]);
3044+
sliceOffsets.push_back(zeroAttr);
3045+
sliceOffsets.push_back(zeroAttr);
3046+
sliceSizes.push_back(sizes[0]);
3047+
sliceSizes.push_back(sizes[1]);
3048+
sliceSizes.push_back(oneAttr);
3049+
sliceSizes.push_back(oneAttr);
3050+
sliceSizes.push_back(sizes[4]);
3051+
sliceSizes.push_back(sizes[5]);
3052+
SmallVector<OpFoldResult> sliceStrides(6, oneAttr);
3053+
tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
3054+
loc, getValue(), sliceOffsets, sliceSizes, sliceStrides));
3055+
3056+
sliceOffsets.clear();
3057+
sliceSizes.clear();
3058+
if (failed(getResultTilePosition(builder, 1, offsets, sizes, sliceOffsets,
3059+
sliceSizes)))
3060+
return failure();
3061+
3062+
SmallVector<OpFoldResult> strides(4, oneAttr);
3063+
tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
3064+
loc, getOutput(), sliceOffsets, sliceSizes, strides));
3065+
3066+
SmallVector<Type, 4> resultTypes;
3067+
resultTypes.push_back(tiledOperands[1].getType());
3068+
Operation *tiledOp =
3069+
mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3070+
3071+
return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
3072+
}
3073+
28583074
//===----------------------------------------------------------------------===//
28593075
// LinalgDialect
28603076
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)