Skip to content

Commit ac7c15b

Browse files
Jerry-GeTai78641
authored andcommitted
Add StatefulOps to TOSA Dialect
This patch adds tosa.variable, tosa.variable.read and tosa.variable.write operators and tests. Signed-off-by: Jerry Ge <[email protected]> Change-Id: I647e2e5c3762d7890b03f6aa7c09a29198b7d355
1 parent 78c8db8 commit ac7c15b

File tree

6 files changed

+17
-40
lines changed

6 files changed

+17
-40
lines changed

mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ std::unique_ptr<Pass> createTosaToLinalgNamed();
3535
void addTosaToLinalgPasses(
3636
OpPassManager &pm, const TosaToLinalgOptions &options,
3737
// Note: Default to 'none' level unless otherwise specified.
38-
tosa::ValidationOptions const &validationOptions =
39-
tosa::ValidationOptions().setLevel(tosa::TosaLevelEnum::None));
38+
tosa::TosaValidationOptions const &validationOptions = {
39+
tosa::TosaProfileEnum::Undefined, false, tosa::TosaLevelEnum::None});
4040

4141
/// Populates conversion passes from TOSA dialect to Linalg dialect.
4242
void populateTosaToLinalgConversionPatterns(RewritePatternSet *patterns);

mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def Tosa_VariableWriteOp : Tosa_Op<"variable.write", []> {
114114
}];
115115

116116
let arguments = (ins
117-
FlatSymbolRefAttr:$name,
117+
SymbolNameAttr:$name,
118118
AnyType:$value
119119
);
120120

@@ -134,7 +134,7 @@ def Tosa_VariableReadOp : Tosa_Op<"variable.read", []> {
134134
}];
135135

136136
let arguments = (ins
137-
FlatSymbolRefAttr:$name
137+
SymbolNameAttr:$name
138138
);
139139

140140
let results = (outs

mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,6 @@ struct ValidationOptions {
6868
}
6969
};
7070

71-
std::unique_ptr<OperationPass<ModuleOp>> createTosaValidationPass(
72-
ValidationOptions const &options = ValidationOptions());
73-
7471
#define GEN_PASS_REGISTRATION
7572
#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
7673

mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,6 @@ def TosaValidation : Pass<"tosa-validate", "mlir::ModuleOp"> {
9595
This pass validates if input TOSA operations match the specification for given
9696
criteria, e.g. TOSA profile.
9797
}];
98-
let constructor = "createTosaValidationPass()";
9998

