Skip to content

Commit 4226de9

Browse files
committed
[mlir][tosa] Update value to values for ConstShapeOp
Signed-off-by: Jerry Ge <[email protected]> Change-Id: Ia31f233f4e051eb7a565c26f496a02cb17cf9828
1 parent c1d0380 commit 4226de9

26 files changed

+418
-416
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,11 +67,11 @@ def Tosa_ConstShapeOp : Tosa_ShapeOp<"const_shape", [ConstantLike, Pure]> {
6767

6868
```mlir
6969
// Generic form
70-
%out = "tosa.const_shape"() {value = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4>
70+
%out = "tosa.const_shape"() {values = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4>
7171
```
7272
}];
7373

74-
let arguments = (ins IndexElementsAttr : $value);
74+
let arguments = (ins IndexElementsAttr : $values);
7575

7676
let results = (outs Tosa_Shape : $output);
7777

mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -237,8 +237,8 @@ Value getTosaConstShape(PatternRewriter &rewriter, Location loc,
237237

238238
SmallVector<int64_t> convertFromMlirShape(ArrayRef<int64_t> shape);
239239

240-
bool getConstShapeValue(Operation *op,
241-
llvm::SmallVector<int64_t> &result_shape);
240+
bool getConstShapeValues(Operation *op,
241+
llvm::SmallVector<int64_t> &result_shape);
242242

243243
// returns a small vector of int64_t values that attr contains
244244
SmallVector<int64_t> convertFromIntAttr(const DenseElementsAttr &attr,

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1578,7 +1578,7 @@ class ResizeUnaryConverter : public OpRewritePattern<tosa::ResizeOp> {
15781578
}
15791579

15801580
SmallVector<int64_t> scale;
1581-
if (!tosa::getConstShapeValue(op.getScale().getDefiningOp(), scale)) {
1581+
if (!tosa::getConstShapeValues(op.getScale().getDefiningOp(), scale)) {
15821582
return failure();
15831583
}
15841584

@@ -1799,9 +1799,9 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
17991799
Value inX = b.create<arith::IndexCastOp>(b.getI32Type(), x);
18001800

18011801
SmallVector<int64_t> scale, offset, border;
1802-
if (!tosa::getConstShapeValue(op.getScale().getDefiningOp(), scale) ||
1803-
!tosa::getConstShapeValue(op.getOffset().getDefiningOp(), offset) ||
1804-
!tosa::getConstShapeValue(op.getBorder().getDefiningOp(), border)) {
1802+
if (!tosa::getConstShapeValues(op.getScale().getDefiningOp(), scale) ||
1803+
!tosa::getConstShapeValues(op.getOffset().getDefiningOp(), offset) ||
1804+
!tosa::getConstShapeValues(op.getBorder().getDefiningOp(), border)) {
18051805
return rewriter.notifyMatchFailure(
18061806
op, "tosa.resize scale/offset/border should have compile time "
18071807
"constant values.");

mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -243,8 +243,8 @@ class ReshapeConverter : public OpConversionPattern<tosa::ReshapeOp> {
243243
}
244244

245245
llvm::SmallVector<int64_t> newShape;
246-
if (!tosa::getConstShapeValue(reshape.getShape().getDefiningOp(),
247-
newShape)) {
246+
if (!tosa::getConstShapeValues(reshape.getShape().getDefiningOp(),
247+
newShape)) {
248248
return failure();
249249
}
250250

mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -884,7 +884,7 @@ OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
884884

885885
OpFoldResult ConstOp::fold(FoldAdaptor adaptor) { return getValuesAttr(); }
886886

887-
OpFoldResult ConstShapeOp::fold(FoldAdaptor adaptor) { return getValueAttr(); }
887+
OpFoldResult ConstShapeOp::fold(FoldAdaptor adaptor) { return getValuesAttr(); }
888888

889889
#define REDUCE_FOLDER(OP) \
890890
OpFoldResult OP::fold(FoldAdaptor adaptor) { \
@@ -947,7 +947,7 @@ OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
947947
return {};
948948

949949
llvm::SmallVector<int64_t> shapeVec;
950-
if (!tosa::getConstShapeValue(getShape().getDefiningOp(), shapeVec))
950+
if (!tosa::getConstShapeValues(getShape().getDefiningOp(), shapeVec))
951951
return {};
952952

953953
return operand.reshape(

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1162,8 +1162,8 @@ LogicalResult tosa::PadOp::inferReturnTypeComponents(
11621162

11631163
SmallVector<int64_t> paddingValues;
11641164
// If the paddings value is not a constant, all dimensions must be dynamic.
1165-
if (!tosa::getConstShapeValue(adaptor.getPadding().getDefiningOp(),
1166-
paddingValues)) {
1165+
if (!tosa::getConstShapeValues(adaptor.getPadding().getDefiningOp(),
1166+
paddingValues)) {
11671167
outputShape.resize(inputShape.getRank(), ShapedType::kDynamic);
11681168
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
11691169
return success();
@@ -1235,8 +1235,8 @@ LogicalResult tosa::SliceOp::inferReturnTypeComponents(
12351235
SmallVector<int64_t> start;
12361236
SmallVector<int64_t> size;
12371237

1238-
if (!tosa::getConstShapeValue(adaptor.getStart().getDefiningOp(), start) ||
1239-
!tosa::getConstShapeValue(adaptor.getSize().getDefiningOp(), size)) {
1238+
if (!tosa::getConstShapeValues(adaptor.getStart().getDefiningOp(), start) ||
1239+
!tosa::getConstShapeValues(adaptor.getSize().getDefiningOp(), size)) {
12401240
auto rank = cast<tosa::shapeType>(adaptor.getSize().getType()).getRank();
12411241
SmallVector<int64_t> fallback(rank, ShapedType::kDynamic);
12421242
inferredReturnShapes.push_back(ShapedTypeComponents(fallback, inputType));
@@ -1544,8 +1544,8 @@ LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
15441544
ShapeAdaptor inputShape(adaptor.getInput1().getType());
15451545
Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType());
15461546
llvm::SmallVector<int64_t> newShapeValue;
1547-
if (!tosa::getConstShapeValue(adaptor.getShape().getDefiningOp(),
1548-
newShapeValue)) {
1547+
if (!tosa::getConstShapeValues(adaptor.getShape().getDefiningOp(),
1548+
newShapeValue)) {
15491549
auto rank = cast<tosa::shapeType>(adaptor.getShape().getType()).getRank();
15501550
SmallVector<int64_t> fallback(rank, ShapedType::kDynamic);
15511551
inferredReturnShapes.push_back(ShapedTypeComponents(fallback, inputType));
@@ -1594,7 +1594,7 @@ llvm::LogicalResult tosa::ReshapeOp::verify() {
15941594
RankedTensorType outputType = getType();
15951595

15961596
SmallVector<int64_t> shapeValues;
1597-
if (!tosa::getConstShapeValue(getShape().getDefiningOp(), shapeValues)) {
1597+
if (!tosa::getConstShapeValues(getShape().getDefiningOp(), shapeValues)) {
15981598
// skip following checks if shape is not constant
15991599
return mlir::success();
16001600
}
@@ -1899,11 +1899,12 @@ LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
18991899
return failure();
19001900

19011901
SmallVector<int64_t> scaleInt, offsetInt, borderInt;
1902-
if (!tosa::getConstShapeValue(adaptor.getScale().getDefiningOp(), scaleInt) ||
1903-
!tosa::getConstShapeValue(adaptor.getOffset().getDefiningOp(),
1904-
offsetInt) ||
1905-
!tosa::getConstShapeValue(adaptor.getBorder().getDefiningOp(),
1906-
borderInt)) {
1902+
if (!tosa::getConstShapeValues(adaptor.getScale().getDefiningOp(),
1903+
scaleInt) ||
1904+
!tosa::getConstShapeValues(adaptor.getOffset().getDefiningOp(),
1905+
offsetInt) ||
1906+
!tosa::getConstShapeValues(adaptor.getBorder().getDefiningOp(),
1907+
borderInt)) {
19071908
return failure();
19081909
}
19091910

@@ -1943,9 +1944,9 @@ LogicalResult tosa::ResizeOp::verify() {
19431944
SmallVector<int64_t> scaleValues;
19441945
SmallVector<int64_t> offsetValues;
19451946
SmallVector<int64_t> borderValues;
1946-
if (!tosa::getConstShapeValue(getScale().getDefiningOp(), scaleValues) ||
1947-
!tosa::getConstShapeValue(getOffset().getDefiningOp(), offsetValues) ||
1948-
!tosa::getConstShapeValue(getBorder().getDefiningOp(), borderValues)) {
1947+
if (!tosa::getConstShapeValues(getScale().getDefiningOp(), scaleValues) ||
1948+
!tosa::getConstShapeValues(getOffset().getDefiningOp(), offsetValues) ||
1949+
!tosa::getConstShapeValues(getBorder().getDefiningOp(), borderValues)) {
19491950
// Skip following checks if shape is not constant
19501951
return success();
19511952
}
@@ -3034,14 +3035,14 @@ OpTrait::tosa::verifyTosaShapeOperatorWithSameRanks(Operation *op) {
30343035

30353036
LogicalResult tosa::ConstShapeOp::verify() {
30363037
// check one dimensional rank
3037-
auto valuesRank = getValue().getType().getRank();
3038+
auto valuesRank = getValues().getType().getRank();
30383039
if (valuesRank != 1)
3039-
return emitOpError("expect elements in attribute value with rank 1");
3040+
return emitOpError("expect elements in attribute values with rank 1");
30403041
// check that number of elements in values attr equal to rank of result shape
3041-
auto count = getValue().getNumElements();
3042+
auto count = getValues().getNumElements();
30423043
auto rank = (cast<tosa::shapeType>(getResult().getType())).getRank();
30433044
if (!(count == rank || (count == 1 && rank == 0))) {
3044-
return emitOpError("expect number of elements in attribute value (")
3045+
return emitOpError("expect number of elements in attribute values (")
30453046
<< count << ") to be equal to the rank (" << rank
30463047
<< ") for the result shape type";
30473048
}

mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -399,8 +399,8 @@ std::optional<Value> TosaReduceTransposes::buildMappedToValue(
399399

400400
// Do not insert a TransposeOp, instead we fold the reshape and its attribute.
401401
llvm::SmallVector<int64_t> newShape;
402-
if (!tosa::getConstShapeValue(reshapeOp.getShape().getDefiningOp(),
403-
newShape)) {
402+
if (!tosa::getConstShapeValues(reshapeOp.getShape().getDefiningOp(),
403+
newShape)) {
404404
// this mean shape is not constant
405405
return std::nullopt;
406406
}

mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,8 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
342342
bool levelCheckResize(Operation *op) {
343343
if (auto resize = dyn_cast<tosa::ResizeOp>(op)) {
344344
SmallVector<int64_t> scale;
345-
if (!tosa::getConstShapeValue(resize.getScale().getDefiningOp(), scale)) {
345+
if (!tosa::getConstShapeValues(resize.getScale().getDefiningOp(),
346+
scale)) {
346347
return false;
347348
}
348349
const int64_t scaleYN = scale[0];
@@ -736,7 +737,7 @@ bool checkErrorIfResize(Operation *op) {
736737
}
737738

738739
SmallVector<int64_t> scale;
739-
if (!tosa::getConstShapeValue(resize.getScale().getDefiningOp(), scale)) {
740+
if (!tosa::getConstShapeValues(resize.getScale().getDefiningOp(), scale)) {
740741
return false;
741742
}
742743

@@ -761,8 +762,8 @@ bool checkErrorIfResize(Operation *op) {
761762

762763
SmallVector<int64_t> offset;
763764
SmallVector<int64_t> border;
764-
if (!tosa::getConstShapeValue(resize.getOffset().getDefiningOp(), offset) ||
765-
!tosa::getConstShapeValue(resize.getBorder().getDefiningOp(), border)) {
765+
if (!tosa::getConstShapeValues(resize.getOffset().getDefiningOp(), offset) ||
766+
!tosa::getConstShapeValues(resize.getBorder().getDefiningOp(), border)) {
766767
return false;
767768
}
768769

mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -178,13 +178,13 @@ SmallVector<int64_t> mlir::tosa::convertFromMlirShape(ArrayRef<int64_t> shape) {
178178
}));
179179
}
180180

181-
bool mlir::tosa::getConstShapeValue(Operation *op,
182-
llvm::SmallVector<int64_t> &result_shape) {
181+
bool mlir::tosa::getConstShapeValues(Operation *op,
182+
llvm::SmallVector<int64_t> &result_shape) {
183183
if (!op) {
184184
return false;
185185
}
186186
if (auto constOp = mlir::dyn_cast<tosa::ConstShapeOp>(op)) {
187-
Attribute constOpAttr = constOp->getAttr("value");
187+
Attribute constOpAttr = constOp->getAttr("values");
188188
DenseElementsAttr elementsAttr = cast<DenseElementsAttr>(constOpAttr);
189189
for (int i = 0; i < elementsAttr.size(); i++) {
190190
int64_t val = elementsAttr.getValues<int64_t>()[i];

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ func.func @unranked_add(%arg0 : tensor<10x10xf32> , %arg1 : tensor<10x10xf32>, %
2424
%reduce = tosa.reduce_max %arg0 {axis = 1 : i32} : (tensor<10x10xf32>) -> tensor<10x1xf32>
2525
%1 = tosa.add %reduce, %arg1 : (tensor<10x1xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
2626
%0 = tosa.add %1, %arg2 : (tensor<10x10xf32>, tensor<*xf32>) -> tensor<*xf32>
27-
%s = tosa.const_shape {value = dense<[10, 10]> : tensor<2xindex>} : () -> !tosa.shape<2>
27+
%s = tosa.const_shape {values = dense<[10, 10]> : tensor<2xindex>} : () -> !tosa.shape<2>
2828
%2 = tosa.reshape %0, %s : (tensor<*xf32>, !tosa.shape<2>) -> tensor<10x10xf32>
2929
return %2 : tensor<10x10xf32>
3030
}

0 commit comments

Comments
 (0)