Skip to content

Commit c6876b4

Browse files
authored
Update input names from input to input1 for Table, Reverse, Slice (#109807)
- For input naming consistency, updated the inputs to input1 for Table, Reverse and Slice operator Signed-off-by: Jerry Ge <[email protected]>
1 parent be6aed9 commit c6876b4

File tree

5 files changed

+19
-19
lines changed

5 files changed

+19
-19
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -881,7 +881,7 @@ def Tosa_TableOp : Tosa_InferShapedTypeOp<"table"> {
881881
}];
882882

883883
let arguments = (ins
884-
Tosa_Tensor: $input,
884+
Tosa_Tensor: $input1,
885885
Tosa_Tensor1D: $table
886886
);
887887

@@ -890,7 +890,7 @@ def Tosa_TableOp : Tosa_InferShapedTypeOp<"table"> {
890890
);
891891

892892
let assemblyFormat = [{
893-
$input `,` $table attr-dict `:` `(` type($input) `,` type($table) `)` `->` type($output)
893+
$input1 `,` $table attr-dict `:` `(` type($input1) `,` type($table) `)` `->` type($output)
894894
}];
895895

896896
let hasVerifier = 1;
@@ -1640,7 +1640,7 @@ def Tosa_ReverseOp: Tosa_Op<"reverse", [
16401640
}];
16411641

16421642
let arguments = (ins
1643-
Tosa_Tensor:$input,
1643+
Tosa_Tensor:$input1,
16441644
I32Attr:$axis
16451645
);
16461646

@@ -1667,7 +1667,7 @@ def Tosa_SliceOp : Tosa_InferShapedTypeOp<"slice"> {
16671667
}];
16681668

