Skip to content

Commit 936356f

Browse files
committed
Address ftynse's comments
1 parent 7247795 commit 936356f

File tree

2 files changed

+226
-228
lines changed

2 files changed

+226
-228
lines changed

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 44 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -2776,15 +2776,6 @@ LogicalResult WinogradFilterTransformOp::verify() {
27762776
// WinogradInputTransformOp
27772777
//===----------------------------------------------------------------------===//
27782778

2779-
Value getValueFromOpFoldResult(OpFoldResult opFoldResult, OpBuilder &builder,
2780-
Location loc) {
2781-
if (auto attr = opFoldResult.dyn_cast<Attribute>()) {
2782-
auto intAttr = cast<IntegerAttr>(attr);
2783-
return builder.create<arith::ConstantOp>(loc, intAttr);
2784-
}
2785-
return opFoldResult.get<Value>();
2786-
}
2787-
27882779
LogicalResult WinogradInputTransformOp::verify() {
27892780
auto inputType = cast<ShapedType>(getInput().getType());
27902781
ArrayRef<int64_t> inputShape = inputType.getShape();
@@ -2825,9 +2816,9 @@ LogicalResult WinogradInputTransformOp::verify() {
28252816
SmallVector<Range>
28262817
WinogradInputTransformOp::getIterationDomain(OpBuilder &builder) {
28272818
Location loc = getLoc();
2828-
auto indexType = builder.getIndexType();
2829-
auto zeroAttr = builder.getIntegerAttr(indexType, 0);
2830-
auto oneAttr = builder.getIntegerAttr(indexType, 1);
2819+
IndexType indexType = builder.getIndexType();
2820+
IntegerAttr zeroAttr = builder.getIntegerAttr(indexType, 0);
2821+
IntegerAttr oneAttr = builder.getIntegerAttr(indexType, 1);
28312822
Value output = getOutput();
28322823
SmallVector<Range> loopBounds(6);
28332824
for (unsigned dim = 0; dim < 6; ++dim) {
@@ -2849,21 +2840,13 @@ LogicalResult WinogradInputTransformOp::getResultTilePosition(
28492840
OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
28502841
ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
28512842
SmallVector<OpFoldResult> &resultSizes) {
2852-
auto zeroAttr = builder.getI64IntegerAttr(0);
2853-
auto oneAttr = builder.getI64IntegerAttr(1);
2843+
IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
2844+
IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
28542845

2855-
resultOffsets.push_back(zeroAttr);
2856-
resultOffsets.push_back(zeroAttr);
2857-
resultOffsets.push_back(offsets[2]);
2858-
resultOffsets.push_back(offsets[3]);
2859-
resultOffsets.push_back(zeroAttr);
2860-
resultOffsets.push_back(zeroAttr);
2861-
resultSizes.push_back(sizes[0]);
2862-
resultSizes.push_back(sizes[1]);
2863-
resultSizes.push_back(oneAttr);
2864-
resultSizes.push_back(oneAttr);
2865-
resultSizes.push_back(sizes[4]);
2866-
resultSizes.push_back(sizes[5]);
2846+
resultOffsets.append(
2847+
{zeroAttr, zeroAttr, offsets[2], offsets[3], zeroAttr, zeroAttr});
2848+
resultSizes.append(
2849+
{sizes[0], sizes[1], oneAttr, oneAttr, sizes[4], sizes[5]});
28672850

28682851
return success();
28692852
}
@@ -2872,41 +2855,37 @@ FailureOr<TilingResult>
28722855
WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
28732856
ArrayRef<OpFoldResult> offsets,
28742857
ArrayRef<OpFoldResult> sizes) {
2875-
auto oneAttr = builder.getI64IntegerAttr(1);
2876-
auto zeroAttr = builder.getI64IntegerAttr(0);
2858+
IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
2859+
IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
28772860
Value input = getInput();
28782861
auto inputType = cast<ShapedType>(input.getType());
2879-
auto inputShape = inputType.getShape();
2862+
ArrayRef<int64_t> inputShape = inputType.getShape();
28802863
int64_t inputH = inputShape[1];
28812864
int64_t inputW = inputShape[2];
28822865
int64_t m = getM();
28832866
int64_t r = getR();
28842867
int64_t alpha = m + r - 1;
28852868
int64_t alphaH = inputH != 1 ? alpha : 1;
28862869
int64_t alphaW = inputW != 1 ? alpha : 1;
2887-
auto alphaHAttr = builder.getI64IntegerAttr(alphaH);
2888-
auto alphaWAttr = builder.getI64IntegerAttr(alphaW);
2870+
IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
2871+
IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
28892872

28902873
Location loc = getLoc();
28912874
SmallVector<Value> tiledOperands;
28922875
SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
28932876

2894-
auto context = builder.getContext();
2877+
MLIRContext *context = builder.getContext();
28952878
auto affineMap =
28962879
AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
28972880
Value mappedOffset1 = builder.create<affine::AffineApplyOp>(
2898-
loc, affineMap, getValueFromOpFoldResult(offsets[2], builder, loc));
2881+
loc, affineMap,
2882+
getValueOrCreateConstantIndexOp(builder, loc, offsets[2]));
28992883
Value mappedOffset2 = builder.create<affine::AffineApplyOp>(
2900-
loc, affineMap, getValueFromOpFoldResult(offsets[3], builder, loc));
2901-
2902-
sliceOffsets.push_back(zeroAttr);
2903-
sliceOffsets.push_back(mappedOffset1);
2904-
sliceOffsets.push_back(mappedOffset2);
2905-
sliceOffsets.push_back(zeroAttr);
2906-
sliceSizes.push_back(sizes[4]);
2907-
sliceSizes.push_back(alphaHAttr);
2908-
sliceSizes.push_back(alphaWAttr);
2909-
sliceSizes.push_back(sizes[5]);
2884+
loc, affineMap,
2885+
getValueOrCreateConstantIndexOp(builder, loc, offsets[3]));
2886+
2887+
sliceOffsets.append({zeroAttr, mappedOffset1, mappedOffset2, zeroAttr});
2888+
sliceSizes.append({sizes[4], alphaHAttr, alphaWAttr, sizes[5]});
29102889
SmallVector<OpFoldResult> inputStrides(4, oneAttr);
29112890
tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
29122891
loc, getInput(), sliceOffsets, sliceSizes, inputStrides));
@@ -2921,7 +2900,7 @@ WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
29212900
tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
29222901
loc, getOutput(), sliceOffsets, sliceSizes, outputStrides));
29232902

