Skip to content

[TOSA] Add StatefulOps to TOSA Dialect #66843

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ std::unique_ptr<Pass> createTosaToLinalgNamed();
void addTosaToLinalgPasses(
OpPassManager &pm, const TosaToLinalgOptions &options,
// Note: Default to 'none' level unless otherwise specified.
tosa::ValidationOptions const &validationOptions =
tosa::ValidationOptions().setLevel(tosa::TosaLevelEnum::None));
tosa::TosaValidationOptions const &validationOptions = {
tosa::TosaProfileEnum::Undefined, false, tosa::TosaLevelEnum::None});

/// Populates conversion passes from TOSA dialect to Linalg dialect.
void populateTosaToLinalgConversionPatterns(RewritePatternSet *patterns);
Expand Down
5 changes: 5 additions & 0 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ class PatternRewriter;

namespace tosa {

ParseResult parseTypeOrAttr(OpAsmParser &parser, TypeAttr &typeAttr,
Attribute &attr);
void printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type,
Attribute attr);

#include "mlir/Dialect/Tosa/IR/TosaInterfaces.h.inc"

} // namespace tosa
Expand Down
67 changes: 67 additions & 0 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -79,4 +79,71 @@ def Tosa_YieldOp : Tosa_Op<"yield", [
let assemblyFormat = "$inputs attr-dict `:` type($inputs)";
}

//===----------------------------------------------------------------------===//
// Operator: variable
//===----------------------------------------------------------------------===//
def Tosa_VariableOp : Tosa_Op<"variable", []> {
let summary = "Defines a variable";

let description = [{
Defines a new TOSA variable. This is a mutable value.
Modifications are expressed using read/write semantics.
}];

let arguments = (ins
SymbolNameAttr:$name,
TypeAttr:$type,
OptionalAttr<AnyAttr>:$initial_value
);

let assemblyFormat = [{
$name
attr-dict
custom<TypeOrAttr>($type, $initial_value)
}];
}

//===----------------------------------------------------------------------===//
// Operator: variable.write
//===----------------------------------------------------------------------===//
def Tosa_VariableWriteOp : Tosa_Op<"variable.write", []> {
let summary = "write_buffer operator";

let description = [{
Assigns a value to pseudo-buffer resource holding a mutable tensor.
}];

let arguments = (ins
SymbolNameAttr:$name,
AnyType:$value
);

let assemblyFormat = [{
$name attr-dict `,` $value `:` type($value)
}];
}

//===----------------------------------------------------------------------===//
// Operator: variable.read
//===----------------------------------------------------------------------===//
def Tosa_VariableReadOp : Tosa_Op<"variable.read", []> {
let summary = "read_buffer operator";

let description = [{
Reads the value from a pseudo-buffer resource holding a mutable tensor.
}];

let arguments = (ins
SymbolNameAttr:$name
);

let results = (outs
AnyType:$value
);

let assemblyFormat = [{
$name attr-dict `:` type($value)
}];
}

#endif // TOSA_UTIL_OPS
3 changes: 0 additions & 3 deletions mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,6 @@ struct ValidationOptions {
}
};

std::unique_ptr<Pass> createTosaValidationPass(
ValidationOptions const &options = ValidationOptions());

#define GEN_PASS_REGISTRATION
#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"

Expand Down
3 changes: 1 addition & 2 deletions mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -89,13 +89,12 @@ def TosaLevelType : I32EnumAttr<"TosaLevelEnum", "Tosa level",
let cppNamespace = "mlir::tosa";
}

