Skip to content

Commit 0bdbc1c

Browse files
Revert "[TOSA] Add Tosa_Shape type and ConstShapeOp (llvm#122547)"
This reverts commit f09db6a.
1 parent 7614a51 commit 0bdbc1c

File tree

18 files changed

+33
-425
lines changed

18 files changed

+33
-425
lines changed

mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,11 @@ add_mlir_doc(TosaOps TosaOps Dialects/ -gen-op-doc)
33
add_mlir_interface(TosaInterfaces)
44

55
set(LLVM_TARGET_DEFINITIONS TosaOps.td)
6-
mlir_tablegen(TosaOpsTypesBase.h.inc -gen-typedef-decls -typedefs-dialect=tosa)
7-
mlir_tablegen(TosaOpsTypesBase.cpp.inc -gen-typedef-defs -typedefs-dialect=tosa)
86
mlir_tablegen(TosaAttributes.h.inc -gen-attrdef-decls -attrdefs-dialect=tosa)
97
mlir_tablegen(TosaAttributes.cpp.inc -gen-attrdef-defs -attrdefs-dialect=tosa)
108
add_public_tablegen_target(MLIRTosaAttributesIncGen)
119

1210
set(LLVM_TARGET_DEFINITIONS TosaDialectBytecode.td)
1311
mlir_tablegen(TosaDialectBytecode.cpp.inc -gen-bytecode -bytecode-dialect="Tosa")
1412
add_public_tablegen_target(MLIRTosaDialectBytecodeIncGen)
13+

mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@ def Tosa_Dialect : Dialect {
4545
let cppNamespace = "mlir::tosa";
4646
let hasConstantMaterializer = 1;
4747
let useDefaultAttributePrinterParser = 1;
48-
let useDefaultTypePrinterParser = 1;
4948
}
5049

5150
//===----------------------------------------------------------------------===//
@@ -218,21 +217,12 @@ def TosaElementwiseOperator : NativeOpTrait<"TosaElementwiseOperator"> {
218217
let cppNamespace = "mlir::OpTrait::tosa";
219218
}
220219

221-
//===----------------------------------------------------------------------===//
222-
// TOSA Operator Trait.
223-
//===----------------------------------------------------------------------===//
224-
// Op operands with TOSA shape types must be compile time resolvable
225-
def TosaResolvableShapeOperands : NativeOpTrait<"TosaResolvableShapeOperands"> {
226-
let cppNamespace = "mlir::OpTrait::tosa";
227-
}
228-
229220
//===----------------------------------------------------------------------===//
230221
// TOSA Operator Class.
231222
//===----------------------------------------------------------------------===//
232223

233224
class Tosa_Op<string mnemonic, list<Trait> traits = []> :
234-
Op<Tosa_Dialect, mnemonic, !listconcat(traits, [TosaOpInterface,
235-
TosaResolvableShapeOperands])> {
225+
Op<Tosa_Dialect, mnemonic, !listconcat(traits, [TosaOpInterface])> {
236226
}
237227

238228
class Tosa_ElementwiseOp<string mnemonic, list<Trait> traits = []> :

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h

Lines changed: 0 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -90,55 +90,14 @@ template <typename ConcreteType>
9090
class TosaElementwiseOperator
9191
: public TraitBase<ConcreteType, TosaElementwiseOperator> {};
9292

93-
LogicalResult verifyTosaResolvableShapeOperands(Operation *op);
94-
/// This class verifies that tosa shape operands are compile time resolvable
95-
template <typename ConcreteType>
96-
class TosaResolvableShapeOperands
97-
: public TraitBase<ConcreteType, TosaResolvableShapeOperands> {
98-
public:
99-
static LogicalResult verifyTrait(Operation *op) {
100-
return verifyTosaResolvableShapeOperands(op);
101-
}
102-
};
103-
104-
LogicalResult verifyTosaShapeOperator(Operation *op);
105-
/// This class indicates that op operates on tosa shape types
106-
template <typename ConcreteType>
107-
class TosaShapeOperator : public TraitBase<ConcreteType, TosaShapeOperator> {
108-
public:
109-
static LogicalResult verifyTrait(Operation *op) {
110-
return verifyTosaShapeOperator(op);
111-
}
112-
};
113-
114-
LogicalResult verifyTosaShapeOperatorWithSameRanks(Operation *op);
115-
/// This class indicates that op operates on tosa shape types
116-
template <typename ConcreteType>
117-
class TosaShapeOperatorWithSameRanks
118-
: public TraitBase<ConcreteType, TosaShapeOperatorWithSameRanks> {
119-
public:
120-
static LogicalResult verifyTrait(Operation *op) {
121-
return verifyTosaShapeOperatorWithSameRanks(op);
122-
}
123-
};
124-
12593
} // namespace tosa
12694
} // namespace OpTrait
12795