2924-
SmallVector<Type, 4> resultTypes;
2903+
SmallVector<Type> resultTypes;
29252904
resultTypes.push_back(tiledOperands[1].getType());
29262905
Operation *tiledOp =
29272906
mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
@@ -2974,9 +2953,9 @@ LogicalResult WinogradOutputTransformOp::verify() {
29742953
SmallVector<Range>
29752954
WinogradOutputTransformOp::getIterationDomain(OpBuilder &builder) {
29762955
Location loc = getLoc();
2977-
auto indexType = builder.getIndexType();
2978-
auto zeroAttr = builder.getIntegerAttr(indexType, 0);
2979-
auto oneAttr = builder.getIntegerAttr(indexType, 1);
2956+
IndexType indexType = builder.getIndexType();
2957+
IntegerAttr zeroAttr = builder.getIntegerAttr(indexType, 0);
2958+
IntegerAttr oneAttr = builder.getIntegerAttr(indexType, 1);
29802959
Value value = getValue();
29812960
SmallVector<Range> loopBounds(6);
29822961
for (unsigned dim = 0; dim < 6; ++dim) {
@@ -2998,57 +2977,44 @@ LogicalResult WinogradOutputTransformOp::getResultTilePosition(
29982977
OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
29992978
ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
30002979
SmallVector<OpFoldResult> &resultSizes) {
3001-
auto zeroAttr = builder.getI64IntegerAttr(0);
2980+
IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
30022981
Value output = getOutput();
30032982
auto outputType = cast<ShapedType>(output.getType());
3004-
auto outputShape = outputType.getShape();
2983+
ArrayRef<int64_t> outputShape = outputType.getShape();
30052984
int64_t outputH = outputShape[1];
30062985
int64_t outputW = outputShape[2];
30072986
int64_t m = getM();
3008-
auto heightM = builder.getI64IntegerAttr(outputH != 1 ? m : 1);
3009-
auto widthM = builder.getI64IntegerAttr(outputW != 1 ? m : 1);
2987+
IntegerAttr heightM = builder.getI64IntegerAttr(outputH != 1 ? m : 1);
2988+
IntegerAttr widthM = builder.getI64IntegerAttr(outputW != 1 ? m : 1);
30102989

30112990
Location loc = getLoc();
3012-
auto context = builder.getContext();
2991+
MLIRContext *context = builder.getContext();
30132992
auto affineMap =
30142993
AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
30152994
Value mappedOffset1 = builder.create<affine::AffineApplyOp>(
3016-
loc, affineMap, getValueFromOpFoldResult(offsets[2], builder, loc));
2995+
loc, affineMap,
2996+
getValueOrCreateConstantIndexOp(builder, loc, offsets[2]));
30172997
Value mappedOffset2 = builder.create<affine::AffineApplyOp>(
3018-
loc, affineMap, getValueFromOpFoldResult(offsets[3], builder, loc));
3019-
3020-
resultOffsets.push_back(zeroAttr);
3021-
resultOffsets.push_back(mappedOffset1);
3022-
resultOffsets.push_back(mappedOffset2);
3023-
resultOffsets.push_back(zeroAttr);
3024-
resultSizes.push_back(sizes[4]);
3025-
resultSizes.push_back(heightM);
3026-
resultSizes.push_back(widthM);
3027-
resultSizes.push_back(sizes[5]);
2998+
loc, affineMap,
2999+
getValueOrCreateConstantIndexOp(builder, loc, offsets[3]));
3000+
3001+
resultOffsets.append({zeroAttr, mappedOffset1, mappedOffset2, zeroAttr});
3002+
resultSizes.append({sizes[4], heightM, widthM, sizes[5]});
30283003
return success();
30293004
}
30303005

30313006
FailureOr<TilingResult> WinogradOutputTransformOp::getTiledImplementation(
30323007
OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
30333008
ArrayRef<OpFoldResult> sizes) {
3034-
auto oneAttr = builder.getI64IntegerAttr(1);
3035-
auto zeroAttr = builder.getI64IntegerAttr(0);
3009+
IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
3010+
IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
30363011
Location loc = getLoc();
30373012
SmallVector<Value> tiledOperands;
30383013
SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
30393014

3040-
sliceOffsets.push_back(zeroAttr);
3041-
sliceOffsets.push_back(zeroAttr);
3042-
sliceOffsets.push_back(offsets[2]);
3043-
sliceOffsets.push_back(offsets[3]);
3044-
sliceOffsets.push_back(zeroAttr);
3045-
sliceOffsets.push_back(zeroAttr);
3046-
sliceSizes.push_back(sizes[0]);
3047-
sliceSizes.push_back(sizes[1]);
3048-
sliceSizes.push_back(oneAttr);
3049-
sliceSizes.push_back(oneAttr);
3050-
sliceSizes.push_back(sizes[4]);
3051-
sliceSizes.push_back(sizes[5]);
3015+
sliceOffsets.append(
3016+
{zeroAttr, zeroAttr, offsets[2], offsets[3], zeroAttr, zeroAttr});
3017+
sliceSizes.append({sizes[0], sizes[1], oneAttr, oneAttr, sizes[4], sizes[5]});
30523018
SmallVector<OpFoldResult> sliceStrides(6, oneAttr);
30533019
tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
30543020
loc, getValue(), sliceOffsets, sliceSizes, sliceStrides));
@@ -3063,7 +3029,7 @@ FailureOr<TilingResult> WinogradOutputTransformOp::getTiledImplementation(
30633029
tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
30643030
loc, getOutput(), sliceOffsets, sliceSizes, strides));
30653031

3066-
SmallVector<Type, 4> resultTypes;
3032+
SmallVector<Type> resultTypes;
30673033
resultTypes.push_back(tiledOperands[1].getType());
30683034
Operation *tiledOp =
30693035
mlir::clone(builder, getOperation(), resultTypes, tiledOperands);

0 commit comments

Comments
 (0)