Skip to content

Commit 8f928e6

Browse files
Tai78641Jerry-Ge
authored andcommitted
[mlir][tosa] Add FP8 support
Add FP8 support to following TOSA operators: ARGMAX AVGPOOL CONV2D CONV3D DEPTHWISE_CONV2D MATMUL MAX_POOL2D TRANSPOSE_CONV2D CONST CAST CONCAT PAD DIM RESHAPE REVERSE SLICE TILE TRANSPOSE GATHER SCATTER Also added verifiers as needed to check input/output element types and renamed inputs of transpose_conv2d and select to match spec. Signed-off-by: Tai Ly <[email protected]> Signed-off-by: Jerry Ge <[email protected]> Change-Id: I56adfabb2396e38b7ed3479e4fd680b740bdb4e4
1 parent 7c24041 commit 8f928e6

File tree

13 files changed

+806
-129
lines changed

13 files changed

+806
-129
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 110 additions & 71 deletions
Large diffs are not rendered by default.

mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -74,16 +74,25 @@ def Tosa_QuantizedInt : AnyTypeOf<[Tosa_QuantizedType<"uint8", [8], 0>,
7474
Tosa_QuantizedType<"int16", [16, 0], 1>,
7575
Tosa_QuantizedType<"int32", [32, 0], 1>]>;
7676

77+
def Tosa_F8 : AnyTypeOf<[
78+
F8E4M3FN,
79+
F8E5M2]>;
80+
7781
//===----------------------------------------------------------------------===//
7882
// Multi-category types.
7983
//===----------------------------------------------------------------------===//
8084
def Tosa_AnyNumber : AnyTypeOf<[Tosa_Int, Tosa_QuantizedInt, AnyFloat],
8185
"number">;
8286

87+
// Add F8 type support to Tosa_AnyNumber
88+
def Tosa_AnyNumber_Extended : AnyTypeOf<[Tosa_AnyNumber, Tosa_F8],
89+
"number_extended">;
90+
8391
// For weight tensors from tosa::Conv2DOp, tosa::Conv3DOp,
8492
// tosa::DepthwiseConv2DOp, tosa::TransposeConv2DOp
8593
def Tosa_Weight : AnyTypeOf<[Tosa_Int4, Tosa_Int8,
86-
Tosa_QuantizedInt, AnyFloat]>;
94+
Tosa_QuantizedInt, AnyFloat, Tosa_F8]>;
95+
8796

8897
//===----------------------------------------------------------------------===//
8998
// TOSA Tensor Conformance
@@ -130,9 +139,11 @@ def Tosa_FloatTensor : TosaTensorOf<[AnyFloat]>;
130139

131140
// Either ranked or unranked tensor of TOSA supported element types.
132141
def Tosa_Tensor : TosaTensorOf<[Tosa_AnyNumber]>;
142+
def Tosa_Tensor_Extended : TosaTensorOf<[Tosa_AnyNumber_Extended]>;
133143

134144
// Must be ranked but no further constraints
135-
def Tosa_RankedTensor : TosaRankedTensorOf<[Tosa_AnyNumber]>;
145+
def Tosa_RankedTensor : RankedTensorOf<[Tosa_AnyNumber]>;
146+
def Tosa_RankedTensor_Extended : RankedTensorOf<[Tosa_AnyNumber_Extended]>;
136147

137148
// Any tensor element type allowed in Tosa ops.
138149
def Tosa_ElementType : Type<Or<[Tosa_Int.predicate, Tosa_QuantizedInt.predicate,
@@ -145,23 +156,35 @@ class Tosa_TensorOfOrNone<list<Type> allowedTypes, string description = ""> :
145156
// Tensor types with constrained ranks.
146157
//===----------------------------------------------------------------------===//
147158

148-
def Tosa_Rank0Tensor : TosaTensorRankOf<[Tosa_AnyNumber], [0]>;
149-
159+
// Scalar tensors: Rank-1 (with only one element)
150160
def Tosa_ScalarTensor : TosaScalarTensorOf<[Tosa_AnyNumber], [1]>;
161+
def Tosa_ScalarTensor_Extended : TosaScalarTensorOf<[Tosa_AnyNumber_Extended], [1]>;
151162
def Tosa_ScalarInt8Tensor : TosaScalarTensorOf<[Tosa_Int8], [1]>;
152163

153164
// We include unranked tensors as a supported type for all possible tosa
154165
// Tensors as unranked does not guarantee invalid. If unranked tensors exist
155166
// they should be shape propagate used Tosa's shape inference pass and verified
156167
// to not include any remaining unranked tensors.
157168
def Tosa_UnrankedTensor : TosaUnrankedTensorOf<[Tosa_AnyNumber]>;
169+
def Tosa_UnrankedTensorExtended : TosaUnrankedTensorOf<[Tosa_AnyNumber_Extended]>;
158170

159171
def Tosa_Tensor1D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [1]>], "1-d tosa-conformant tensor", "::mlir::TensorType">;
160172
def Tosa_Tensor2D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [2]>], "2-d tosa-conformant tensor", "::mlir::TensorType">;
161173
def Tosa_Tensor3D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [3]>], "3-d tosa-conformant tensor", "::mlir::TensorType">;
162174
def Tosa_Tensor4D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [4]>], "4-d tosa-conformant tensor", "::mlir::TensorType">;
163175
def Tosa_Tensor5D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [5]>], "5-d tosa-conformant tensor", "::mlir::TensorType">;
164176

