Skip to content

Commit 8e9ff8e

Browse files
authored
[mlir][tosa] Align Variable ops to match with TOSA v1.0 spec (#130680)
- updated AnyType:$value to Tosa_Tensor:$input1 and Tosa_Tensor:$output1 for VariableWrite and VriableRead Operators - updated description discrepancies - note: in the TOSA spec, we had var_shape attr, but it's already included in the TypeAttr:$type in MLIR Signed-off-by: Jerry Ge <[email protected]>
1 parent 11a3de7 commit 8e9ff8e

File tree

10 files changed

+51
-46
lines changed

10 files changed

+51
-46
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -109,9 +109,9 @@ def Tosa_VariableOp : Tosa_Op<"variable", []> {
109109
}
110110

111111
//===----------------------------------------------------------------------===//
112-
// Operator: variable.write
112+
// Operator: variable_write
113113
//===----------------------------------------------------------------------===//
114-
def Tosa_VariableWriteOp : Tosa_Op<"variable.write", []> {
114+
def Tosa_VariableWriteOp : Tosa_Op<"variable_write", []> {
115115
let summary = "write_buffer operator";
116116

117117
let description = [{
@@ -120,7 +120,7 @@ def Tosa_VariableWriteOp : Tosa_Op<"variable.write", []> {
120120

121121
let arguments = (ins
122122
SymbolNameAttr:$name,
123-
AnyType:$value
123+
Tosa_Tensor:$input1
124124
);
125125

126126
list<Availability> availability = [
@@ -129,14 +129,14 @@ def Tosa_VariableWriteOp : Tosa_Op<"variable.write", []> {
129129
];
130130

131131
let assemblyFormat = [{
132-
$name attr-dict `,` $value `:` type($value)
132+
$name attr-dict `,` $input1 `:` type($input1)
133133
}];
134134
}
135135

136136
//===----------------------------------------------------------------------===//
137-
// Operator: variable.read
137+
// Operator: variable_read
138138
//===----------------------------------------------------------------------===//
139-
def Tosa_VariableReadOp : Tosa_Op<"variable.read", []> {
139+
def Tosa_VariableReadOp : Tosa_Op<"variable_read", []> {
140140
let summary = "read_buffer operator";
141141

142142
let description = [{
@@ -148,7 +148,7 @@ def Tosa_VariableReadOp : Tosa_Op<"variable.read", []> {
148148
);
149149

150150
let results = (outs
151-
AnyType:$value
151+
Tosa_Tensor:$output1
152152
);
153153

154154
list<Availability> availability = [
@@ -157,7 +157,7 @@ def Tosa_VariableReadOp : Tosa_Op<"variable.read", []> {
157157
];
158158

159159
let assemblyFormat = [{
160-
$name attr-dict `:` type($value)
160+
$name attr-dict `:` type($output1)
161161
}];
162162
}
163163

mlir/lib/Conversion/TosaToMLProgram/TosaToMLProgram.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ class VariableWriteOpConverter
4545
auto globalSymbolRef =
4646
SymbolRefAttr::get(rewriter.getContext(), op.getName());
4747
auto newVariableWrite = rewriter.create<ml_program::GlobalStoreOp>(
48-
op.getLoc(), globalSymbolRef, op.getValue());
48+
op.getLoc(), globalSymbolRef, op.getInput1());
4949
rewriter.replaceOp(op, newVariableWrite);
5050
return success();
5151
}

mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,12 @@ LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::VariableOp op) {
226226
return failure();
227227
}
228228

229+
template <>
230+
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::VariableWriteOp op) {
231+
addValue(op.getInput1());
232+
return success();
233+
}
234+
229235
template <>
230236
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::IfOp op) {
231237
addValue(op.getCondition());
@@ -280,6 +286,7 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
280286
POPULATE_PROFILE_INFO_CUSTOM(Rescale)
281287
POPULATE_PROFILE_INFO_CUSTOM(MatMul)
282288
POPULATE_PROFILE_INFO_CUSTOM(Variable)
289+
POPULATE_PROFILE_INFO_CUSTOM(VariableWrite)
283290
POPULATE_PROFILE_INFO_CUSTOM(If)
284291
POPULATE_PROFILE_INFO_CUSTOM(While)
285292

@@ -334,7 +341,6 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
334341
POPULATE_PROFILE_INFO_COMMON(Reverse)
335342
POPULATE_PROFILE_INFO_COMMON(Identity)
336343
POPULATE_PROFILE_INFO_COMMON(VariableRead)
337-
POPULATE_PROFILE_INFO_COMMON(VariableWrite)
338344

339345
// Type Invariant Extension, a capability extension that is independent
340346
// of the data type, meaning any compatible type can be used. No type

mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -767,7 +767,7 @@ inline bool CompatibleTypes(const mlir::Type &type,
767767

768768
bool TosaValidation::CheckVariable(Operation *op) {
769769
if (isa<mlir::tosa::VariableOp>(op)) {
770-
auto nameAttr = cast<mlir::StringAttr>(op->getAttr("name"));
770+
mlir::StringAttr nameAttr = cast<mlir::StringAttr>(op->getAttr("name"));
771771

772772
if (variablesMap.count(nameAttr)) {
773773
op->emitOpError() << "name has already been declared";
@@ -786,8 +786,7 @@ bool TosaValidation::CheckVariable(Operation *op) {
786786
bool TosaValidation::CheckVariableReadOrWrite(Operation *op) {
787787
if (isa<mlir::tosa::VariableReadOp>(op) ||
788788
isa<mlir::tosa::VariableWriteOp>(op)) {
789-
auto nameAttr = cast<mlir::StringAttr>(op->getAttr("name"));
790-
789+
mlir::StringAttr nameAttr = cast<mlir::StringAttr>(op->getAttr("name"));
791790
if (!variablesMap.count(nameAttr)) {
792791
op->emitOpError() << "name has not been declared";
793792
return false;

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
// check that -tosa-validate of stateful ops kick in
77
func.func @test_variable_write_shape(%arg0: tensor<1x4x8xi8>) -> () {
88
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
9-
// expected-error@+1 {{'tosa.variable.write' op operand type does not equal variable type}}
10-
tosa.variable.write @stored_var, %arg0 : tensor<1x4x8xi8>
9+
// expected-error@+1 {{'tosa.variable_write' op operand type does not equal variable type}}
10+
tosa.variable_write @stored_var, %arg0 : tensor<1x4x8xi8>
1111
return
1212
}
1313

mlir/test/Conversion/TosaToMLProgram/tosa-to-mlprogram.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@ module {
55
tosa.variable @var_x = dense<7.000000e+00> : tensor<1xf32>
66
func.func @test_stateful_ops(%arg0: tensor<1xf32>) -> (tensor<1xf32>) {
77
// CHECK: ml_program.global_store @var_x = %arg0 : tensor<1xf32>
8-
tosa.variable.write @var_x, %arg0 : tensor<1xf32>
8+
tosa.variable_write @var_x, %arg0 : tensor<1xf32>
99
// CHECK: %[[LOAD:.+]] = ml_program.global_load @var_x : tensor<1xf32>
10-
%0 = tosa.variable.read @var_x : tensor<1xf32>
10+
%0 = tosa.variable_read @var_x : tensor<1xf32>
1111
return %0 : tensor<1xf32>
1212
}
1313
}

mlir/test/Dialect/Tosa/invalid.mlir

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -626,35 +626,35 @@ func.func @test_variable_duplicates(%arg0: tensor<2x4x8xi8>) -> () {
626626

627627
func.func @test_variable_read_type(%arg0: tensor<2x4x8xi8>) -> () {
628628
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
629-
// expected-error@+1 {{'tosa.variable.read' op result type does not equal variable type}}
630-
%0 = tosa.variable.read @stored_var : tensor<2x4x8xi16>
629+
// expected-error@+1 {{'tosa.variable_read' op illegal: operand/result data types not supported}}
630+
%0 = tosa.variable_read @stored_var : tensor<2x4x8xi16>
631631
return
632632
}
633633

634634
// -----
635635

636636
func.func @test_variable_read_shape(%arg0: tensor<2x4x8xi8>) -> () {
637637
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
638-
// expected-error@+1 {{'tosa.variable.read' op result type does not equal variable type}}
639-
%0 = tosa.variable.read @stored_var : tensor<1x4x8xi32>
638+
// expected-error@+1 {{'tosa.variable_read' op illegal: operand/result data types not supported}}
639+
%0 = tosa.variable_read @stored_var : tensor<1x4x8xi32>
640640
return
641641
}
642642

643643
// -----
644644

645645
func.func @test_variable_write_type(%arg0: tensor<2x4x8xi16>) -> () {
646646
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
647-
// expected-error@+1 {{'tosa.variable.write' op operand type does not equal variable type}}
648-
tosa.variable.write @stored_var, %arg0 : tensor<2x4x8xi16>
647+
// expected-error@+1 {{'tosa.variable_write' op illegal: operand/result data types not supported}}
648+
tosa.variable_write @stored_var, %arg0 : tensor<2x4x8xi16>
649649
return
650650
}
651651

652652
// -----
653653

654654
func.func @test_variable_write_shape(%arg0: tensor<1x4x8xi8>) -> () {
655655
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
656-
// expected-error@+1 {{'tosa.variable.write' op operand type does not equal variable type}}
657-
tosa.variable.write @stored_var, %arg0 : tensor<1x4x8xi8>
656+
// expected-error@+1 {{'tosa.variable_write' op operand type does not equal variable type}}
657+
tosa.variable_write @stored_var, %arg0 : tensor<1x4x8xi8>
658658
return
659659
}
660660

mlir/test/Dialect/Tosa/invalid_extension.mlir

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -313,17 +313,17 @@ func.func @test_identity(%arg0: tensor<13x21x3xi4>) -> tensor<13x21x3xi4> {
313313
func.func @test_variable_read_type(%arg0: tensor<2x4x8xi8>) -> () {
314314
// expected-error@+1 {{'tosa.variable' op illegal: requires [variable] but not enabled in target}}
315315
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
316-
// expected-error@+1 {{'tosa.variable.read' op illegal: requires [variable]}}
317-
%0 = tosa.variable.read @stored_var : tensor<2x4x8xi16>
316+
// expected-error@+1 {{'tosa.variable_read' op illegal: requires [variable]}}
317+
%0 = tosa.variable_read @stored_var : tensor<2x4x8xi8>
318318
return
319319
}
320320

321321
// -----
322-
func.func @test_variable_write_type(%arg0: tensor<2x4x8xi16>) -> () {
322+
func.func @test_variable_write_type(%arg0: tensor<2x4x8xi8>) -> () {
323323
// expected-error@+1 {{'tosa.variable' op illegal: requires [variable] but not enabled in target}}
324324
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
325-
// expected-error@+1 {{'tosa.variable.write' op illegal: requires [variable]}}
326-
tosa.variable.write @stored_var, %arg0 : tensor<2x4x8xi16>
325+
// expected-error@+1 {{'tosa.variable_write' op illegal: requires [variable]}}
326+
tosa.variable_write @stored_var, %arg0 : tensor<2x4x8xi8>
327327
return
328328
}
329329

mlir/test/Dialect/Tosa/level_check.mlir

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1089,10 +1089,10 @@ func.func @test_scatter_tensor_size_invalid(%arg0: tensor<13x210000000x3xf32>, %
10891089

10901090
func.func @test_variable_read_write_tensor_size_invalid() -> () {
10911091
tosa.variable @stored_var = dense<3.14> : tensor<536870912xf32>
1092-
// expected-error@+1 {{'tosa.variable.read' op failed level check: result tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}}
1093-
%0 = tosa.variable.read @stored_var : tensor<536870912xf32>
1094-
// expected-error@+1 {{'tosa.variable.write' op failed level check: operand tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}}
1095-
tosa.variable.write @stored_var, %0 : tensor<536870912xf32>
1092+
// expected-error@+1 {{'tosa.variable_read' op failed level check: result tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}}
1093+
%0 = tosa.variable_read @stored_var : tensor<536870912xf32>
1094+
// expected-error@+1 {{'tosa.variable_write' op failed level check: operand tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}}
1095+
tosa.variable_write @stored_var, %0 : tensor<536870912xf32>
10961096
return
10971097
}
10981098

@@ -1157,10 +1157,10 @@ func.func @test_cond_if_rank_invalid(%arg0: tensor<1x1x1x1x1x1x1x1xf32>, %arg1:
11571157
func.func @test_variable_read_write_rank_invalid() -> () {
11581158
// expected-error@+1 {{'tosa.variable' op failed level check: attribute rank(shape) <= MAX_RANK}}
11591159
tosa.variable @stored_var = dense<3.14> : tensor<1x1x1x1x1x1x1x1xf32>
1160-
// expected-error@+1 {{'tosa.variable.read' op failed level check: result rank(shape) <= MAX_RANK}}
1161-
%0 = tosa.variable.read @stored_var : tensor<1x1x1x1x1x1x1x1xf32>
1162-
// expected-error@+1 {{'tosa.variable.write' op failed level check: operand rank(shape) <= MAX_RANK}}
1163-
tosa.variable.write @stored_var, %0 : tensor<1x1x1x1x1x1x1x1xf32>
1160+
// expected-error@+1 {{'tosa.variable_read' op failed level check: result rank(shape) <= MAX_RANK}}
1161+
%0 = tosa.variable_read @stored_var : tensor<1x1x1x1x1x1x1x1xf32>
1162+
// expected-error@+1 {{'tosa.variable_write' op failed level check: operand rank(shape) <= MAX_RANK}}
1163+
tosa.variable_write @stored_var, %0 : tensor<1x1x1x1x1x1x1x1xf32>
11641164
return
11651165
}
11661166

mlir/test/Dialect/Tosa/variables.mlir

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,12 @@
88
func.func @test_variable_scalar(%arg0: tensor<f32>) -> () {
99
// CHECK: tosa.variable @stored_var = dense<3.140000e+00> : tensor<f32>
1010
tosa.variable @stored_var = dense<3.14> : tensor<f32>
11-
// CHECK: %[[STORED_VAL:.*]] = tosa.variable.read @stored_var : tensor<f32>
12-
%0 = tosa.variable.read @stored_var : tensor<f32>
11+
// CHECK: %[[STORED_VAL:.*]] = tosa.variable_read @stored_var : tensor<f32>
12+
%0 = tosa.variable_read @stored_var : tensor<f32>
1313
// CHECK: %[[RESULT_ADD:.*]] = tosa.add %[[ADD_VAL]], %[[STORED_VAL]] : (tensor<f32>, tensor<f32>) -> tensor<f32>
1414
%1 = "tosa.add"(%arg0, %0) : (tensor<f32>, tensor<f32>) -> tensor<f32>
15-
// CHECK: tosa.variable.write @stored_var, %[[RESULT_ADD]] : tensor<f32>
16-
tosa.variable.write @stored_var, %1 : tensor<f32>
15+
// CHECK: tosa.variable_write @stored_var, %[[RESULT_ADD]] : tensor<f32>
16+
tosa.variable_write @stored_var, %1 : tensor<f32>
1717
return
1818
}
1919

@@ -23,11 +23,11 @@ func.func @test_variable_scalar(%arg0: tensor<f32>) -> () {
2323
func.func @test_variable_tensor(%arg0: tensor<2x4x8xi32>) -> () {
2424
// CHECK: tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
2525
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
26-
// CHECK: %[[STORED_VAL:.*]] = tosa.variable.read @stored_var : tensor<2x4x8xi32>
27-
%0 = tosa.variable.read @stored_var : tensor<2x4x8xi32>
26+
// CHECK: %[[STORED_VAL:.*]] = tosa.variable_read @stored_var : tensor<2x4x8xi32>
27+
%0 = tosa.variable_read @stored_var : tensor<2x4x8xi32>
2828
// CHECK: %[[RESULT_ADD:.*]] = tosa.add %[[ADD_VAL]], %[[STORED_VAL]] : (tensor<2x4x8xi32>, tensor<2x4x8xi32>) -> tensor<2x4x8xi32>
2929
%1 = "tosa.add"(%arg0, %0) : (tensor<2x4x8xi32>, tensor<2x4x8xi32>) -> tensor<2x4x8xi32>
30-
// CHECK: tosa.variable.write @stored_var, %[[RESULT_ADD]] : tensor<2x4x8xi32>
31-
tosa.variable.write @stored_var, %1 : tensor<2x4x8xi32>
30+
// CHECK: tosa.variable_write @stored_var, %[[RESULT_ADD]] : tensor<2x4x8xi32>
31+
tosa.variable_write @stored_var, %1 : tensor<2x4x8xi32>
3232
return
3333
}

0 commit comments

Comments
 (0)