Skip to content

Commit af5e12f

Browse files
Jerry-Gepsunn
authored andcommitted
[mlir][tosa] Align Variable ops to match with TOSA v1.0 spec (llvm#130680)
- updated AnyType:$value to Tosa_Tensor:$input1 and Tosa_Tensor:$output1 for VariableWrite and VriableRead Operators - updated description discrepancies - note: in the TOSA spec, we had var_shape attr, but it's already included in the TypeAttr:$type in MLIR TF_REFSPEC: refs/changes/61/1023261/3 (cherry picked from commit 8e9ff8e) Signed-off-by: Jerry Ge <[email protected]> Change-Id: I4cd0348cd4e306dbc2e0e53a89a9404d91fb44d4
1 parent b43b545 commit af5e12f

File tree

10 files changed

+51
-46
lines changed

10 files changed

+51
-46
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -109,9 +109,9 @@ def Tosa_VariableOp : Tosa_Op<"variable", []> {
109109
}
110110

111111
//===----------------------------------------------------------------------===//
112-
// Operator: variable.write
112+
// Operator: variable_write
113113
//===----------------------------------------------------------------------===//
114-
def Tosa_VariableWriteOp : Tosa_Op<"variable.write", []> {
114+
def Tosa_VariableWriteOp : Tosa_Op<"variable_write", []> {
115115
let summary = "write_buffer operator";
116116

117117
let description = [{
@@ -120,7 +120,7 @@ def Tosa_VariableWriteOp : Tosa_Op<"variable.write", []> {
120120

121121
let arguments = (ins
122122
SymbolNameAttr:$name,
123-
AnyType:$value
123+
Tosa_Tensor:$input1
124124
);
125125

126126
list<Availability> availability = [
@@ -129,14 +129,14 @@ def Tosa_VariableWriteOp : Tosa_Op<"variable.write", []> {
129129
];
130130

131131
let assemblyFormat = [{
132-
$name attr-dict `,` $value `:` type($value)
132+
$name attr-dict `,` $input1 `:` type($input1)
133133
}];
134134
}
135135

136136
//===----------------------------------------------------------------------===//
137-
// Operator: variable.read
137+
// Operator: variable_read
138138
//===----------------------------------------------------------------------===//
139-
def Tosa_VariableReadOp : Tosa_Op<"variable.read", []> {
139+
def Tosa_VariableReadOp : Tosa_Op<"variable_read", []> {
140140
let summary = "read_buffer operator";
141141

142142
let description = [{
@@ -148,7 +148,7 @@ def Tosa_VariableReadOp : Tosa_Op<"variable.read", []> {
148148
);
149149

150150
let results = (outs
151-
AnyType:$value
151+
Tosa_Tensor:$output1
152152
);
153153

154154
list<Availability> availability = [
@@ -157,7 +157,7 @@ def Tosa_VariableReadOp : Tosa_Op<"variable.read", []> {
157157
];
158158

159159
let assemblyFormat = [{
160-
$name attr-dict `:` type($value)
160+
$name attr-dict `:` type($output1)
161161
}];
162162
}
163163

mlir/lib/Conversion/TosaToMLProgram/TosaToMLProgram.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ class VariableWriteOpConverter
4545
auto globalSymbolRef =
4646
SymbolRefAttr::get(rewriter.getContext(), op.getName());
4747
auto newVariableWrite = rewriter.create<ml_program::GlobalStoreOp>(
48-
op.getLoc(), globalSymbolRef, op.getValue());
48+
op.getLoc(), globalSymbolRef, op.getInput1());
4949
rewriter.replaceOp(op, newVariableWrite);
5050
return success();
5151
}

mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,12 @@ LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::VariableOp op) {
226226
return failure();
227227
}
228228

229+
template <>
230+
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::VariableWriteOp op) {
231+
addValue(op.getInput1());
232+
return success();
233+
}
234+
229235
template <>
230236
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::IfOp op) {
231237
addValue(op.getCondition());
@@ -280,6 +286,7 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
280286
POPULATE_PROFILE_INFO_CUSTOM(Rescale)
281287
POPULATE_PROFILE_INFO_CUSTOM(MatMul)
282288
POPULATE_PROFILE_INFO_CUSTOM(Variable)
289+
POPULATE_PROFILE_INFO_CUSTOM(VariableWrite)
283290
POPULATE_PROFILE_INFO_CUSTOM(If)
284291
POPULATE_PROFILE_INFO_CUSTOM(While)
285292

@@ -334,7 +341,6 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
334341
POPULATE_PROFILE_INFO_COMMON(Reverse)
335342
POPULATE_PROFILE_INFO_COMMON(Identity)
336343
POPULATE_PROFILE_INFO_COMMON(VariableRead)
337-
POPULATE_PROFILE_INFO_COMMON(VariableWrite)
338344

339345
// Type Invariant Extension, a capability extension that is independent
340346
// of the data type, meaning any compatible type can be used. No type

mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -768,7 +768,7 @@ inline bool CompatibleTypes(const mlir::Type &type,
768768

769769
bool TosaValidation::CheckVariable(Operation *op) {
770770
if (isa<mlir::tosa::VariableOp>(op)) {
771-
auto nameAttr = cast<mlir::StringAttr>(op->getAttr("name"));
771+
mlir::StringAttr nameAttr = cast<mlir::StringAttr>(op->getAttr("name"));
772772

773773
if (variablesMap.count(nameAttr)) {
774774
op->emitOpError() << "name has already been declared";
@@ -787,8 +787,7 @@ bool TosaValidation::CheckVariable(Operation *op) {
787787
bool TosaValidation::CheckVariableReadOrWrite(Operation *op) {
788788
if (isa<mlir::tosa::VariableReadOp>(op) ||
789789
isa<mlir::tosa::VariableWriteOp>(op)) {
790-
auto nameAttr = cast<mlir::StringAttr>(op->getAttr("name"));
791-
790+
mlir::StringAttr nameAttr = cast<mlir::StringAttr>(op->getAttr("name"));
792791
if (!variablesMap.count(nameAttr)) {
793792
op->emitOpError() << "name has not been declared";
794793
return false;

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
// check that -tosa-validate of stateful ops kick in
77
func.func @test_variable_write_shape(%arg0: tensor<1x4x8xi8>) -> () {
88
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
9-
// expected-error@+1 {{'tosa.variable.write' op operand type does not equal variable type}}
10-
tosa.variable.write @stored_var, %arg0 : tensor<1x4x8xi8>
9+
// expected-error@+1 {{'tosa.variable_write' op operand type does not equal variable type}}
10+
tosa.variable_write @stored_var, %arg0 : tensor<1x4x8xi8>
1111
return
1212
}
1313

mlir/test/Conversion/TosaToMLProgram/tosa-to-mlprogram.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@ module {
55
tosa.variable @var_x = dense<7.000000e+00> : tensor<1xf32>
66
func.func @test_stateful_ops(%arg0: tensor<1xf32>) -> (tensor<1xf32>) {
77
// CHECK: ml_program.global_store @var_x = %arg0 : tensor<1xf32>
8-
tosa.variable.write @var_x, %arg0 : tensor<1xf32>
8+
tosa.variable_write @var_x, %arg0 : tensor<1xf32>
99
// CHECK: %[[LOAD:.+]] = ml_program.global_load @var_x : tensor<1xf32>
10-
%0 = tosa.variable.read @var_x : tensor<1xf32>
10+
%0 = tosa.variable_read @var_x : tensor<1xf32>
1111
return %0 : tensor<1xf32>
1212
}
1313
}

mlir/test/Dialect/Tosa/invalid.mlir

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -585,35 +585,35 @@ func.func @test_variable_duplicates(%arg0: tensor<2x4x8xi8>) -> () {
585585

586586
func.func @test_variable_read_type(%arg0: tensor<2x4x8xi8>) -> () {
587587
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
588-
// expected-error@+1 {{'tosa.variable.read' op result type does not equal variable type}}
589-
%0 = tosa.variable.read @stored_var : tensor<2x4x8xi16>
588+
// expected-error@+1 {{'tosa.variable_read' op illegal: operand/result data types not supported}}
589+
%0 = tosa.variable_read @stored_var : tensor<2x4x8xi16>
590590
return
591591
}
592592

593593
// -----
594594

595595
func.func @test_variable_read_shape(%arg0: tensor<2x4x8xi8>) -> () {
596596
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
597-
// expected-error@+1 {{'tosa.variable.read' op result type does not equal variable type}}
598-
%0 = tosa.variable.read @stored_var : tensor<1x4x8xi32>
597+
// expected-error@+1 {{'tosa.variable_read' op illegal: operand/result data types not supported}}
598+
%0 = tosa.variable_read @stored_var : tensor<1x4x8xi32>
599599
return
600600
}
601601

602602
// -----
603603

604604
func.func @test_variable_write_type(%arg0: tensor<2x4x8xi16>) -> () {
605605
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
606-
// expected-error@+1 {{'tosa.variable.write' op operand type does not equal variable type}}
607-
tosa.variable.write @stored_var, %arg0 : tensor<2x4x8xi16>
606+
// expected-error@+1 {{'tosa.variable_write' op illegal: operand/result data types not supported}}
607+
tosa.variable_write @stored_var, %arg0 : tensor<2x4x8xi16>
608608
return
609609
}
610610

611611
// -----
612612

613613
func.func @test_variable_write_shape(%arg0: tensor<1x4x8xi8>) -> () {
614614
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
615-
// expected-error@+1 {{'tosa.variable.write' op operand type does not equal variable type}}
616-
tosa.variable.write @stored_var, %arg0 : tensor<1x4x8xi8>
615+
// expected-error@+1 {{'tosa.variable_write' op operand type does not equal variable type}}
616+
tosa.variable_write @stored_var, %arg0 : tensor<1x4x8xi8>
617617
return
618618
}
619619

mlir/test/Dialect/Tosa/invalid_extension.mlir

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -313,17 +313,17 @@ func.func @test_identity(%arg0: tensor<13x21x3xi4>) -> tensor<13x21x3xi4> {
313313
func.func @test_variable_read_type(%arg0: tensor<2x4x8xi8>) -> () {
314314
// expected-error@+1 {{'tosa.variable' op illegal: requires [variable] but not enabled in target}}
315315
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
316-
// expected-error@+1 {{'tosa.variable.read' op illegal: requires [variable]}}
317-
%0 = tosa.variable.read @stored_var : tensor<2x4x8xi16>
316+
// expected-error@+1 {{'tosa.variable_read' op illegal: requires [variable]}}
317+
%0 = tosa.variable_read @stored_var : tensor<2x4x8xi8>
318318
return
319319
}
320320

321321
// -----
322-
func.func @test_variable_write_type(%arg0: tensor<2x4x8xi16>) -> () {
322+
func.func @test_variable_write_type(%arg0: tensor<2x4x8xi8>) -> () {
323323
// expected-error@+1 {{'tosa.variable' op illegal: requires [variable] but not enabled in target}}
324324
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
325-
// expected-error@+1 {{'tosa.variable.write' op illegal: requires [variable]}}
326-
tosa.variable.write @stored_var, %arg0 : tensor<2x4x8xi16>
325+
// expected-error@+1 {{'tosa.variable_write' op illegal: requires [variable]}}
326+
tosa.variable_write @stored_var, %arg0 : tensor<2x4x8xi8>
327327
return
328328
}
329329

mlir/test/Dialect/Tosa/level_check.mlir

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1097,10 +1097,10 @@ func.func @test_scatter_tensor_size_invalid(%arg0: tensor<13x260000000x3xf32>, %
10971097

10981098
func.func @test_variable_read_write_tensor_size_invalid() -> () {
10991099
tosa.variable @stored_var = dense<3.14> : tensor<536870912xf32>
1100-
// expected-error@+1 {{'tosa.variable.read' op failed level check: result tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}}
1101-
%0 = tosa.variable.read @stored_var : tensor<536870912xf32>
1102-
// expected-error@+1 {{'tosa.variable.write' op failed level check: operand tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}}
1103-
tosa.variable.write @stored_var, %0 : tensor<536870912xf32>
1100+
// expected-error@+1 {{'tosa.variable_read' op failed level check: result tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}}
1101+
%0 = tosa.variable_read @stored_var : tensor<536870912xf32>
1102+
// expected-error@+1 {{'tosa.variable_write' op failed level check: operand tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}}
1103+
tosa.variable_write @stored_var, %0 : tensor<536870912xf32>
11041104
return
11051105
}
11061106

@@ -1165,10 +1165,10 @@ func.func @test_cond_if_rank_invalid(%arg0: tensor<1x1x1x1x1x1x1x1xf32>, %arg1:
11651165
func.func @test_variable_read_write_rank_invalid() -> () {
11661166
// expected-error@+1 {{'tosa.variable' op failed level check: attribute rank(shape) <= MAX_RANK}}
11671167
tosa.variable @stored_var = dense<3.14> : tensor<1x1x1x1x1x1x1x1xf32>
1168-
// expected-error@+1 {{'tosa.variable.read' op failed level check: result rank(shape) <= MAX_RANK}}
1169-
%0 = tosa.variable.read @stored_var : tensor<1x1x1x1x1x1x1x1xf32>
1170-
// expected-error@+1 {{'tosa.variable.write' op failed level check: operand rank(shape) <= MAX_RANK}}
1171-
tosa.variable.write @stored_var, %0 : tensor<1x1x1x1x1x1x1x1xf32>
1168+
// expected-error@+1 {{'tosa.variable_read' op failed level check: result rank(shape) <= MAX_RANK}}
1169+
%0 = tosa.variable_read @stored_var : tensor<1x1x1x1x1x1x1x1xf32>
1170+
// expected-error@+1 {{'tosa.variable_write' op failed level check: operand rank(shape) <= MAX_RANK}}
1171+
tosa.variable_write @stored_var, %0 : tensor<1x1x1x1x1x1x1x1xf32>
11721172
return
11731173
}
11741174

mlir/test/Dialect/Tosa/variables.mlir

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,12 @@
88
func.func @test_variable_scalar(%arg0: tensor<f32>) -> () {
99
// CHECK: tosa.variable @stored_var = dense<3.140000e+00> : tensor<f32>
1010
tosa.variable @stored_var = dense<3.14> : tensor<f32>
11-
// CHECK: %[[STORED_VAL:.*]] = tosa.variable.read @stored_var : tensor<f32>
12-
%0 = tosa.variable.read @stored_var : tensor<f32>
11+
// CHECK: %[[STORED_VAL:.*]] = tosa.variable_read @stored_var : tensor<f32>
12+
%0 = tosa.variable_read @stored_var : tensor<f32>
1313
// CHECK: %[[RESULT_ADD:.*]] = tosa.add %[[ADD_VAL]], %[[STORED_VAL]] : (tensor<f32>, tensor<f32>) -> tensor<f32>
1414
%1 = "tosa.add"(%arg0, %0) : (tensor<f32>, tensor<f32>) -> tensor<f32>
15-
// CHECK: tosa.variable.write @stored_var, %[[RESULT_ADD]] : tensor<f32>
16-
tosa.variable.write @stored_var, %1 : tensor<f32>
15+
// CHECK: tosa.variable_write @stored_var, %[[RESULT_ADD]] : tensor<f32>
16+
tosa.variable_write @stored_var, %1 : tensor<f32>
1717
return
1818
}
1919

@@ -23,11 +23,11 @@ func.func @test_variable_scalar(%arg0: tensor<f32>) -> () {
2323
func.func @test_variable_tensor(%arg0: tensor<2x4x8xi32>) -> () {
2424
// CHECK: tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
2525
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
26-
// CHECK: %[[STORED_VAL:.*]] = tosa.variable.read @stored_var : tensor<2x4x8xi32>
27-
%0 = tosa.variable.read @stored_var : tensor<2x4x8xi32>
26+
// CHECK: %[[STORED_VAL:.*]] = tosa.variable_read @stored_var : tensor<2x4x8xi32>
27+
%0 = tosa.variable_read @stored_var : tensor<2x4x8xi32>
2828
// CHECK: %[[RESULT_ADD:.*]] = tosa.add %[[ADD_VAL]], %[[STORED_VAL]] : (tensor<2x4x8xi32>, tensor<2x4x8xi32>) -> tensor<2x4x8xi32>
2929
%1 = "tosa.add"(%arg0, %0) : (tensor<2x4x8xi32>, tensor<2x4x8xi32>) -> tensor<2x4x8xi32>
30-
// CHECK: tosa.variable.write @stored_var, %[[RESULT_ADD]] : tensor<2x4x8xi32>
31-
tosa.variable.write @stored_var, %1 : tensor<2x4x8xi32>
30+
// CHECK: tosa.variable_write @stored_var, %[[RESULT_ADD]] : tensor<2x4x8xi32>
31+
tosa.variable_write @stored_var, %1 : tensor<2x4x8xi32>
3232
return
3333
}

0 commit comments

Comments
 (0)