177+
def Tosa_Tensor1D_Extended : AnyTypeOf<[Tosa_UnrankedTensorExtended, TosaTensorRankOf<[Tosa_AnyNumber_Extended], [1]>],
178+
"1-d tosa-conformant tensor extended", "::mlir::TensorType">;
179+
def Tosa_Tensor2D_Extended : AnyTypeOf<[Tosa_UnrankedTensorExtended, TosaTensorRankOf<[Tosa_AnyNumber_Extended], [2]>],
180+
"2-d tosa-conformant tensor extended", "::mlir::TensorType">;
181+
def Tosa_Tensor3D_Extended : AnyTypeOf<[Tosa_UnrankedTensorExtended, TosaTensorRankOf<[Tosa_AnyNumber_Extended], [3]>],
182+
"3-d tosa-conformant tensor extended", "::mlir::TensorType">;
183+
def Tosa_Tensor4D_Extended : AnyTypeOf<[Tosa_UnrankedTensorExtended, TosaTensorRankOf<[Tosa_AnyNumber_Extended], [4]>],
184+
"4-d tosa-conformant tensor extended", "::mlir::TensorType">;
185+
def Tosa_Tensor5D_Extended : AnyTypeOf<[Tosa_UnrankedTensorExtended, TosaTensorRankOf<[Tosa_AnyNumber_Extended], [5]>],
186+
"5-d tosa-conformant tensor extended", "::mlir::TensorType">;
187+
165188
// Ranked tensors up to given rank.
166189
def Tosa_Tensor1Dto4D : AnyTypeOf<[
167190
Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [1,2,3,4]>]>;

mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -65,12 +65,12 @@ void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
6565
}
6666

6767
LogicalResult SelectOp::canonicalize(SelectOp op, PatternRewriter &rewriter) {
68-
auto notOp = op.getPred().getDefiningOp<tosa::LogicalNotOp>();
68+
auto notOp = op.getInput1().getDefiningOp<tosa::LogicalNotOp>();
6969
if (!notOp)
7070
return failure();
7171
rewriter.modifyOpInPlace(op, [&]() {
7272
op.getOperation()->setOperands(
73-
{notOp.getInput1(), op.getOnFalse(), op.getOnTrue()});
73+
{notOp.getInput1(), op.getInput3(), op.getInput2()});
7474
});
7575
return success();
7676
}
@@ -1118,18 +1118,18 @@ OpFoldResult SliceOp::fold(FoldAdaptor adaptor) {
11181118
}
11191119

11201120
OpFoldResult tosa::SelectOp::fold(FoldAdaptor adaptor) {
1121-
if (getOnTrue() == getOnFalse())
1122-
return getOnTrue();
1121+
if (getInput2() == getInput3())
1122+
return getInput2();
11231123

11241124
auto predicate =
1125-
llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getPred());
1125+
llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getInput1());
11261126
if (!predicate)
11271127
return {};
11281128

11291129
if (!predicate.isSplat())
11301130
return {};
1131-
return predicate.getSplatValue<APInt>().getBoolValue() ? getOnTrue()
1132-
: getOnFalse();
1131+
return predicate.getSplatValue<APInt>().getBoolValue() ? getInput2()
1132+
: getInput3();
11331133
}
11341134

11351135
OpFoldResult TileOp::fold(FoldAdaptor adaptor) {

0 commit comments

Comments
 (0)