Skip to content

Commit f09db6a

Browse files
authored
[TOSA] Add Tosa_Shape type and ConstShapeOp (llvm#122547)
Adds: 1. tosa shape type to Tosa dialect e.g., !tosa.shape<4> is a type for rank-4 shape values (size-4 array of index values) 2. const_shape operator 3. trait TosaShapeOperator, added to tosa shape operators, and a verifier that all operands and results of operator are tosa shapes 4. trait TosaResolvableShapeOperands, added to all tosa operators, and a verifier that every tosa shape operand is produced by a tosa shape operator (indicated by trait TosaShapeOperator) 5. trait TosaShapeOperatorWithSameRanks, added to Tosa_ElementwiseShapeOp and a verifier that all operands and result shapes have same ranks 5. changed TileOp's multiples from attribute to input, of !tosa.shape type. 6. add folder for tosa ConstShape operator This patch was originally authored by Tai Ly <[email protected]> Signed-off-by: Jerry Ge <[email protected]> Signed-off-by: Tai Ly <[email protected]>
1 parent 31249e2 commit f09db6a

File tree

18 files changed

+425
-33
lines changed

18 files changed

+425
-33
lines changed

mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,12 @@ add_mlir_doc(TosaOps TosaOps Dialects/ -gen-op-doc)
33
add_mlir_interface(TosaInterfaces)
44

55
set(LLVM_TARGET_DEFINITIONS TosaOps.td)
6+
mlir_tablegen(TosaOpsTypesBase.h.inc -gen-typedef-decls -typedefs-dialect=tosa)
7+
mlir_tablegen(TosaOpsTypesBase.cpp.inc -gen-typedef-defs -typedefs-dialect=tosa)
68
mlir_tablegen(TosaAttributes.h.inc -gen-attrdef-decls -attrdefs-dialect=tosa)
79
mlir_tablegen(TosaAttributes.cpp.inc -gen-attrdef-defs -attrdefs-dialect=tosa)
810
add_public_tablegen_target(MLIRTosaAttributesIncGen)
911

1012
set(LLVM_TARGET_DEFINITIONS TosaDialectBytecode.td)
1113
mlir_tablegen(TosaDialectBytecode.cpp.inc -gen-bytecode -bytecode-dialect="Tosa")
1214
add_public_tablegen_target(MLIRTosaDialectBytecodeIncGen)
13-

mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def Tosa_Dialect : Dialect {
4545
let cppNamespace = "mlir::tosa";
4646
let hasConstantMaterializer = 1;
4747
let useDefaultAttributePrinterParser = 1;
48+
let useDefaultTypePrinterParser = 1;
4849
}
4950

5051
//===----------------------------------------------------------------------===//
@@ -217,12 +218,21 @@ def TosaElementwiseOperator : NativeOpTrait<"TosaElementwiseOperator"> {
217218
let cppNamespace = "mlir::OpTrait::tosa";
218219
}
219220

221+
//===----------------------------------------------------------------------===//
222+
// TOSA Operator Trait.
223+
//===----------------------------------------------------------------------===//
224+
// Op operands with TOSA shape types must be compile time resolvable
225+
def TosaResolvableShapeOperands : NativeOpTrait<"TosaResolvableShapeOperands"> {
226+
let cppNamespace = "mlir::OpTrait::tosa";
227+
}
228+
220229
//===----------------------------------------------------------------------===//
221230
// TOSA Operator Class.
222231
//===----------------------------------------------------------------------===//
223232

224233
class Tosa_Op<string mnemonic, list<Trait> traits = []> :
225-
Op<Tosa_Dialect, mnemonic, !listconcat(traits, [TosaOpInterface])> {
234+
Op<Tosa_Dialect, mnemonic, !listconcat(traits, [TosaOpInterface,
235+
TosaResolvableShapeOperands])> {
226236
}
227237

228238
class Tosa_ElementwiseOp<string mnemonic, list<Trait> traits = []> :

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,14 +90,55 @@ template <typename ConcreteType>
9090
class TosaElementwiseOperator
9191
: public TraitBase<ConcreteType, TosaElementwiseOperator> {};
9292

93+
LogicalResult verifyTosaResolvableShapeOperands(Operation *op);
94+
/// This class verifies that tosa shape operands are compile time resolvable
95+
template <typename ConcreteType>
96+
class TosaResolvableShapeOperands
97+
: public TraitBase<ConcreteType, TosaResolvableShapeOperands> {
98+
public:
99+
static LogicalResult verifyTrait(Operation *op) {
100+
return verifyTosaResolvableShapeOperands(op);
101+
}
102+
};
103+
104+
LogicalResult verifyTosaShapeOperator(Operation *op);
105+
/// This class indicates that op operates on tosa shape types
106+
template <typename ConcreteType>
107+
class TosaShapeOperator : public TraitBase<ConcreteType, TosaShapeOperator> {
108+
public:
109+
static LogicalResult verifyTrait(Operation *op) {
110+
return verifyTosaShapeOperator(op);
111+
}
112+
};
113+
114+
LogicalResult verifyTosaShapeOperatorWithSameRanks(Operation *op);
115+
/// This class indicates that op operates on tosa shape types
116+
template <typename ConcreteType>
117+
class TosaShapeOperatorWithSameRanks
118+
: public TraitBase<ConcreteType, TosaShapeOperatorWithSameRanks> {
119+
public:
120+
static LogicalResult verifyTrait(Operation *op) {
121+
return verifyTosaShapeOperatorWithSameRanks(op);
122+
}
123+
};
124+
93125
} // namespace tosa
94126
} // namespace OpTrait
95127

