Skip to content

Commit 5a33f7e

Browse files
committed
[TOSA] Add Tosa_Shape type and ConstShapeOp
Adds: 1. tosa shape type to Tosa dialect e.g., !tosa.shape<4> is a type for rank-4 shape values (size-4 array of index values) 2. const_shape operator 3. trait TosaShapeOperator, added to tosa shape operators, and a verifier that all operands and results of operator are tosa shapes 4. trait TosaResolvableShapeOperands, added to all tosa operators, and a verifier that every tosa shape operand is produced by a tosa shape operator (indicated by trait TosaShapeOperator) 5. trait TosaShapeOperatorWithSameRanks, added to Tosa_ElementwiseShapeOp and a verifier that all operands and result shapes have same ranks 5. changed TileOp's multiples from attribute to input, of !tosa.shape type. 6. add folder for tosa ConstShape operator Signed-off-by: Jerry Ge <[email protected]> Signed-off-by: Tai Ly <[email protected]> Change-Id: I0213f99f5816b648f732b01fe8bd196956f1dfc8
1 parent 386dec2 commit 5a33f7e

File tree

18 files changed

+416
-32
lines changed

18 files changed

+416
-32
lines changed

mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,12 @@ add_mlir_doc(TosaOps TosaOps Dialects/ -gen-op-doc)
33
add_mlir_interface(TosaInterfaces)
44

55
set(LLVM_TARGET_DEFINITIONS TosaOps.td)
6+
mlir_tablegen(TosaOpsTypesBase.h.inc -gen-typedef-decls -typedefs-dialect=tosa)
7+
mlir_tablegen(TosaOpsTypesBase.cpp.inc -gen-typedef-defs -typedefs-dialect=tosa)
68
mlir_tablegen(TosaAttributes.h.inc -gen-attrdef-decls -attrdefs-dialect=tosa)
79
mlir_tablegen(TosaAttributes.cpp.inc -gen-attrdef-defs -attrdefs-dialect=tosa)
810
add_public_tablegen_target(MLIRTosaAttributesIncGen)
911

1012
set(LLVM_TARGET_DEFINITIONS TosaDialectBytecode.td)
1113
mlir_tablegen(TosaDialectBytecode.cpp.inc -gen-bytecode -bytecode-dialect="Tosa")
1214
add_public_tablegen_target(MLIRTosaDialectBytecodeIncGen)
13-

mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def Tosa_Dialect : Dialect {
4545
let cppNamespace = "mlir::tosa";
4646
let hasConstantMaterializer = 1;
4747
let useDefaultAttributePrinterParser = 1;
48+
let useDefaultTypePrinterParser = 1;
4849
}
4950

5051
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,14 +90,55 @@ template <typename ConcreteType>
9090
class TosaElementwiseOperator
9191
: public TraitBase<ConcreteType, TosaElementwiseOperator> {};
9292

93+
LogicalResult verifyTosaResolvableShapeOperands(Operation *op);
94+
/// This class verifies that tosa shape operands are compile time resolvable
95+
template <typename ConcreteType>
96+
class TosaResolvableShapeOperands
97+
: public TraitBase<ConcreteType, TosaResolvableShapeOperands> {
98+
public:
99+
static LogicalResult verifyTrait(Operation *op) {
100+
return verifyTosaResolvableShapeOperands(op);
101+
}
102+
};
103+
104+
LogicalResult verifyTosaShapeOperator(Operation *op);
105+
/// This class indicates that op operates on tosa shape types
106+
template <typename ConcreteType>
107+
class TosaShapeOperator : public TraitBase<ConcreteType, TosaShapeOperator> {
108+
public:
109+
static LogicalResult verifyTrait(Operation *op) {
110+
return verifyTosaShapeOperator(op);
111+
}
112+
};
113+
114+
LogicalResult verifyTosaShapeOperatorWithSameRanks(Operation *op);
115+
/// This class indicates that op operates on tosa shape types
116+
template <typename ConcreteType>
117+
class TosaShapeOperatorWithSameRanks
118+
: public TraitBase<ConcreteType, TosaShapeOperatorWithSameRanks> {
119+
public:
120+
static LogicalResult verifyTrait(Operation *op) {
121+
return verifyTosaShapeOperatorWithSameRanks(op);
122+
}
123+
};
124+
93125
} // namespace tosa
94126
} // namespace OpTrait
95127

128+
namespace tosa {
129+
130+
bool isa_tosa_shape_type(mlir::Type t);
131+
132+
} // namespace tosa
133+
96134
} // namespace mlir
97135

98136
#define GET_ATTRDEF_CLASSES
99137
#include "mlir/Dialect/Tosa/IR/TosaAttributes.h.inc"
100138

139+
#define GET_TYPEDEF_CLASSES
140+
#include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.h.inc"
141+
101142
#define GET_OP_CLASSES
102143
#include "mlir/Dialect/Tosa/IR/TosaOps.h.inc"
103144

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1689,12 +1689,16 @@ def Tosa_TileOp : Tosa_InferShapedTypeOp<"tile"> {
16891689

16901690
let arguments = (ins
16911691
Tosa_Tensor:$input1,
1692-
DenseI64ArrayAttr:$multiples);
1692+
Tosa_Shape:$multiples);
16931693