128-
namespace tosa {
129-
130-
bool isa_tosa_shape_type(mlir::Type t);
131-
132-
} // namespace tosa
133-
13496
} // namespace mlir
13597

13698
#define GET_ATTRDEF_CLASSES
13799
#include "mlir/Dialect/Tosa/IR/TosaAttributes.h.inc"
138100

139-
#define GET_TYPEDEF_CLASSES
140-
#include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.h.inc"
141-
142101
#define GET_OP_CLASSES
143102
#include "mlir/Dialect/Tosa/IR/TosaOps.h.inc"
144103

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1689,16 +1689,12 @@ def Tosa_TileOp : Tosa_InferShapedTypeOp<"tile"> {
16891689

16901690
let arguments = (ins
16911691
Tosa_Tensor:$input1,
1692-
Tosa_Shape:$multiples);
1692+
DenseI64ArrayAttr:$multiples);
16931693

16941694
let results = (outs
16951695
Tosa_Tensor:$output
16961696
);
16971697

1698-
let extraClassDeclaration = [{
1699-
LogicalResult getConstantMultiples(llvm::SmallVector<int64_t> &multiples);
1700-
}];
1701-
17021698
let hasFolder = 1;
17031699
let hasVerifier = 1;
17041700
}
@@ -2110,6 +2106,4 @@ def Tosa_WhileOp : Tosa_Op<"while_loop", [
21102106

21112107
include "mlir/Dialect/Tosa/IR/TosaUtilOps.td"
21122108

2113-
include "mlir/Dialect/Tosa/IR/TosaShapeOps.td"
2114-
21152109
#endif // TOSA_OPS

mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td

Lines changed: 0 additions & 77 deletions
This file was deleted.

mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td

Lines changed: 0 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,8 @@
1313
#ifndef TOSA_TYPES_BASE
1414
#define TOSA_TYPES_BASE
1515

16-
include "mlir/IR/AttrTypeBase.td"
1716
include "mlir/IR/OpBase.td"
1817

19-
include "mlir/Dialect/Tosa/IR/TosaOpBase.td"
20-
2118
//===----------------------------------------------------------------------===//
2219
// Tosa Type Definitions.
2320
//===----------------------------------------------------------------------===//
@@ -218,66 +215,4 @@ def Tosa_Buffer : MemRefOf<[Tosa_AnyNumber]>;
218215
def Tosa_TupleBuffer : NestedTupleOf<[Tosa_Buffer]>;
219216
def Tosa_BufOrTuple : AnyTypeOf<[Tosa_Buffer, Tosa_TupleBuffer]>;
220217

221-
//===----------------------------------------------------------------------===//
222-
// Tosa Type Definitions.
223-
//===----------------------------------------------------------------------===//
224-
225-
// The base class for Tosa dialect types.
226-
class Tosa_Type<string name, string typeMnemonic, list<Trait> traits = []>
227-
: TypeDef<Tosa_Dialect, name, traits> {
228-
let mnemonic = typeMnemonic;
229-
}
230-
231-
//===----------------------------------------------------------------------===//
232-
// ShapeType
233-
//===----------------------------------------------------------------------===//
234-
def Tosa_Shape : Tosa_Type<"shape", "shape"> {
235-
let summary = "Shape with static rank and Index element type";
236-
let description = [{
237-
Syntax:
238-
239-
``` shape - type :: = `shape` `<` rank `>`
240-
``` Values with shape type represents a shape with a fixed rank and a list
241-
of dimensions
242-
.Rank must be zero or a positive integer
243-
.Each dimension is represented by the builtin
244-
Index type.
245-
246-
Examples:
247-
248-
```mlir
249-
// Shape with rank of four, for example, [1, 1, 8, 16]:
250-
!tosa
251-
.shape<4>
252-
253-
// Shape with rank of one, for example, [16]:
254-
!tosa
255-
.shape<1>
256-
257-
// Shape with rank zero, for example, [] (i.e., shape of scalar values):
258-
!tosa.shape<0>
259-
```
260-
}];
261-
let parameters = (ins "int" : $rank);
262-
let builders = [TypeBuilder<(ins "int" : $rank)>];
263-
let assemblyFormat = "`<` $rank `>`";
264-
265-
let genVerifyDecl = 1;
266-
}
267-
268-
def IsTosaShapeType : CPred<"mlir::tosa::isa_tosa_shape_type($_self)">;
269-
270-
// Whether a Tosa Shape type has a rank equal to the specified rank.
271-
class IsTosaShapeOfRankPred<int rank> : And<[
272-
IsTosaShapeType,
273-
CPred<[{::llvm::cast<::mlir::tosa::shapeType>($_self).getRank() == }] # rank>
274-
]>;
275-
276-
class TosaShapeOfRank<int rank>
277-
: Type<IsTosaShapeOfRankPred<rank>, "Tosa shape type of rank " #rank>;
278-
279-
def Rank1TosaShape : TosaShapeOfRank<1>;
280-
def Rank2TosaShape : TosaShapeOfRank<2>;
281-
def Rank4TosaShape : TosaShapeOfRank<4>;
282-
283218
#endif // TOSA_TYPES_BASE

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1892,9 +1892,7 @@ struct TileConverter : public OpConversionPattern<tosa::TileOp> {
18921892
auto elementTy = inputTy.getElementType();
18931893
int64_t rank = inputTy.getRank();
18941894

1895-
SmallVector<int64_t> multiples;
1896-
if (failed(op.getConstantMultiples(multiples)))
1897-
return failure();
1895+
ArrayRef<int64_t> multiples = op.getMultiples();
18981896

18991897
// Broadcast the newly added dimensions to their appropriate multiple.
19001898
SmallVector<int64_t, 2> genericShape;

mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,6 @@ struct TosaToLinalg : public impl::TosaToLinalgBase<TosaToLinalg> {
5555
target.addLegalOp<tosa::ApplyScaleOp>();
5656
target.addLegalOp<tosa::IfOp>();
5757
target.addLegalOp<tosa::ConstOp>();
58-
target.addLegalOp<tosa::ConstShapeOp>();
5958
target.addLegalOp<tosa::WhileOp>();
6059
target.addLegalOp<tosa::ConcatOp>();
6160
target.addLegalOp<tosa::SliceOp>();

mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -808,8 +808,6 @@ OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
808808

809809
OpFoldResult ConstOp::fold(FoldAdaptor adaptor) { return getValueAttr(); }
810810

811-
OpFoldResult ConstShapeOp::fold(FoldAdaptor adaptor) { return getValueAttr(); }
812-
813811
#define REDUCE_FOLDER(OP) \
814812
OpFoldResult OP::fold(FoldAdaptor adaptor) { \
815813
ShapedType inputTy = llvm::cast<ShapedType>(getInput().getType()); \
@@ -987,20 +985,9 @@ OpFoldResult tosa::SelectOp::fold(FoldAdaptor adaptor) {
987985
}
988986

989987
OpFoldResult TileOp::fold(FoldAdaptor adaptor) {
990-
if (getInput1().getType() == getType()) {
991-
if (auto multiples = llvm::dyn_cast_if_present<DenseElementsAttr>(
992-
adaptor.getMultiples())) {
993-
if (multiples.isSplat() &&
994-
multiples.getSplatValue<APInt>().getSExtValue() == 1)
995-
return getInput1();
996-
if (auto int_array_attr =
997-
llvm::dyn_cast<DenseIntElementsAttr>(multiples)) {
998-
if (llvm::all_of(int_array_attr.getValues<APInt>(),
999-
[](APInt v) { return v.getSExtValue() == 1; }))
1000-
return getInput1();
1001-
}
1002-
}
1003-
}
988+
bool allOnes = llvm::all_of(getMultiples(), [](int64_t v) { return v == 1; });
989+
if (allOnes && getInput1().getType() == getType())
990+
return getInput1();
1004991
return {};
1005992
}
1006993

0 commit comments

Comments
 (0)