10099
let options = [
101100
Option<"profile", "profile", "mlir::tosa::TosaProfileEnum",

mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ std::unique_ptr<Pass> mlir::tosa::createTosaToLinalg() {
7676

7777
void mlir::tosa::addTosaToLinalgPasses(
7878
OpPassManager &pm, const TosaToLinalgOptions &options,
79-
tosa::ValidationOptions const &validationOptions) {
79+
tosa::TosaValidationOptions const &validationOptions) {
8080
// Optional decompositions are designed to benefit linalg.
8181
if (!options.disableTosaDecompositions)
8282
pm.addNestedPass<func::FuncOp>(tosa::createTosaOptionalDecompositions());
@@ -90,7 +90,6 @@ void mlir::tosa::addTosaToLinalgPasses(
9090
pm.addNestedPass<func::FuncOp>(tosa::createTosaLayerwiseConstantFoldPass(
9191
{options.aggressiveReduceConstant}));
9292
pm.addNestedPass<func::FuncOp>(tosa::createTosaMakeBroadcastablePass());
93-
pm.addNestedPass<func::FuncOp>(
94-
tosa::createTosaValidationPass(validationOptions));
93+
pm.addNestedPass<func::FuncOp>(tosa::createTosaValidation(validationOptions));
9594
pm.addNestedPass<func::FuncOp>(tosa::createTosaToLinalg());
9695
}

mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp

Lines changed: 11 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -99,9 +99,10 @@ static constexpr tosa_level_t TOSA_LEVEL_NONE = {0, 0, 0, 0};
9999
struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
100100
public:
101101
explicit TosaValidation() { populateConstantOperandChecks(); }
102-
explicit TosaValidation(const ValidationOptions &options) : TosaValidation() {
102+
explicit TosaValidation(const TosaValidationOptions &options)
103+
: TosaValidation() {
103104
this->profile = options.profile;
104-
this->StrictOperationSpecAlignment = options.strictOperationSpecAlignment;
105+
this->StrictOperationSpecAlignment = options.StrictOperationSpecAlignment;
105106
this->level = options.level;
106107
}
107108
void runOnOperation() final;
@@ -409,7 +410,7 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
409410

410411
SmallVector<std::function<LogicalResult(Operation *)>> const_checkers;
411412
tosa_level_t tosa_level;
412-
std::unordered_map<std::string, mlir::Type> variables_map;
413+
DenseMap<const mlir::StringAttr *, mlir::Type> variables_map;
413414
};
414415

415416
LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
@@ -445,26 +446,17 @@ inline bool CompatibleTypes(const mlir::Type &type,
445446

446447
bool TosaValidation::CheckVariable(Operation *op) {
447448
if (isa<mlir::tosa::VariableOp>(op)) {
448-
auto name_attr = dyn_cast<mlir::StringAttr>(op->getAttr("name"));
449-
if (!name_attr) {
450-
op->emitOpError() << "Name attribute is not StringAttr";
451-
return false;
452-
}
453-
std::string name = name_attr.getValue().str();
449+
auto name_attr = cast<mlir::StringAttr>(op->getAttr("name"));
454450

455-
if (variables_map.count(name)) {
451+
if (variables_map.count(&name_attr)) {
456452
op->emitOpError() << "name has already been declared";
457453
return false;
458454
}
459455

460-
auto type_attr = dyn_cast<mlir::TypeAttr>(op->getAttr("type"));
461-
if (!type_attr) {
462-
op->emitOpError() << "type attribute is not TypeAttr";
463-
return false;
464-
}
456+
auto type_attr = cast<mlir::TypeAttr>(op->getAttr("type"));
465457
mlir::Type type = type_attr.getValue();
466458

467-
variables_map[name] = type;
459+
variables_map[&name_attr] = type;
468460
}
469461

470462
return true;
@@ -473,19 +465,14 @@ bool TosaValidation::CheckVariable(Operation *op) {
473465
bool TosaValidation::CheckVariableReadOrWrite(Operation *op) {
474466
if (isa<mlir::tosa::VariableReadOp>(op) ||
475467
isa<mlir::tosa::VariableWriteOp>(op)) {
476-
auto name_attr = dyn_cast<mlir::FlatSymbolRefAttr>(op->getAttr("name"));
477-
if (!name_attr) {
478-
op->emitOpError() << "name attribute is not FlatSymbolRefAttr";
479-
return false;
480-
}
481-
std::string name = name_attr.getValue().str();
468+
auto name_attr = cast<mlir::StringAttr>(op->getAttr("name"));
482469

483-
if (!variables_map.count(name)) {
470+
if (!variables_map.count(&name_attr)) {
484471
op->emitOpError() << "name has not been declared";
485472
return false;
486473
}
487474

488-
auto var_type = variables_map[name];
475+
auto var_type = variables_map[&name_attr];
489476

490477
for (auto v : op->getOperands()) {
491478
auto type = v.getType();
@@ -542,8 +529,3 @@ void TosaValidation::runOnOperation() {
542529
});
543530
}
544531
} // namespace
545-
546-
std::unique_ptr<OperationPass<ModuleOp>>
547-
mlir::tosa::createTosaValidationPass(ValidationOptions const &options) {
548-
return std::make_unique<TosaValidation>(options);
549-
}

0 commit comments

Comments
 (0)