16941694
let results = (outs
16951695
Tosa_Tensor:$output
16961696
);
16971697

1698+
let extraClassDeclaration = [{
1699+
LogicalResult getConstantMultiples(llvm::SmallVector<int64_t> &multiples);
1700+
}];
1701+
16981702
let hasFolder = 1;
16991703
let hasVerifier = 1;
17001704
}
@@ -2106,4 +2110,6 @@ def Tosa_WhileOp : Tosa_Op<"while_loop", [
21062110

21072111
include "mlir/Dialect/Tosa/IR/TosaUtilOps.td"
21082112

2113+
include "mlir/Dialect/Tosa/IR/TosaShapeOps.td"
2114+
21092115
#endif // TOSA_OPS
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
//===-- TosaShapeOps.td - TOSA dialect utility operations --*- tablegen -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file defines shape operators for the TOSA dialect.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#ifndef TOSA_SHAPE_OPS
14+
#define TOSA_SHAPE_OPS
15+
16+
include "mlir/IR/OpBase.td"
17+
18+
include "mlir/Interfaces/SideEffectInterfaces.td"
19+
include "mlir/Interfaces/InferTypeOpInterface.td"
20+
include "mlir/Interfaces/LoopLikeInterface.td"
21+
include "mlir/Dialect/Tosa/IR/TosaInterfaces.td"
22+
23+
include "mlir/Dialect/Tosa/IR/TosaTypesBase.td"
24+
include "mlir/Dialect/Tosa/IR/TosaOpBase.td"
25+
26+
// Op trait: operator has operands and results with TOSA shape type
27+
def TosaShapeOperator : NativeOpTrait<"TosaShapeOperator"> {
28+
let cppNamespace = "mlir::OpTrait::tosa";
29+
}
30+
31+
class Tosa_ShapeOp<string mnemonic, list<Trait> traits = []>
32+
: Tosa_Op<mnemonic, !listconcat(traits, [TosaShapeOperator, Pure])> {
33+
34+
let assemblyFormat =
35+
"operands attr-dict `:` functional-type(operands, results)";
36+
37+
let hasFolder = 1;
38+
}
39+
40+
// op trait: shape operator has same ranks for operands and results
41+
def TosaShapeOperatorWithSameRanks
42+
: NativeOpTrait<"TosaShapeOperatorWithSameRanks"> {
43+
let cppNamespace = "mlir::OpTrait::tosa";
44+
}
45+
46+
class Tosa_ElementwiseShapeOp<string mnemonic, list<Trait> traits = []>
47+
: Tosa_ShapeOp<mnemonic,
48+
!listconcat(traits, [TosaShapeOperatorWithSameRanks])> {
49+
}
50+
51+
52+
//===----------------------------------------------------------------------===//
53+
// Operator: ConstShape
54+
//===----------------------------------------------------------------------===//
55+
def Tosa_ConstShapeOp : Tosa_ShapeOp<"const_shape", [ConstantLike, Pure]> {
56+
let summary = "Constant Shape op.";
57+
58+
let description = [{
59+
A node containing constant data for use as the input to an shape operation. May
60+
hold data only in index data type.
61+
62+
Example:
63+
64+
```mlir
65+
// Generic form
66+
%out = "tosa.const_shape"() {value = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4>
67+
```
68+
}];
69+
70+
let arguments = (ins IndexElementsAttr : $value);
71+
72+
let results = (outs Tosa_Shape : $output);
73+
74+
let hasVerifier = 1;
75+
}
76+
77+
#endif // TOSA_SHAPE_OPS

mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,11 @@
1313
#ifndef TOSA_TYPES_BASE
1414
#define TOSA_TYPES_BASE
1515

16+
include "mlir/IR/AttrTypeBase.td"
1617
include "mlir/IR/OpBase.td"
1718

19+
include "mlir/Dialect/Tosa/IR/TosaOpBase.td"
20+
1821
//===----------------------------------------------------------------------===//
1922
// Tosa Type Definitions.
2023
//===----------------------------------------------------------------------===//
@@ -215,4 +218,66 @@ def Tosa_Buffer : MemRefOf<[Tosa_AnyNumber]>;
215218
def Tosa_TupleBuffer : NestedTupleOf<[Tosa_Buffer]>;
216219
def Tosa_BufOrTuple : AnyTypeOf<[Tosa_Buffer, Tosa_TupleBuffer]>;
217220

221+
//===----------------------------------------------------------------------===//
222+
// Tosa Type Definitions.
223+
//===----------------------------------------------------------------------===//
224+
225+
// The base class for Tosa dialect types.
226+
class Tosa_Type<string name, string typeMnemonic, list<Trait> traits = []>
227+
: TypeDef<Tosa_Dialect, name, traits> {
228+
let mnemonic = typeMnemonic;
229+
}
230+
231+
//===----------------------------------------------------------------------===//
232+
// ShapeType
233+
//===----------------------------------------------------------------------===//
234+
def Tosa_Shape : Tosa_Type<"shape", "shape"> {
235+
let summary = "Shape with static rank and Index element type";
236+
let description = [{
237+
Syntax:
238+
239+
``` shape - type :: = `shape` `<` rank `>`
240+
``` Values with shape type represents a shape with a fixed rank and a list
241+
of dimensions
242+
.Rank must be zero or a positive integer
243+
.Each dimension is represented by the builtin
244+
Index type.
245+
246+
Examples:
247+
248+
```mlir
249+
// Shape with rank of four, for example, [1, 1, 8, 16]:
250+
!tosa
251+
.shape<4>
252+
253+
// Shape with rank of one, for example, [16]:
254+
!tosa
255+
.shape<1>
256+
257+
// Shape with rank zero, for example, [] (i.e., shape of scalar values):
258+
!tosa.shape<0>
259+
```
260+
}];
261+
let parameters = (ins "int" : $rank);
262+
let builders = [TypeBuilder<(ins "int" : $rank)>];
263+
let assemblyFormat = "`<` $rank `>`";
264+
265+
let genVerifyDecl = 1;
266+
}
267+
268+
def IsTosaShapeType : CPred<"mlir::tosa::isa_tosa_shape_type($_self)">;
269+
270+
// Whether a Tosa Shape type has a rank equal to the specified rank.
271+
class IsTosaShapeOfRankPred<int rank> : And<[
272+
IsTosaShapeType,
273+
CPred<[{::llvm::cast<::mlir::tosa::shapeType>($_self).getRank() == }] # rank>
274+
]>;
275+
276+
class TosaShapeOfRank<int rank>
277+
: Type<IsTosaShapeOfRankPred<rank>, "Tosa shape type of rank " #rank>;
278+
279+
def Rank1TosaShape : TosaShapeOfRank<1>;
280+
def Rank2TosaShape : TosaShapeOfRank<2>;
281+
def Rank4TosaShape : TosaShapeOfRank<4>;
282+
218283
#endif // TOSA_TYPES_BASE

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1886,7 +1886,9 @@ struct TileConverter : public OpConversionPattern<tosa::TileOp> {
18861886
auto elementTy = inputTy.getElementType();
18871887
int64_t rank = inputTy.getRank();
18881888

1889-
ArrayRef<int64_t> multiples = op.getMultiples();
1889+
SmallVector<int64_t> multiples;
1890+
if (failed(op.getConstantMultiples(multiples)))
1891+
return failure();
18901892

18911893
// Broadcast the newly added dimensions to their appropriate multiple.
18921894
SmallVector<int64_t, 2> genericShape;

mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ struct TosaToLinalg : public impl::TosaToLinalgBase<TosaToLinalg> {
5555
target.addLegalOp<tosa::ApplyScaleOp>();
5656
target.addLegalOp<tosa::IfOp>();
5757
target.addLegalOp<tosa::ConstOp>();
58+
target.addLegalOp<tosa::ConstShapeOp>();
5859
target.addLegalOp<tosa::WhileOp>();
5960
target.addLegalOp<tosa::ConcatOp>();
6061
target.addLegalOp<tosa::SliceOp>();

mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -808,6 +808,8 @@ OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
808808

809809
OpFoldResult ConstOp::fold(FoldAdaptor adaptor) { return getValueAttr(); }
810810

811+
OpFoldResult ConstShapeOp::fold(FoldAdaptor adaptor) { return getValueAttr(); }
812+
811813
#define REDUCE_FOLDER(OP) \
812814
OpFoldResult OP::fold(FoldAdaptor adaptor) { \
813815
ShapedType inputTy = llvm::cast<ShapedType>(getInput().getType()); \
@@ -985,9 +987,20 @@ OpFoldResult tosa::SelectOp::fold(FoldAdaptor adaptor) {
985987
}
986988

987989
OpFoldResult TileOp::fold(FoldAdaptor adaptor) {
988-
bool allOnes = llvm::all_of(getMultiples(), [](int64_t v) { return v == 1; });
989-
if (allOnes && getInput1().getType() == getType())
990-
return getInput1();
990+
if (getInput1().getType() == getType()) {
991+
if (auto multiples = llvm::dyn_cast_if_present<DenseElementsAttr>(
992+
adaptor.getMultiples())) {
993+
if (multiples.isSplat() &&
994+
multiples.getSplatValue<APInt>().getSExtValue() == 1)
995+
return getInput1();
996+
if (auto int_array_attr =
997+
llvm::dyn_cast<DenseIntElementsAttr>(multiples)) {
998+
if (llvm::all_of(int_array_attr.getValues<APInt>(),
999+
[](APInt v) { return v.getSExtValue() == 1; }))
1000+
return getInput1();
1001+
}
1002+
}
1003+
}
9911004
return {};
9921005
}
9931006

0 commit comments

Comments
 (0)