128+
namespace tosa {
129+
130+
bool isa_tosa_shape_type(mlir::Type t);
131+
132+
} // namespace tosa
133+
96134
} // namespace mlir
97135

98136
#define GET_ATTRDEF_CLASSES
99137
#include "mlir/Dialect/Tosa/IR/TosaAttributes.h.inc"
100138

139+
#define GET_TYPEDEF_CLASSES
140+
#include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.h.inc"
141+
101142
#define GET_OP_CLASSES
102143
#include "mlir/Dialect/Tosa/IR/TosaOps.h.inc"
103144

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1689,12 +1689,16 @@ def Tosa_TileOp : Tosa_InferShapedTypeOp<"tile"> {
16891689

16901690
let arguments = (ins
16911691
Tosa_Tensor:$input1,
1692-
DenseI64ArrayAttr:$multiples);
1692+
Tosa_Shape:$multiples);
16931693

16941694
let results = (outs
16951695
Tosa_Tensor:$output
16961696
);
16971697

1698+
let extraClassDeclaration = [{
1699+
LogicalResult getConstantMultiples(llvm::SmallVector<int64_t> &multiples);
1700+
}];
1701+
16981702
let hasFolder = 1;
16991703
let hasVerifier = 1;
17001704
}
@@ -2106,4 +2110,6 @@ def Tosa_WhileOp : Tosa_Op<"while_loop", [
21062110

21072111
include "mlir/Dialect/Tosa/IR/TosaUtilOps.td"
21082112

2113+
include "mlir/Dialect/Tosa/IR/TosaShapeOps.td"
2114+
21092115
#endif // TOSA_OPS
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
//===-- TosaShapeOps.td - TOSA dialect utility operations --*- tablegen -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file defines shape operators for the TOSA dialect.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#ifndef TOSA_SHAPE_OPS
14+
#define TOSA_SHAPE_OPS
15+
16+
include "mlir/IR/OpBase.td"
17+
18+
include "mlir/Interfaces/SideEffectInterfaces.td"
19+
include "mlir/Interfaces/InferTypeOpInterface.td"
20+
include "mlir/Interfaces/LoopLikeInterface.td"
21+
include "mlir/Dialect/Tosa/IR/TosaInterfaces.td"
22+
23+
include "mlir/Dialect/Tosa/IR/TosaTypesBase.td"
24+
include "mlir/Dialect/Tosa/IR/TosaOpBase.td"
25+
26+
// Op trait: operator has operands and results with TOSA shape type
27+
def TosaShapeOperator : NativeOpTrait<"TosaShapeOperator"> {
28+
let cppNamespace = "mlir::OpTrait::tosa";
29+
}
30+
31+
class Tosa_ShapeOp<string mnemonic, list<Trait> traits = []>
32+
: Tosa_Op<mnemonic, !listconcat(traits, [TosaShapeOperator, Pure])> {
33+
34+
let assemblyFormat =
35+
"operands attr-dict `:` functional-type(operands, results)";
36+
37+
let hasFolder = 1;
38+
}
39+
40+
// op trait: shape operator has same ranks for operands and results
41+
def TosaShapeOperatorWithSameRanks
42+
: NativeOpTrait<"TosaShapeOperatorWithSameRanks"> {
43+
let cppNamespace = "mlir::OpTrait::tosa";
44+
}
45+
46+
class Tosa_ElementwiseShapeOp<string mnemonic, list<Trait> traits = []>
47+
: Tosa_ShapeOp<mnemonic,
48+
!listconcat(traits, [TosaShapeOperatorWithSameRanks])> {
49+
}
50+
51+
52+
//===----------------------------------------------------------------------===//
53+
// Operator: ConstShape
54+
//===----------------------------------------------------------------------===//
55+
def Tosa_ConstShapeOp : Tosa_ShapeOp<"const_shape", [ConstantLike, Pure]> {
56+
let summary = "Constant Shape op.";
57+
58+
let description = [{
59+
A node containing constant data for use as the input to an shape operation. May
60+
hold data only in index data type.
61+
62+
Example:
63+
64+
```mlir
65+
// Generic form
66+
%out = "tosa.const_shape"() {value = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4>
67+
```
68+
}];
69+
70+
let arguments = (ins IndexElementsAttr : $value);
71+
72+
let results = (outs Tosa_Shape : $output);
73+
74+
let hasVerifier = 1;
75+
}
76+
77+
#endif // TOSA_SHAPE_OPS

mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,11 @@
1313
#ifndef TOSA_TYPES_BASE
1414
#define TOSA_TYPES_BASE
1515

16+
include "mlir/IR/AttrTypeBase.td"
1617
include "mlir/IR/OpBase.td"
1718

19+
include "mlir/Dialect/Tosa/IR/TosaOpBase.td"
20+
1821
//===----------------------------------------------------------------------===//
1922
// Tosa Type Definitions.
2023
//===----------------------------------------------------------------------===//
@@ -215,4 +218,66 @@ def Tosa_Buffer : MemRefOf<[Tosa_AnyNumber]>;
215218
def Tosa_TupleBuffer : NestedTupleOf<[Tosa_Buffer]>;
216219
def Tosa_BufOrTuple : AnyTypeOf<[Tosa_Buffer, Tosa_TupleBuffer]>;
217220

221+
//===----------------------------------------------------------------------===//
222+
// Tosa Type Definitions.
223+
//===----------------------------------------------------------------------===//
224+
225+
// The base class for Tosa dialect types.
226+
class Tosa_Type<string name, string typeMnemonic, list<Trait> traits = []>
227+
: TypeDef<Tosa_Dialect, name, traits> {
228+
let mnemonic = typeMnemonic;
229+
}
230+
231+
//===----------------------------------------------------------------------===//
232+
// ShapeType
233+
//===----------------------------------------------------------------------===//
234+
def Tosa_Shape : Tosa_Type<"shape", "shape"> {
235+
let summary = "Shape with static rank and Index element type";
236+
let description = [{
237+
Syntax:
238+
239+
``` shape - type :: = `shape` `<` rank `>`
240+
``` Values with shape type represents a shape with a fixed rank and a list
241+
of dimensions
242+
.Rank must be zero or a positive integer
243+
.Each dimension is represented by the builtin
244+
Index type.
245+
246+
Examples:
247+
248+
```mlir
249+
// Shape with rank of four, for example, [1, 1, 8, 16]:
250+
!tosa
251+
.shape<4>
252+
253+
// Shape with rank of one, for example, [16]:
254+
!tosa
255+
.shape<1>
256+
257+
// Shape with rank zero, for example, [] (i.e., shape of scalar values):
258+
!tosa.shape<0>
259+
```
260+
}];
261+
let parameters = (ins "int" : $rank);
262+
let builders = [TypeBuilder<(ins "int" : $rank)>];
263+
let assemblyFormat = "`<` $rank `>`";
264+
265+
let genVerifyDecl = 1;
266+
}
267+
268+
def IsTosaShapeType : CPred<"mlir::tosa::isa_tosa_shape_type($_self)">;
269+
270+
// Whether a Tosa Shape type has a rank equal to the specified rank.
271+
class IsTosaShapeOfRankPred<int rank> : And<[
272+
IsTosaShapeType,
273+
CPred<[{::llvm::cast<::mlir::tosa::shapeType>($_self).getRank() == }] # rank>
274+
]>;
275+
276+
class TosaShapeOfRank<int rank>
277+
: Type<IsTosaShapeOfRankPred<rank>, "Tosa shape type of rank " #rank>;
278+
279+
def Rank1TosaShape : TosaShapeOfRank<1>;
280+
def Rank2TosaShape : TosaShapeOfRank<2>;
281+
def Rank4TosaShape : TosaShapeOfRank<4>;
282+
218283
#endif // TOSA_TYPES_BASE

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1886,7 +1886,9 @@ struct TileConverter : public OpConversionPattern<tosa::TileOp> {
18861886
auto elementTy = inputTy.getElementType();
18871887
int64_t rank = inputTy.getRank();
18881888

1889-
ArrayRef<int64_t> multiples = op.getMultiples();
1889+
SmallVector<int64_t> multiples;
1890+
if (failed(op.getConstantMultiples(multiples)))
1891+
return failure();
18901892

18911893
// Broadcast the newly added dimensions to their appropriate multiple.
18921894
SmallVector<int64_t, 2> genericShape;

mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ struct TosaToLinalg : public impl::TosaToLinalgBase<TosaToLinalg> {
5555
target.addLegalOp<tosa::ApplyScaleOp>();
5656
target.addLegalOp<tosa::IfOp>();
5757
target.addLegalOp<tosa::ConstOp>();
58+
target.addLegalOp<tosa::ConstShapeOp>();
5859
target.addLegalOp<tosa::WhileOp>();
5960
target.addLegalOp<tosa::ConcatOp>();
6061
target.addLegalOp<tosa::SliceOp>();

mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -808,6 +808,8 @@ OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
808808

809809
OpFoldResult ConstOp::fold(FoldAdaptor adaptor) { return getValueAttr(); }
810810

811+
OpFoldResult ConstShapeOp::fold(FoldAdaptor adaptor) { return getValueAttr(); }
812+
811813
#define REDUCE_FOLDER(OP) \
812814
OpFoldResult OP::fold(FoldAdaptor adaptor) { \
813815
ShapedType inputTy = llvm::cast<ShapedType>(getInput().getType()); \
@@ -985,9 +987,20 @@ OpFoldResult tosa::SelectOp::fold(FoldAdaptor adaptor) {
985987
}
986988

987989
OpFoldResult TileOp::fold(FoldAdaptor adaptor) {
988-
bool allOnes = llvm::all_of(getMultiples(), [](int64_t v) { return v == 1; });
989-
if (allOnes && getInput1().getType() == getType())
990-
return getInput1();
990+
if (getInput1().getType() == getType()) {
991+
if (auto multiples = llvm::dyn_cast_if_present<DenseElementsAttr>(
992+
adaptor.getMultiples())) {
993+
if (multiples.isSplat() &&
994+
multiples.getSplatValue<APInt>().getSExtValue() == 1)
995+
return getInput1();
996+
if (auto int_array_attr =
997+
llvm::dyn_cast<DenseIntElementsAttr>(multiples)) {
998+
if (llvm::all_of(int_array_attr.getValues<APInt>(),
999+
[](APInt v) { return v.getSExtValue() == 1; }))
1000+
return getInput1();
1001+
}
1002+
}
1003+
}
9911004
return {};
9921005
}
9931006

0 commit comments

Comments
 (0)