Skip to content

Commit b315250

Browse files
committed
[mlir][tosa] Align Variable ops to match with TOSA v1.0 spec
* updated AnyType:$value to Tosa_Tensor:$input1 and Tosa_Tensor:$output1 for VariableWrite and VriableRead Operators * updated description discrepancies * note: in the TOSA spec, we had var_shape attr, but it's already included in the TypeAttr:$type in MLIR Signed-off-by: Jerry Ge <[email protected]> Change-Id: I4cd0348cd4e306dbc2e0e53a89a9404d91fb44d4
1 parent f1ac2af commit b315250

File tree

3 files changed

+11
-11
lines changed

3 files changed

+11
-11
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,8 @@ def Tosa_VariableOp : Tosa_Op<"variable", []> {
8686
let summary = "Defines a variable";
8787

8888
let description = [{
89-
Defines a new TOSA variable. This is a mutable value.
89+
Defines a new TOSA variable.
90+
This is a persistent mutable value across multiple TOSA graph invocations.
9091
Modifications are expressed using read/write semantics.
9192
}];
9293

@@ -115,12 +116,12 @@ def Tosa_VariableWriteOp : Tosa_Op<"variable.write", []> {
115116
let summary = "write_buffer operator";
116117

117118
let description = [{
118-
Assigns a value to pseudo-buffer resource holding a mutable tensor.
119+
Assigns a value to the pseudo-buffer resource holding a persistent mutable tensor.
119120
}];
120121

121122
let arguments = (ins
122123
SymbolNameAttr:$name,
123-
AnyType:$value
124+
Tosa_Tensor:$input1
124125
);
125126

126127
list<Availability> availability = [
@@ -129,7 +130,7 @@ def Tosa_VariableWriteOp : Tosa_Op<"variable.write", []> {
129130
];
130131

131132
let assemblyFormat = [{
132-
$name attr-dict `,` $value `:` type($value)
133+
$name attr-dict `,` $input1 `:` type($input1)
133134
}];
134135
}
135136

@@ -140,15 +141,15 @@ def Tosa_VariableReadOp : Tosa_Op<"variable.read", []> {
140141
let summary = "read_buffer operator";
141142

142143
let description = [{
143-
Reads the value from a pseudo-buffer resource holding a mutable tensor.
144+
Reads the value from a pseudo-buffer resource holding a persistent mutable tensor.
144145
}];
145146

146147
let arguments = (ins
147148
SymbolNameAttr:$name
148149
);
149150

150151
let results = (outs
151-
AnyType:$value
152+
Tosa_Tensor:$output1
152153
);
153154

154155
list<Availability> availability = [
@@ -157,7 +158,7 @@ def Tosa_VariableReadOp : Tosa_Op<"variable.read", []> {
157158
];
158159

159160
let assemblyFormat = [{
160-
$name attr-dict `:` type($value)
161+
$name attr-dict `:` type($output1)
161162
}];
162163
}
163164

mlir/lib/Conversion/TosaToMLProgram/TosaToMLProgram.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ class VariableWriteOpConverter
4545
auto globalSymbolRef =
4646
SymbolRefAttr::get(rewriter.getContext(), op.getName());
4747
auto newVariableWrite = rewriter.create<ml_program::GlobalStoreOp>(
48-
op.getLoc(), globalSymbolRef, op.getValue());
48+
op.getLoc(), globalSymbolRef, op.getInput1());
4949
rewriter.replaceOp(op, newVariableWrite);
5050
return success();
5151
}

mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -765,7 +765,7 @@ inline bool CompatibleTypes(const mlir::Type &type,
765765

766766
bool TosaValidation::CheckVariable(Operation *op) {
767767
if (isa<mlir::tosa::VariableOp>(op)) {
768-
auto nameAttr = cast<mlir::StringAttr>(op->getAttr("name"));
768+
mlir::StringAttr nameAttr = cast<mlir::StringAttr>(op->getAttr("name"));
769769

770770
if (variablesMap.count(nameAttr)) {
771771
op->emitOpError() << "name has already been declared";
@@ -784,8 +784,7 @@ bool TosaValidation::CheckVariable(Operation *op) {
784784
bool TosaValidation::CheckVariableReadOrWrite(Operation *op) {
785785
if (isa<mlir::tosa::VariableReadOp>(op) ||
786786
isa<mlir::tosa::VariableWriteOp>(op)) {
787-
auto nameAttr = cast<mlir::StringAttr>(op->getAttr("name"));
788-
787+
mlir::StringAttr nameAttr = cast<mlir::StringAttr>(op->getAttr("name"));
789788
if (!variablesMap.count(nameAttr)) {
790789
op->emitOpError() << "name has not been declared";
791790
return false;

0 commit comments

Comments
 (0)