def TosaValidation : Pass<"tosa-validate", "func::FuncOp"> {
def TosaValidation : Pass<"tosa-validate", "mlir::ModuleOp"> {
let summary = "Validates TOSA dialect";
let description = [{
This pass validates if input TOSA operations match the specification for given
criteria, e.g. TOSA profile.
}];
let constructor = "createTosaValidationPass()";

let options = [
Option<"profile", "profile", "mlir::tosa::TosaProfileEnum",
Expand Down
5 changes: 2 additions & 3 deletions mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ std::unique_ptr<Pass> mlir::tosa::createTosaToLinalg() {

void mlir::tosa::addTosaToLinalgPasses(
OpPassManager &pm, const TosaToLinalgOptions &options,
tosa::ValidationOptions const &validationOptions) {
tosa::TosaValidationOptions const &validationOptions) {
// Optional decompositions are designed to benefit linalg.
if (!options.disableTosaDecompositions)
pm.addNestedPass<func::FuncOp>(tosa::createTosaOptionalDecompositions());
Expand All @@ -90,7 +90,6 @@ void mlir::tosa::addTosaToLinalgPasses(
pm.addNestedPass<func::FuncOp>(tosa::createTosaLayerwiseConstantFoldPass(
{options.aggressiveReduceConstant}));
pm.addNestedPass<func::FuncOp>(tosa::createTosaMakeBroadcastablePass());
pm.addNestedPass<func::FuncOp>(
tosa::createTosaValidationPass(validationOptions));
pm.addNestedPass<func::FuncOp>(tosa::createTosaValidation(validationOptions));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TosaValidation pass seems to have been updated to work on mlir::ModuleOp instead of func::FuncOp, but this addition still uses func::FuncOp. Can this be updated?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will do

pm.addNestedPass<func::FuncOp>(tosa::createTosaToLinalg());
}
43 changes: 43 additions & 0 deletions mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,49 @@ Operation *TosaDialect::materializeConstant(OpBuilder &builder, Attribute value,
return nullptr;
}

//===----------------------------------------------------------------------===//
// Parsers and printers
//===----------------------------------------------------------------------===//

ParseResult mlir::tosa::parseTypeOrAttr(OpAsmParser &parser, TypeAttr &typeAttr,
Attribute &attr) {
if (succeeded(parser.parseOptionalEqual())) {
if (failed(parser.parseAttribute(attr))) {
return parser.emitError(parser.getCurrentLocation())
<< "expected attribute";
}
if (auto typedAttr = attr.dyn_cast<TypedAttr>()) {
typeAttr = TypeAttr::get(typedAttr.getType());
}
return success();
}

Type type;
if (failed(parser.parseColonType(type))) {
return parser.emitError(parser.getCurrentLocation()) << "expected type";
}
typeAttr = TypeAttr::get(type);

return success();
}

void mlir::tosa::printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type,
Attribute attr) {
bool needsSpace = false;
auto typedAttr = attr.dyn_cast_or_null<TypedAttr>();
if (!typedAttr || typedAttr.getType() != type.getValue()) {
p << ": ";
p.printAttribute(type);
needsSpace = true; // subsequent attr value needs a space separator
}
if (attr) {
if (needsSpace)
p << ' ';
p << "= ";
p.printAttribute(attr);
}
}

//===----------------------------------------------------------------------===//
// TOSA Operator Verifiers.
//===----------------------------------------------------------------------===//
Expand Down
92 changes: 83 additions & 9 deletions mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
#include "mlir/Dialect/Tosa/Transforms/PassesEnums.cpp.inc"

#include <string>
#include <unordered_map>

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/IR/Builders.h"
Expand Down Expand Up @@ -96,12 +99,13 @@ static constexpr tosa_level_t TOSA_LEVEL_NONE = {0, 0, 0, 0};
struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
public:
explicit TosaValidation() { populateConstantOperandChecks(); }
explicit TosaValidation(const ValidationOptions &options) : TosaValidation() {
explicit TosaValidation(const TosaValidationOptions &options)
: TosaValidation() {
this->profile = options.profile;
this->StrictOperationSpecAlignment = options.strictOperationSpecAlignment;
this->StrictOperationSpecAlignment = options.StrictOperationSpecAlignment;
this->level = options.level;
}
void runOnOperation() override;
void runOnOperation() final;

LogicalResult applyConstantOperandCheck(Operation *op) {
for (auto &checker : const_checkers) {
Expand All @@ -113,6 +117,9 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {

LogicalResult applyLevelCheck(Operation *op);

// check variable read/write data types against variable declarations
LogicalResult applyVariableCheck(Operation *op);

private:
void populateConstantOperandChecks() {
const_checkers.emplace_back(checkConstantOperandPad);
Expand Down Expand Up @@ -398,8 +405,12 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
}
}

bool CheckVariable(Operation *op);
bool CheckVariableReadOrWrite(Operation *op);

SmallVector<std::function<LogicalResult(Operation *)>> const_checkers;
tosa_level_t tosa_level;
DenseMap<const mlir::StringAttr *, mlir::Type> variables_map;
};

LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
Expand Down Expand Up @@ -427,6 +438,69 @@ LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
return success();
}

inline bool CompatibleTypes(const mlir::Type &type,
const mlir::Type &declared_type) {
// for now, simply use type equality comparison
return type == declared_type;
}

bool TosaValidation::CheckVariable(Operation *op) {
if (isa<mlir::tosa::VariableOp>(op)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use dyn_cast to avoid the dangerous-looking cast below?

auto name_attr = cast<mlir::StringAttr>(op->getAttr("name"));

if (variables_map.count(&name_attr)) {
op->emitOpError() << "name has already been declared";
return false;
}

auto type_attr = cast<mlir::TypeAttr>(op->getAttr("type"));
mlir::Type type = type_attr.getValue();

variables_map[&name_attr] = type;
}

return true;
}

bool TosaValidation::CheckVariableReadOrWrite(Operation *op) {
if (isa<mlir::tosa::VariableReadOp>(op) ||
isa<mlir::tosa::VariableWriteOp>(op)) {
auto name_attr = cast<mlir::StringAttr>(op->getAttr("name"));

if (!variables_map.count(&name_attr)) {
op->emitOpError() << "name has not been declared";
return false;
}

auto var_type = variables_map[&name_attr];

for (auto v : op->getOperands()) {
auto type = v.getType();
if (!CompatibleTypes(type, var_type)) {
op->emitOpError() << "operand type does not equal variable type";
return false;
}
}

for (auto v : op->getResults()) {
auto type = v.getType();
if (!CompatibleTypes(type, var_type)) {
op->emitOpError() << "result type does not equal variable type";
return false;
}
}
}

return true;
}

LogicalResult TosaValidation::applyVariableCheck(Operation *op) {
if (!CheckVariable(op) || !CheckVariableReadOrWrite(op)) {
return failure();
}
return success();
}

void TosaValidation::runOnOperation() {
configLevelAndProfile();
getOperation().walk([&](Operation *op) {
Expand All @@ -440,18 +514,18 @@ void TosaValidation::runOnOperation() {
}
}

// Some uses of TOSA rely on the constant operands of particular operations.
// Some uses of TOSA rely on the constant operands of particular
// operations.
if (StrictOperationSpecAlignment && failed(applyConstantOperandCheck(op)))
signalPassFailure();

// do level checks
if (failed(applyLevelCheck(op)))
signalPassFailure();

// do variable type checks
if (failed(applyVariableCheck(op)))
signalPassFailure();
});
}
} // namespace

std::unique_ptr<Pass>
mlir::tosa::createTosaValidationPass(ValidationOptions const &options) {
return std::make_unique<TosaValidation>(options);
}
45 changes: 45 additions & 0 deletions mlir/test/Dialect/Tosa/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -203,3 +203,48 @@ func.func @test_avg_pool2d_zero_dim_input(%arg0: tensor<1x0x?x9xf32>) -> tensor<
: (tensor<1x0x?x9xf32>) -> tensor<1x7x7x9xf32>
return %0 : tensor<1x7x7x9xf32>
}

// -----

func.func @test_variable_duplicates(%arg0: tensor<2x4x8xi32>) -> () {
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
// expected-error@+1 {{'tosa.variable' op name has already been declared}}
tosa.variable @stored_var : tensor<1x4x8xi32>
return
}

// -----

func.func @test_variable_read_type(%arg0: tensor<2x4x8xi32>) -> () {
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
// expected-error@+1 {{'tosa.variable.read' op result type does not equal variable type}}
%0 = tosa.variable.read @stored_var : tensor<2x4x8xi16>
return
}

// -----

func.func @test_variable_read_shape(%arg0: tensor<2x4x8xi32>) -> () {
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
// expected-error@+1 {{'tosa.variable.read' op result type does not equal variable type}}
%0 = tosa.variable.read @stored_var : tensor<1x4x8xi32>
return
}

// -----

func.func @test_variable_write_type(%arg0: tensor<2x4x8xi16>) -> () {
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
// expected-error@+1 {{'tosa.variable.write' op operand type does not equal variable type}}
tosa.variable.write @stored_var, %arg0 : tensor<2x4x8xi16>
return
}

// -----

func.func @test_variable_write_shape(%arg0: tensor<1x4x8xi32>) -> () {
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
// expected-error@+1 {{'tosa.variable.write' op operand type does not equal variable type}}
tosa.variable.write @stored_var, %arg0 : tensor<1x4x8xi32>
return
}
Loading