16691669
let arguments = (ins
1670-
Tosa_Tensor:$input,
1670+
Tosa_Tensor:$input1,
16711671
DenseI64ArrayAttr:$start,
16721672
DenseI64ArrayAttr:$size
16731673
);

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1830,7 +1830,7 @@ class ReverseConverter : public OpRewritePattern<tosa::ReverseOp> {
18301830
LogicalResult matchAndRewrite(tosa::ReverseOp op,
18311831
PatternRewriter &rewriter) const final {
18321832
auto loc = op.getLoc();
1833-
Value input = op.getInput();
1833+
Value input = op.getInput1();
18341834
auto inputTy = cast<ShapedType>(input.getType());
18351835
auto resultTy = cast<ShapedType>(op.getType());
18361836
auto axis = op.getAxis();
@@ -2161,7 +2161,7 @@ class TableConverter : public OpRewritePattern<tosa::TableOp> {
21612161
LogicalResult matchAndRewrite(tosa::TableOp op,
21622162
PatternRewriter &rewriter) const final {
21632163
auto loc = op.getLoc();
2164-
Value input = op.getInput();
2164+
Value input = op.getInput1();
21652165
Value table = op.getTable();
21662166
auto inputTy = cast<ShapedType>(input.getType());
21672167
auto tableTy = cast<ShapedType>(table.getType());

mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ TensorType inferReshapeCollapsedType(TensorType lhsType, TensorType rhsType) {
144144
for (; currRhsDim < rhsShape.size(); currRhsDim++) {
145145
assert(rhsShape[currRhsDim] == 1);
146146
}
147-
147+
148148
return lhsType.clone(intermediateShape);
149149
}
150150

@@ -264,7 +264,7 @@ class SliceConverter : public OpConversionPattern<tosa::SliceOp> {
264264
matchAndRewrite(tosa::SliceOp sliceOp, OpAdaptor adaptor,
265265
ConversionPatternRewriter &rewriter) const final {
266266
Location loc = sliceOp.getLoc();
267-
Value input = adaptor.getInput();
267+
Value input = adaptor.getInput1();
268268
ShapedType resultType = cast<ShapedType>(sliceOp.getType());
269269
if (llvm::isa<UnrankedTensorType>(resultType))
270270
return failure();

mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -380,7 +380,7 @@ struct ConcatSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
380380

381381
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp,
382382
PatternRewriter &rewriter) const override {
383-
Value sliceInput = sliceOp.getInput();
383+
Value sliceInput = sliceOp.getInput1();
384384
auto concatOp = sliceInput.getDefiningOp<tosa::ConcatOp>();
385385
if (!concatOp)
386386
return rewriter.notifyMatchFailure(
@@ -919,11 +919,11 @@ OpFoldResult ResizeOp::fold(FoldAdaptor adaptor) {
919919
}
920920

921921
OpFoldResult ReverseOp::fold(FoldAdaptor adaptor) {
922-
auto operand = getInput();
922+
auto operand = getInput1();
923923
auto operandTy = llvm::cast<ShapedType>(operand.getType());
924924
auto axis = getAxis();
925925
auto operandAttr =
926-
llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getInput());
926+
llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getInput1());
927927
if (operandAttr)
928928
return operandAttr;
929929

@@ -936,24 +936,24 @@ OpFoldResult ReverseOp::fold(FoldAdaptor adaptor) {
936936
}
937937

938938
OpFoldResult SliceOp::fold(FoldAdaptor adaptor) {
939-
auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput().getType());
939+
auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
940940
auto outputTy = llvm::dyn_cast<RankedTensorType>(getType());
941941

942942
if (!inputTy || !outputTy)
943943
return {};
944944

945945
if (inputTy == outputTy && inputTy.hasStaticShape())
946-
return getInput();
946+
return getInput1();
947947

948-
if (!adaptor.getInput())
948+
if (!adaptor.getInput1())
949949
return {};
950950

951951
// Cannot create an ElementsAttr from non-int/float/index types
952952
if (!inputTy.getElementType().isIntOrIndexOrFloat() ||
953953
!outputTy.getElementType().isIntOrIndexOrFloat())
954954
return {};
955955

956-
auto operand = llvm::cast<ElementsAttr>(adaptor.getInput());
956+
auto operand = llvm::cast<ElementsAttr>(adaptor.getInput1());
957957
if (operand.isSplat() && outputTy.hasStaticShape()) {
958958
return SplatElementsAttr::get(outputTy, operand.getSplatValue<Attribute>());
959959
}

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -850,7 +850,7 @@ LogicalResult tosa::SliceOp::inferReturnTypeComponents(
850850
}
851851

852852
LogicalResult tosa::SliceOp::verify() {
853-
auto inputType = llvm::dyn_cast<RankedTensorType>(getInput().getType());
853+
auto inputType = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
854854
if (!inputType)
855855
return success();
856856

@@ -869,7 +869,7 @@ LogicalResult tosa::TableOp::inferReturnTypeComponents(
869869
MLIRContext *context, ::std::optional<Location> location,
870870
TableOp::Adaptor adaptor,
871871
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
872-
ShapeAdaptor inputShape(adaptor.getInput().getType());
872+
ShapeAdaptor inputShape(adaptor.getInput1().getType());
873873

874874
if (!inputShape.hasRank()) {
875875
inferredReturnShapes.push_back(ShapedTypeComponents());
@@ -882,7 +882,7 @@ LogicalResult tosa::TableOp::inferReturnTypeComponents(
882882
}
883883

884884
LogicalResult tosa::TableOp::verify() {
885-
TensorType inputType = getInput().getType();
885+
TensorType inputType = getInput1().getType();
886886
TensorType outputType = getOutput().getType();
887887

888888
if (inputType.hasRank() && outputType.hasRank() &&
@@ -1973,7 +1973,7 @@ void IfOp::print(OpAsmPrinter &p) {
19731973
}
19741974

19751975
LogicalResult ReverseOp::verify() {
1976-
TensorType inputType = getInput().getType();
1976+
TensorType inputType = getInput1().getType();
19771977
TensorType outputType = getOutput().getType();
19781978
int32_t reverseAxis = getAxis();
19791979

0 commit comments

Comments
 (0)