Skip to content

[mlir][tosa] Update value to values for ConstOp and ConstShapeOp #129943

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2338,7 +2338,7 @@ def Tosa_RescaleOp: Tosa_Op<"rescale", [Pure,
// Operator: const
//===----------------------------------------------------------------------===//
def Tosa_ConstOp : Tosa_Op<"const", [ConstantLike, Pure,
AllShapesMatch<["value", "output"]>,
AllShapesMatch<["values", "output"]>,
FirstAttrDerivedResultType]> {
let summary = "Constant op.";

Expand All @@ -2350,12 +2350,12 @@ def Tosa_ConstOp : Tosa_Op<"const", [ConstantLike, Pure,

```mlir
// Generic form
%out = "tosa.const"() {value = dense<0> : tensor<2x3xi32>} : () -> tensor<2x3xi32>
%out = "tosa.const"() {values = dense<0> : tensor<2x3xi32>} : () -> tensor<2x3xi32>
```
}];

let arguments = (ins
ElementsAttr:$value
ElementsAttr:$values
);

let results = (outs
Expand Down
4 changes: 2 additions & 2 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,11 @@ def Tosa_ConstShapeOp : Tosa_ShapeOp<"const_shape", [ConstantLike, Pure]> {

```mlir
// Generic form
%out = "tosa.const_shape"() {value = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4>
%out = "tosa.const_shape"() {values = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4>
```
}];

let arguments = (ins IndexElementsAttr : $value);
let arguments = (ins IndexElementsAttr : $values);

let results = (outs Tosa_Shape : $output);

Expand Down
4 changes: 2 additions & 2 deletions mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -237,8 +237,8 @@ Value getTosaConstShape(PatternRewriter &rewriter, Location loc,

SmallVector<int64_t> convertFromMlirShape(ArrayRef<int64_t> shape);

bool getConstShapeValue(Operation *op,
llvm::SmallVector<int64_t> &result_shape);
bool getConstShapeValues(Operation *op,
llvm::SmallVector<int64_t> &result_shape);

// returns a small vector of int64_t values that attr contains
SmallVector<int64_t> convertFromIntAttr(const DenseElementsAttr &attr,
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Conversion/TosaToArith/TosaToArith.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class ConstOpConverter : public OpRewritePattern<tosa::ConstOp> {

LogicalResult matchAndRewrite(tosa::ConstOp op,
PatternRewriter &rewriter) const final {
rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, op.getValue());
rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, op.getValues());
return success();
}
};
Expand Down
8 changes: 4 additions & 4 deletions mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1578,7 +1578,7 @@ class ResizeUnaryConverter : public OpRewritePattern<tosa::ResizeOp> {
}

SmallVector<int64_t> scale;
if (!tosa::getConstShapeValue(op.getScale().getDefiningOp(), scale)) {
if (!tosa::getConstShapeValues(op.getScale().getDefiningOp(), scale)) {
return failure();
}

Expand Down Expand Up @@ -1799,9 +1799,9 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
Value inX = b.create<arith::IndexCastOp>(b.getI32Type(), x);

SmallVector<int64_t> scale, offset, border;
if (!tosa::getConstShapeValue(op.getScale().getDefiningOp(), scale) ||
!tosa::getConstShapeValue(op.getOffset().getDefiningOp(), offset) ||
!tosa::getConstShapeValue(op.getBorder().getDefiningOp(), border)) {
if (!tosa::getConstShapeValues(op.getScale().getDefiningOp(), scale) ||
!tosa::getConstShapeValues(op.getOffset().getDefiningOp(), offset) ||
!tosa::getConstShapeValues(op.getBorder().getDefiningOp(), border)) {
return rewriter.notifyMatchFailure(
op, "tosa.resize scale/offset/border should have compile time "
"constant values.");
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -243,8 +243,8 @@ class ReshapeConverter : public OpConversionPattern<tosa::ReshapeOp> {
}

llvm::SmallVector<int64_t> newShape;
if (!tosa::getConstShapeValue(reshape.getShape().getDefiningOp(),
newShape)) {
if (!tosa::getConstShapeValues(reshape.getShape().getDefiningOp(),
newShape)) {
return failure();
}

Expand Down
6 changes: 3 additions & 3 deletions mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -882,9 +882,9 @@ OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
return {};
}

OpFoldResult ConstOp::fold(FoldAdaptor adaptor) { return getValueAttr(); }
OpFoldResult ConstOp::fold(FoldAdaptor adaptor) { return getValuesAttr(); }

OpFoldResult ConstShapeOp::fold(FoldAdaptor adaptor) { return getValueAttr(); }
OpFoldResult ConstShapeOp::fold(FoldAdaptor adaptor) { return getValuesAttr(); }

#define REDUCE_FOLDER(OP) \
OpFoldResult OP::fold(FoldAdaptor adaptor) { \
Expand Down Expand Up @@ -947,7 +947,7 @@ OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
return {};

llvm::SmallVector<int64_t> shapeVec;
if (!tosa::getConstShapeValue(getShape().getDefiningOp(), shapeVec))
if (!tosa::getConstShapeValues(getShape().getDefiningOp(), shapeVec))
return {};

return operand.reshape(
Expand Down
43 changes: 22 additions & 21 deletions mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ static LogicalResult verifyConvOp(T op) {

LogicalResult tosa::ConstOp::verify() {

auto attrType = llvm::dyn_cast<TensorType>(getValueAttr().getType());
auto attrType = llvm::dyn_cast<TensorType>(getValuesAttr().getType());
auto outputType = llvm::dyn_cast<TensorType>(getOutput().getType());

if (!attrType || !outputType) {
Expand Down Expand Up @@ -1179,8 +1179,8 @@ LogicalResult tosa::PadOp::inferReturnTypeComponents(

SmallVector<int64_t> paddingValues;
// If the paddings value is not a constant, all dimensions must be dynamic.
if (!tosa::getConstShapeValue(adaptor.getPadding().getDefiningOp(),
paddingValues)) {
if (!tosa::getConstShapeValues(adaptor.getPadding().getDefiningOp(),
paddingValues)) {
outputShape.resize(inputShape.getRank(), ShapedType::kDynamic);
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
return success();
Expand Down Expand Up @@ -1252,8 +1252,8 @@ LogicalResult tosa::SliceOp::inferReturnTypeComponents(
SmallVector<int64_t> start;
SmallVector<int64_t> size;

if (!tosa::getConstShapeValue(adaptor.getStart().getDefiningOp(), start) ||
!tosa::getConstShapeValue(adaptor.getSize().getDefiningOp(), size)) {
if (!tosa::getConstShapeValues(adaptor.getStart().getDefiningOp(), start) ||
!tosa::getConstShapeValues(adaptor.getSize().getDefiningOp(), size)) {
auto rank = cast<tosa::shapeType>(adaptor.getSize().getType()).getRank();
SmallVector<int64_t> fallback(rank, ShapedType::kDynamic);
inferredReturnShapes.push_back(ShapedTypeComponents(fallback, inputType));
Expand Down Expand Up @@ -1561,8 +1561,8 @@ LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
ShapeAdaptor inputShape(adaptor.getInput1().getType());
Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType());
llvm::SmallVector<int64_t> newShapeValue;
if (!tosa::getConstShapeValue(adaptor.getShape().getDefiningOp(),
newShapeValue)) {
if (!tosa::getConstShapeValues(adaptor.getShape().getDefiningOp(),
newShapeValue)) {
auto rank = cast<tosa::shapeType>(adaptor.getShape().getType()).getRank();
SmallVector<int64_t> fallback(rank, ShapedType::kDynamic);
inferredReturnShapes.push_back(ShapedTypeComponents(fallback, inputType));
Expand Down Expand Up @@ -1611,7 +1611,7 @@ llvm::LogicalResult tosa::ReshapeOp::verify() {
RankedTensorType outputType = getType();

SmallVector<int64_t> shapeValues;
if (!tosa::getConstShapeValue(getShape().getDefiningOp(), shapeValues)) {
if (!tosa::getConstShapeValues(getShape().getDefiningOp(), shapeValues)) {
// skip following checks if shape is not constant
return mlir::success();
}
Expand Down Expand Up @@ -1916,11 +1916,12 @@ LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
return failure();

SmallVector<int64_t> scaleInt, offsetInt, borderInt;
if (!tosa::getConstShapeValue(adaptor.getScale().getDefiningOp(), scaleInt) ||
!tosa::getConstShapeValue(adaptor.getOffset().getDefiningOp(),
offsetInt) ||
!tosa::getConstShapeValue(adaptor.getBorder().getDefiningOp(),
borderInt)) {
if (!tosa::getConstShapeValues(adaptor.getScale().getDefiningOp(),
scaleInt) ||
!tosa::getConstShapeValues(adaptor.getOffset().getDefiningOp(),
offsetInt) ||
!tosa::getConstShapeValues(adaptor.getBorder().getDefiningOp(),
borderInt)) {
return failure();
}

Expand Down Expand Up @@ -1960,9 +1961,9 @@ LogicalResult tosa::ResizeOp::verify() {
SmallVector<int64_t> scaleValues;
SmallVector<int64_t> offsetValues;
SmallVector<int64_t> borderValues;
if (!tosa::getConstShapeValue(getScale().getDefiningOp(), scaleValues) ||
!tosa::getConstShapeValue(getOffset().getDefiningOp(), offsetValues) ||
!tosa::getConstShapeValue(getBorder().getDefiningOp(), borderValues)) {
if (!tosa::getConstShapeValues(getScale().getDefiningOp(), scaleValues) ||
!tosa::getConstShapeValues(getOffset().getDefiningOp(), offsetValues) ||
!tosa::getConstShapeValues(getBorder().getDefiningOp(), borderValues)) {
// Skip following checks if shape is not constant
return success();
}
Expand Down Expand Up @@ -3051,14 +3052,14 @@ OpTrait::tosa::verifyTosaShapeOperatorWithSameRanks(Operation *op) {

LogicalResult tosa::ConstShapeOp::verify() {
// check one dimensional rank
auto valuesRank = getValue().getType().getRank();
auto valuesRank = getValues().getType().getRank();
if (valuesRank != 1)
return emitOpError("expect elements in attribute value with rank 1");
// check that number of elements in value attr equal to rank of result shape
auto count = getValue().getNumElements();
return emitOpError("expect elements in attribute values with rank 1");
// check that number of elements in values attr equal to rank of result shape
auto count = getValues().getNumElements();
auto rank = (cast<tosa::shapeType>(getResult().getType())).getRank();
if (!(count == rank || (count == 1 && rank == 0))) {
return emitOpError("expect number of elements in attribute value (")
return emitOpError("expect number of elements in attribute values (")
<< count << ") to be equal to the rank (" << rank
<< ") for the result shape type";
}
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ struct ReduceConstantOptimization : public OpRewritePattern<OperationType> {
return rewriter.notifyMatchFailure(op, "result type shape is not static");

auto reductionAxis = op.getAxis();
const auto denseElementsAttr = constOp.getValue();
const auto denseElementsAttr = constOp.getValues();
const auto shapedOldElementsValues =
cast<ShapedType>(denseElementsAttr.getType());

Expand Down
6 changes: 3 additions & 3 deletions mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -399,8 +399,8 @@ std::optional<Value> TosaReduceTransposes::buildMappedToValue(

// Do not insert a TransposeOp, instead we fold the reshape and its attribute.
llvm::SmallVector<int64_t> newShape;
if (!tosa::getConstShapeValue(reshapeOp.getShape().getDefiningOp(),
newShape)) {
if (!tosa::getConstShapeValues(reshapeOp.getShape().getDefiningOp(),
newShape)) {
// this mean shape is not constant
return std::nullopt;
}
Expand All @@ -418,7 +418,7 @@ std::optional<Value> TosaReduceTransposes::buildMappedToValue(
std::optional<Value> TosaReduceTransposes::buildMappedToValue(
ConstOp constOp, const DenseMap<Value, Value> &valuesMap,
IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms) {
auto denseAttr = llvm::dyn_cast<DenseElementsAttr>(constOp.getValue());
auto denseAttr = llvm::dyn_cast<DenseElementsAttr>(constOp.getValues());
if (!denseAttr)
return std::nullopt;
auto maybeNewDenseAttr = transposeDenseAttribute(denseAttr, hoistedPerms);
Expand Down
9 changes: 5 additions & 4 deletions mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,8 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
bool levelCheckResize(Operation *op) {
if (auto resize = dyn_cast<tosa::ResizeOp>(op)) {
SmallVector<int64_t> scale;
if (!tosa::getConstShapeValue(resize.getScale().getDefiningOp(), scale)) {
if (!tosa::getConstShapeValues(resize.getScale().getDefiningOp(),
scale)) {
return false;
}
const int64_t scaleYN = scale[0];
Expand Down Expand Up @@ -736,7 +737,7 @@ bool checkErrorIfResize(Operation *op) {
}

SmallVector<int64_t> scale;
if (!tosa::getConstShapeValue(resize.getScale().getDefiningOp(), scale)) {
if (!tosa::getConstShapeValues(resize.getScale().getDefiningOp(), scale)) {
return false;
}

Expand All @@ -761,8 +762,8 @@ bool checkErrorIfResize(Operation *op) {

SmallVector<int64_t> offset;
SmallVector<int64_t> border;
if (!tosa::getConstShapeValue(resize.getOffset().getDefiningOp(), offset) ||
!tosa::getConstShapeValue(resize.getBorder().getDefiningOp(), border)) {
if (!tosa::getConstShapeValues(resize.getOffset().getDefiningOp(), offset) ||
!tosa::getConstShapeValues(resize.getBorder().getDefiningOp(), border)) {
return false;
}

Expand Down
6 changes: 3 additions & 3 deletions mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,13 +178,13 @@ SmallVector<int64_t> mlir::tosa::convertFromMlirShape(ArrayRef<int64_t> shape) {
}));
}

bool mlir::tosa::getConstShapeValue(Operation *op,
llvm::SmallVector<int64_t> &result_shape) {
bool mlir::tosa::getConstShapeValues(Operation *op,
llvm::SmallVector<int64_t> &result_shape) {
if (!op) {
return false;
}
if (auto constOp = mlir::dyn_cast<tosa::ConstShapeOp>(op)) {
Attribute constOpAttr = constOp->getAttr("value");
Attribute constOpAttr = constOp->getAttr("values");
DenseElementsAttr elementsAttr = cast<DenseElementsAttr>(constOpAttr);
for (int i = 0; i < elementsAttr.size(); i++) {
int64_t val = elementsAttr.getValues<int64_t>()[i];
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Conversion/TosaToArith/tosa-to-arith.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
// CHECK-LABEL: func @const_test
func.func @const_test() -> (tensor<i32>) {
// CHECK: [[C3:%.+]] = arith.constant dense<3> : tensor<i32>
%result = "tosa.const"() {value = dense<3> : tensor<i32>} : () -> tensor<i32>
%result = "tosa.const"() {values = dense<3> : tensor<i32>} : () -> tensor<i32>

// CHECK: return [[C3]]
return %result : tensor<i32>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ func.func @unranked_add(%arg0 : tensor<10x10xf32> , %arg1 : tensor<10x10xf32>, %
%reduce = tosa.reduce_max %arg0 {axis = 1 : i32} : (tensor<10x10xf32>) -> tensor<10x1xf32>
%1 = tosa.add %reduce, %arg1 : (tensor<10x1xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
%0 = tosa.add %1, %arg2 : (tensor<10x10xf32>, tensor<*xf32>) -> tensor<*xf32>
%s = tosa.const_shape {value = dense<[10, 10]> : tensor<2xindex>} : () -> !tosa.shape<2>
%s = tosa.const_shape {values = dense<[10, 10]> : tensor<2xindex>} : () -> !tosa.shape<2>
%2 = tosa.reshape %0, %s : (tensor<*xf32>, !tosa.shape<2>) -> tensor<10x10xf32>
return %2 : tensor<10x10xf32>
}
Expand Down
Loading