Skip to content

Commit d60a314

Browse files
committed
[mlir][tosa] Require signless types in validation and add corresponding conversion pass
Firstly, this commit requires that all types are signless in the strict mode of the validation pass. This is because signless types on operations are required by the TOSA specification. The "strict" mode in the validation pass is the final check for TOSA conformance to the specification, which can often be used for conversion to other formats. In addition, a conversion pass `--tosa-convert-integer-type-to-signless` is provided to allow a user to convert all integer types to signless. The intention is that this pass can be run before the validation pass. Following use of this pass, input/output information should be carried independently by the user. Change-Id: Id7aebf0071c9a7516c77f55062db82760c0da533
1 parent 9fbde32 commit d60a314

File tree

7 files changed

+265
-4
lines changed

7 files changed

+265
-4
lines changed

mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,4 +127,18 @@ def TosaReduceTransposes : Pass<"tosa-reduce-transposes", "func::FuncOp"> {
127127
}];
128128
}
129129

130+
def TosaConvertIntegerTypeToSignless : Pass<"tosa-convert-integer-type-to-signless", "func::FuncOp"> {
131+
let summary = "Convert integer types to signless";
132+
let description = [{
133+
This pass converts signed or unsigned integer types to signless. It
134+
currently does this greedily for all operators and can also change the
135+
signature of the function. Should the signature of the entrypoint
136+
function change, it will be the responsibility of the user to carry
137+
signedness information of the inputs and outputs independently.
138+
139+
This can be a useful transformation for conversion to other formats
140+
that require strict adherence to the TOSA specification.
141+
}];
142+
}
143+
130144
#endif // MLIR_DIALECT_TOSA_TRANSFORMS_PASSES

mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
add_mlir_dialect_library(MLIRTosaTransforms
2+
TosaConvertIntegerTypeToSignless.cpp
23
TosaDecomposeTransposeConv.cpp
34
TosaDecomposeDepthwise.cpp
45
TosaFolders.cpp
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
//===- TosaConvertIntegerTypeToSignless.cpp
2+
//-------------------------------------------===//
3+
//
4+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5+
// See https://llvm.org/LICENSE.txt for license information.
6+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7+
//
8+
//===-------------------------------------------------------------------------------===//
9+
10+
// -----------
11+
// Motivation:
12+
// -----------
13+
14+
// The TOSA specification uses a signless type system, which means that
15+
// information about signedness must be encapsulated by the operations
16+
// themselves. For example, tosa.rescale provides the attributes
17+
// `input_unsigned` and `output_unsigned` to indicate whether the input/output
18+
// should be interpreted as unsigned or signed.
19+
20+
// The TOSA dialect, on the other hand, allows the use of signed or unsigned
21+
// types in addition to signless. As such, when converting from TOSA dialect to
22+
// other formats, we need to ensure that we conform to the TOSA specification.
23+
24+
// ---------
25+
// Overview:
26+
// ---------
27+
28+
// This pass converts signed or unsigned integer types to signless. It currently
29+
// does this greedily for all operators and can also change the signature of the
30+
// function. Should the signature of the entrypoint function change, it will be
31+
// the responsibility of the user to carry signedness information of the inputs
32+
// and outputs independently.
33+
34+
#include "mlir/Dialect/Func/IR/FuncOps.h"
35+
#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
36+
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
37+
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
38+
#include "mlir/Transforms/DialectConversion.h"
39+
40+
namespace mlir {
41+
namespace tosa {
42+
43+
#define GEN_PASS_DEF_TOSACONVERTINTEGERTYPETOSIGNLESS
44+
#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
45+
46+
namespace {
47+
class ToSignlessTensorTypeConverter : public TypeConverter {
48+
static Type convertType(Type type) {
49+
const auto tensorType = dyn_cast<TensorType>(type);
50+
if (!tensorType)
51+
return type;
52+
53+
const auto intType = dyn_cast<IntegerType>(tensorType.getElementType());
54+
if (!intType ||
55+
intType.getSignedness() == IntegerType::SignednessSemantics::Signless)
56+
return type;
57+
58+
const auto signlessType = IntegerType::get(
59+
intType.getContext(), intType.getWidth(), IntegerType::Signless);
60+
return tensorType.cloneWith(std::nullopt, signlessType);
61+
}
62+
63+
public:
64+
explicit ToSignlessTensorTypeConverter() { addConversion(convertType); }
65+
};
66+
67+
class ConvertGenericOpWithIntegerTensorType : public ConversionPattern {
68+
public:
69+
ConvertGenericOpWithIntegerTensorType(TypeConverter &typeConverter,
70+
MLIRContext *context)
71+
: ConversionPattern(typeConverter, MatchAnyOpTypeTag{}, 0, context) {}
72+
73+
LogicalResult
74+
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
75+
ConversionPatternRewriter &rewriter) const final {
76+
// Typically TOSA operators have a single result, but some have an
77+
// arbitrary number. 4 seems like a good balance as an optimization
78+
// hint for storing result types.
79+
constexpr unsigned int numResults = 4;
80+
81+
// Convert integer types to signless
82+
SmallVector<Type, numResults> resultTypes;
83+
if (failed(typeConverter->convertTypes(op->getResultTypes(), resultTypes)))
84+
return failure();
85+
86+
// Create new op with replaced operands and results
87+
auto *newOp = Operation::create(
88+
op->getLoc(), op->getName(), resultTypes, operands, op->getAttrs(),
89+
op->getPropertiesStorage(), op->getSuccessors(), op->getNumRegions());
90+
91+
// Handle regions in e.g. tosa.cond_if and tosa.while_loop
92+
for (auto regions : llvm::zip(op->getRegions(), newOp->getRegions())) {
93+
Region &before = std::get<0>(regions);
94+
Region &parent = std::get<1>(regions);
95+
rewriter.inlineRegionBefore(before, parent, parent.end());
96+
if (failed(rewriter.convertRegionTypes(&parent, *typeConverter)))
97+
return failure();
98+
}
99+
100+
// Replace with rewritten op
101+
rewriter.insert(newOp);
102+
rewriter.replaceOp(op, newOp->getResults());
103+
return success();
104+
}
105+
};
106+
107+
class TosaConvertIntegerTypeToSignless
108+
: public impl::TosaConvertIntegerTypeToSignlessBase<
109+
TosaConvertIntegerTypeToSignless> {
110+
public:
111+
void runOnOperation() override {
112+
MLIRContext *context = &getContext();
113+
ConversionTarget target(*context);
114+
ToSignlessTensorTypeConverter typeConverter;
115+
116+
target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
117+
return typeConverter.isSignatureLegal(op.getFunctionType()) &&
118+
typeConverter.isLegal(&op.getBody());
119+
});
120+
target.markUnknownOpDynamicallyLegal([&](Operation *op) {
121+
return typeConverter.isLegal(op->getOperandTypes()) &&
122+
typeConverter.isLegal(op->getResultTypes());
123+
});
124+
125+
RewritePatternSet patterns(context);
126+
populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
127+
patterns, typeConverter);
128+
patterns.add<ConvertGenericOpWithIntegerTensorType>(typeConverter, context);
129+
130+
if (failed(
131+
applyFullConversion(getOperation(), target, std::move(patterns))))
132+
signalPassFailure();
133+
}
134+
};
135+
136+
} // namespace
137+
138+
} // namespace tosa
139+
} // namespace mlir

mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1320,21 +1320,22 @@ void TosaValidation::runOnOperation() {
13201320

13211321
// validate operator element types:
13221322
// - rescale operator is allowed to have ui8/ui16/ui32
1323-
// operands/results
1323+
// operands/results when strictOpSpecAlignment is false
13241324
// - perform valid element type check at the beginning to
13251325
// protect rest of code against quantized element types
1326-
const bool opIsRescale = isa<tosa::RescaleOp>(op);
1326+
const bool allowUnsigned =
1327+
!strictOpSpecAlignment && isa<tosa::RescaleOp>(op);
13271328
for (Value operand : op->getOperands()) {
13281329
auto elementTy = getElementTypeOrSelf(operand);
1329-
if (!isValidElementType(elementTy, opIsRescale)) {
1330+
if (!isValidElementType(elementTy, allowUnsigned)) {
13301331
op->emitOpError() << "is not profile-aligned: element type "
13311332
<< elementTy << " is not legal";
13321333
return signalPassFailure();
13331334
}
13341335
}
13351336
for (Type resultTy : op->getResultTypes()) {
13361337
auto elementTy = getElementTypeOrSelf(resultTy);
1337-
if (!isValidElementType(elementTy, opIsRescale)) {
1338+
if (!isValidElementType(elementTy, allowUnsigned)) {
13381339
op->emitOpError() << "is not profile-aligned: element type "
13391340
<< elementTy << " is not legal";
13401341
return signalPassFailure();

mlir/test/Dialect/Tosa/invalid.mlir

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2000,6 +2000,7 @@ func.func @test_rescale_input_unsigned(%arg0: tensor<1x1xui8>) -> (tensor<1x1xi8
20002000
%1 = "tosa.const"() <{values = dense<2> : tensor<1xi32>}> : () -> tensor<1xi32>
20012001
%2 = "tosa.const"() <{values = dense<3> : tensor<1xi8>}> : () -> tensor<1xi8>
20022002
%3 = "tosa.const"() <{values = dense<-128> : tensor<1xi8>}> : () -> tensor<1xi8>
2003+
// expected-error@+1 {{'tosa.rescale' op is not profile-aligned: element type 'ui8' is not legal}}
20032004
%r = tosa.rescale %arg0, %1, %0, %3, %2 {input_unsigned = true, output_unsigned = false, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} : (tensor<1x1xui8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x1xi8>
20042005
return %r : tensor<1x1xi8>
20052006
}
@@ -2012,6 +2013,7 @@ func.func @test_rescale_output_unsigned(%arg0: tensor<1x1xi8>) -> (tensor<1x1xui
20122013
%1 = "tosa.const"() <{values = dense<2> : tensor<1xi32>}> : () -> tensor<1xi32>
20132014
%2 = "tosa.const"() <{values = dense<3> : tensor<1xi8>}> : () -> tensor<1xi8>
20142015
%3 = "tosa.const"() <{values = dense<-128> : tensor<1xi8>}> : () -> tensor<1xi8>
2016+
// expected-error@+1 {{'tosa.rescale' op is not profile-aligned: element type 'ui8' is not legal}}
20152017
%r = tosa.rescale %arg0, %1, %0, %3, %2 {input_unsigned = false, output_unsigned = true, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} : (tensor<1x1xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x1xui8>
20162018
return %r : tensor<1x1xui8>
20172019
}
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
// RUN: mlir-opt --split-input-file --tosa-convert-integer-type-to-signless %s | FileCheck %s
2+
3+
// -----
4+
5+
// CHECK-LABEL: test_rescale_output_unsigned
6+
// CHECK: %arg0: tensor<1x1xi8>
7+
func.func @test_rescale_output_unsigned(%arg0: tensor<1x1xi8>) -> (tensor<1x1xui8>) {
8+
%0 = "tosa.const"() <{values = dense<1> : tensor<1xi8>}> : () -> tensor<1xi8>
9+
%1 = "tosa.const"() <{values = dense<2> : tensor<1xi32>}> : () -> tensor<1xi32>
10+
%2 = "tosa.const"() <{values = dense<3> : tensor<1xi8>}> : () -> tensor<1xi8>
11+
%3 = "tosa.const"() <{values = dense<-128> : tensor<1xi8>}> : () -> tensor<1xi8>
12+
// CHECK: %[[RESCALE:.*]] = tosa.rescale %arg0, %1, %0, %3, %2 {input_unsigned = false, output_unsigned = true, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} : (tensor<1x1xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x1xi8>
13+
%r = tosa.rescale %arg0, %1, %0, %3, %2 {input_unsigned = false, output_unsigned = true, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} : (tensor<1x1xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x1xui8>
14+
// CHECK: return %[[RESCALE]] : tensor<1x1xi8>
15+
return %r : tensor<1x1xui8>
16+
}
17+
18+
// -----
19+
20+
// CHECK-LABEL: test_rescale_input_unsigned
21+
// CHECK: %arg0: tensor<1x1xi16>
22+
func.func @test_rescale_input_unsigned(%arg0: tensor<1x1xui16>) -> (tensor<1x1xi8>) {
23+
%0 = "tosa.const"() <{values = dense<1> : tensor<1xi8>}> : () -> tensor<1xi8>
24+
%1 = "tosa.const"() <{values = dense<2> : tensor<1xi32>}> : () -> tensor<1xi32>
25+
%2 = "tosa.const"() <{values = dense<3> : tensor<1xi8>}> : () -> tensor<1xi8>
26+
%3 = "tosa.const"() <{values = dense<32768> : tensor<1xi16>}> : () -> tensor<1xi16>
27+
// CHECK: %[[RESCALE:.*]] = tosa.rescale %arg0, %1, %0, %3, %2 {input_unsigned = true, output_unsigned = false, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} : (tensor<1x1xi16>, tensor<1xi32>, tensor<1xi8>, tensor<1xi16>, tensor<1xi8>) -> tensor<1x1xi8>
28+
%r = tosa.rescale %arg0, %1, %0, %3, %2 {input_unsigned = true, output_unsigned = false, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} : (tensor<1x1xui16>, tensor<1xi32>, tensor<1xi8>, tensor<1xi16>, tensor<1xi8>) -> tensor<1x1xi8>
29+
// CHECK: return %[[RESCALE]] : tensor<1x1xi8>
30+
return %r : tensor<1x1xi8>
31+
}
32+
33+
// -----
34+
35+
// CHECK-LABEL: test_unsigned_function_signature
36+
// CHECK: %arg0: tensor<1xi8>, %arg1: tensor<1xi8>
37+
func.func @test_unsigned_function_signature(%arg0: tensor<1xui8>, %arg1: tensor<1xui8>) -> (tensor<1xui8>, tensor<1xui8>) {
38+
// CHECK: return %arg0, %arg1 : tensor<1xi8>, tensor<1xi8>
39+
return %arg0, %arg1 : tensor<1xui8>, tensor<1xui8>
40+
}
41+
42+
// -----
43+
44+
// CHECK-LABEL: test_no_change
45+
// CHECK: %arg0: tensor<13x21x3xi8>
46+
func.func @test_no_change(%arg0: tensor<13x21x3xi8>) -> tensor<13x21x3xi8> {
47+
%0 = tosa.reverse %arg0 {axis = 0 : i32} : (tensor<13x21x3xi8>) -> tensor<13x21x3xi8>
48+
// CHECK: return %0 : tensor<13x21x3xi8>
49+
return %0 : tensor<13x21x3xi8>
50+
}
51+
52+
// -----
53+
54+
// CHECK-LABEL: test_regions
55+
// CHECK: %arg0: tensor<i8>, %arg1: tensor<i8>
56+
func.func @test_regions(%arg0: tensor<ui8>, %arg1: tensor<ui8>, %arg2: tensor<i1>) -> tensor<ui8> {
57+
// CHECK: tosa.cond_if %arg2 -> (tensor<i8>)
58+
%0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({
59+
^bb0(%arg3: tensor<ui8>, %arg4: tensor<ui8>):
60+
// CHECK: %1 = tosa.add %arg0, %arg1 : (tensor<i8>, tensor<i8>) -> tensor<i8>
61+
%1 = tosa.add %arg0, %arg1 : (tensor<ui8>, tensor<ui8>) -> tensor<ui8>
62+
// CHECK: tosa.yield %1 : tensor<i8>
63+
tosa.yield %1 : tensor<ui8>
64+
}, {
65+
^bb0(%arg3: tensor<ui8>, %arg4: tensor<ui8>):
66+
// CHECK: %1 = tosa.sub %arg0, %arg1 : (tensor<i8>, tensor<i8>) -> tensor<i8>
67+
%1 = tosa.sub %arg0, %arg1 : (tensor<ui8>, tensor<ui8>) -> tensor<ui8>
68+
// CHECK: tosa.yield %1 : tensor<i8>
69+
tosa.yield %1 : tensor<ui8>
70+
}) : (tensor<i1>, tensor<ui8>, tensor<ui8>) -> tensor<ui8>
71+
// CHECK: return %0 : tensor<i8>
72+
return %0 : tensor<ui8>
73+
}
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
//--------------------------------------------------------------------------------------------------
2+
// Test valid IR in terms of the shape and type of tensor, and the argument type of
3+
// operation. Excludes the profile compilance checking since it is performed earlier in the
4+
// validation flow.
5+
//--------------------------------------------------------------------------------------------------
6+
7+
// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate="profile=pro_int,pro_fp extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround" | FileCheck %s
8+
9+
// -----
10+
11+
// CHECK-LABEL: test_rescale_input_unsigned
12+
func.func @test_rescale_input_unsigned(%arg0: tensor<1x1xui8>) -> (tensor<1x1xi8>) {
13+
%0 = "tosa.const"() <{values = dense<1> : tensor<1xi8>}> : () -> tensor<1xi8>
14+
%1 = "tosa.const"() <{values = dense<2> : tensor<1xi32>}> : () -> tensor<1xi32>
15+
%2 = "tosa.const"() <{values = dense<3> : tensor<1xi8>}> : () -> tensor<1xi8>
16+
%3 = "tosa.const"() <{values = dense<-128> : tensor<1xi8>}> : () -> tensor<1xi8>
17+
%r = tosa.rescale %arg0, %1, %0, %3, %2 {input_unsigned = true, output_unsigned = false, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} : (tensor<1x1xui8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x1xi8>
18+
return %r : tensor<1x1xi8>
19+
}
20+
21+
// -----
22+
23+
// CHECK-LABEL: test_rescale_output_unsigned
24+
func.func @test_rescale_output_unsigned(%arg0: tensor<1x1xi8>) -> (tensor<1x1xui8>) {
25+
%0 = "tosa.const"() <{values = dense<1> : tensor<1xi8>}> : () -> tensor<1xi8>
26+
%1 = "tosa.const"() <{values = dense<2> : tensor<1xi32>}> : () -> tensor<1xi32>
27+
%2 = "tosa.const"() <{values = dense<3> : tensor<1xi8>}> : () -> tensor<1xi8>
28+
%3 = "tosa.const"() <{values = dense<-128> : tensor<1xi8>}> : () -> tensor<1xi8>
29+
%r = tosa.rescale %arg0, %1, %0, %3, %2 {input_unsigned = false, output_unsigned = true, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} : (tensor<1x1xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x1xui8>
30+
return %r : tensor<1x1xui8>
31+
}

0 commit comments

Comments
 (0)