Skip to content

Commit 0251588

Browse files
committed
[mlir][tosa] Align Variable ops to match with TOSA v1.0 spec
* updated SymbolNameAttr:$name to I32Attr:uid * updated description discrepancies * note: in the TOSA spec, we had var_shape attr, but it's already included in the TypeAttr:$type in MLIR Signed-off-by: Jerry Ge <[email protected]> Change-Id: I4cd0348cd4e306dbc2e0e53a89a9404d91fb44d4
1 parent 3f30207 commit 0251588

File tree

9 files changed

+74
-63
lines changed

9 files changed

+74
-63
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -86,12 +86,13 @@ def Tosa_VariableOp : Tosa_Op<"variable", []> {
8686
let summary = "Defines a variable";
8787

8888
let description = [{
89-
Defines a new TOSA variable. This is a mutable value.
89+
Defines a new TOSA variable.
90+
This is a persistent mutable value across multiple TOSA graph invocations.
9091
Modifications are expressed using read/write semantics.
9192
}];
9293

9394
let arguments = (ins
94-
SymbolNameAttr:$name,
95+
I32Attr:$uid,
9596
TypeAttr:$type,
9697
OptionalAttr<AnyAttr>:$initial_value
9798
);
@@ -102,7 +103,7 @@ def Tosa_VariableOp : Tosa_Op<"variable", []> {
102103
];
103104

104105
let assemblyFormat = [{
105-
$name
106+
$uid
106107
attr-dict
107108
custom<TypeOrAttr>($type, $initial_value)
108109
}];
@@ -115,11 +116,11 @@ def Tosa_VariableWriteOp : Tosa_Op<"variable.write", []> {
115116
let summary = "write_buffer operator";
116117

117118
let description = [{
118-
Assigns a value to pseudo-buffer resource holding a mutable tensor.
119+
Assigns a value to the pseudo-buffer resource holding a persistent mutable tensor.
119120
}];
120121

121122
let arguments = (ins
122-
SymbolNameAttr:$name,
123+
I32Attr:$uid,
123124
AnyType:$value
124125
);
125126

@@ -129,7 +130,7 @@ def Tosa_VariableWriteOp : Tosa_Op<"variable.write", []> {
129130
];
130131

131132
let assemblyFormat = [{
132-
$name attr-dict `,` $value `:` type($value)
133+
$uid attr-dict `,` $value `:` type($value)
133134
}];
134135
}
135136

@@ -140,11 +141,11 @@ def Tosa_VariableReadOp : Tosa_Op<"variable.read", []> {
140141
let summary = "read_buffer operator";
141142

142143
let description = [{
143-
Reads the value from a pseudo-buffer resource holding a mutable tensor.
144+
Reads the value from a pseudo-buffer resource holding a persistent mutable tensor.
144145
}];
145146

146147
let arguments = (ins
147-
SymbolNameAttr:$name
148+
I32Attr:$uid
148149
);
149150

150151
let results = (outs
@@ -157,7 +158,7 @@ def Tosa_VariableReadOp : Tosa_Op<"variable.read", []> {
157158
];
158159

159160
let assemblyFormat = [{
160-
$name attr-dict `:` type($value)
161+
$uid attr-dict `:` type($value)
161162
}];
162163
}
163164

mlir/lib/Conversion/TosaToMLProgram/TosaToMLProgram.cpp

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,12 @@ class VariableOpConverter : public OpRewritePattern<tosa::VariableOp> {
2626

2727
LogicalResult matchAndRewrite(tosa::VariableOp op,
2828
PatternRewriter &rewriter) const final {
29+
30+
std::string uid = std::to_string(op.getUid());
31+
llvm::StringRef uidStringRef(uid);
32+
2933
auto newVariable = rewriter.create<mlir::ml_program::GlobalOp>(
30-
op.getLoc(), op.getName(), op.getType(), /*is_mutable=*/true,
34+
op.getLoc(), uidStringRef, op.getType(), /*is_mutable=*/true,
3135
op.getInitialValueAttr(), /*sym_visibility=*/nullptr);
3236
newVariable.setPrivate();
3337
rewriter.replaceOp(op, newVariable);
@@ -42,8 +46,12 @@ class VariableWriteOpConverter
4246

4347
LogicalResult matchAndRewrite(tosa::VariableWriteOp op,
4448
PatternRewriter &rewriter) const final {
49+
50+
std::string uid = std::to_string(op.getUid());
51+
llvm::StringRef uidStringRef(uid);
52+
4553
auto globalSymbolRef =
46-
SymbolRefAttr::get(rewriter.getContext(), op.getName());
54+
SymbolRefAttr::get(rewriter.getContext(), uidStringRef);
4755
auto newVariableWrite = rewriter.create<ml_program::GlobalStoreOp>(
4856
op.getLoc(), globalSymbolRef, op.getValue());
4957
rewriter.replaceOp(op, newVariableWrite);
@@ -57,8 +65,11 @@ class VariableReadOpConverter : public OpRewritePattern<tosa::VariableReadOp> {
5765

5866
LogicalResult matchAndRewrite(tosa::VariableReadOp op,
5967
PatternRewriter &rewriter) const final {
68+
std::string uid = std::to_string(op.getUid());
69+
llvm::StringRef uidStringRef(uid);
70+
6071
auto globalSymbolRef =
61-
SymbolRefAttr::get(rewriter.getContext(), op.getName());
72+
SymbolRefAttr::get(rewriter.getContext(), uidStringRef);
6273
auto newVariableRead = rewriter.create<ml_program::GlobalLoadOp>(
6374
op.getLoc(), op.getType(), globalSymbolRef);
6475
rewriter.replaceOp(op, newVariableRead);

mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -449,7 +449,7 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
449449

450450
SmallVector<std::function<LogicalResult(Operation *)>> constCheckers;
451451
TosaLevel tosaLevel;
452-
DenseMap<StringAttr, mlir::Type> variablesMap;
452+
DenseMap<IntegerAttr, mlir::Type> variablesMap;
453453
TosaProfileCompliance profileComp;
454454
tosa::TargetEnv targetEnv;
455455
};
@@ -677,17 +677,17 @@ inline bool CompatibleTypes(const mlir::Type &type,
677677

678678
bool TosaValidation::CheckVariable(Operation *op) {
679679
if (isa<mlir::tosa::VariableOp>(op)) {
680-
auto nameAttr = cast<mlir::StringAttr>(op->getAttr("name"));
680+
mlir::IntegerAttr uidAttr = cast<mlir::IntegerAttr>(op->getAttr("uid"));
681681

682-
if (variablesMap.count(nameAttr)) {
682+
if (variablesMap.count(uidAttr)) {
683683
op->emitOpError() << "name has already been declared";
684684
return false;
685685
}
686686

687687
auto typeAttr = cast<mlir::TypeAttr>(op->getAttr("type"));
688688
mlir::Type type = typeAttr.getValue();
689689

690-
variablesMap[nameAttr] = type;
690+
variablesMap[uidAttr] = type;
691691
}
692692

693693
return true;
@@ -696,14 +696,13 @@ bool TosaValidation::CheckVariable(Operation *op) {
696696
bool TosaValidation::CheckVariableReadOrWrite(Operation *op) {
697697
if (isa<mlir::tosa::VariableReadOp>(op) ||
698698
isa<mlir::tosa::VariableWriteOp>(op)) {
699-
auto nameAttr = cast<mlir::StringAttr>(op->getAttr("name"));
700-
701-
if (!variablesMap.count(nameAttr)) {
699+
mlir::IntegerAttr uidAttr = cast<mlir::IntegerAttr>(op->getAttr("uid"));
700+
if (!variablesMap.count(uidAttr)) {
702701
op->emitOpError() << "name has not been declared";
703702
return false;
704703
}
705704

706-
auto varType = variablesMap[nameAttr];
705+
auto varType = variablesMap[uidAttr];
707706

708707
for (auto v : op->getOperands()) {
709708
auto type = v.getType();

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55

66
// check that -tosa-validate of stateful ops kick in
77
func.func @test_variable_write_shape(%arg0: tensor<1x4x8xi32>) -> () {
8-
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
8+
tosa.variable 1 = dense<-1> : tensor<2x4x8xi32>
99
// expected-error@+1 {{'tosa.variable.write' op operand type does not equal variable type}}
10-
tosa.variable.write @stored_var, %arg0 : tensor<1x4x8xi32>
10+
tosa.variable.write 1, %arg0 : tensor<1x4x8xi32>
1111
return
1212
}
1313

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
// RUN: mlir-opt --tosa-to-mlprogram %s -o -| FileCheck %s
22

33
module {
4-
// CHECK: ml_program.global private mutable @var_x(dense<7.000000e+00> : tensor<1xf32>) : tensor<1xf32>
5-
tosa.variable @var_x = dense<7.000000e+00> : tensor<1xf32>
4+
// CHECK: ml_program.global private mutable @"1"(dense<7.000000e+00> : tensor<1xf32>) : tensor<1xf32>
5+
tosa.variable 1 = dense<7.000000e+00> : tensor<1xf32>
66
func.func @test_stateful_ops(%arg0: tensor<1xf32>) -> (tensor<1xf32>) {
7-
// CHECK: ml_program.global_store @var_x = %arg0 : tensor<1xf32>
8-
tosa.variable.write @var_x, %arg0 : tensor<1xf32>
9-
// CHECK: %[[LOAD:.+]] = ml_program.global_load @var_x : tensor<1xf32>
10-
%0 = tosa.variable.read @var_x : tensor<1xf32>
7+
// CHECK: ml_program.global_store @"1" = %arg0 : tensor<1xf32>
8+
tosa.variable.write 1, %arg0 : tensor<1xf32>
9+
// CHECK: %[[LOAD:.+]] = ml_program.global_load @"1" : tensor<1xf32>
10+
%0 = tosa.variable.read 1 : tensor<1xf32>
1111
return %0 : tensor<1xf32>
1212
}
1313
}

mlir/test/Dialect/Tosa/invalid.mlir

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -617,45 +617,45 @@ func.func @test_avg_pool2d_zero_dim_input(%arg0: tensor<1x0x?x9xf32>, %arg1: ten
617617
// -----
618618

619619
func.func @test_variable_duplicates(%arg0: tensor<2x4x8xi32>) -> () {
620-
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
620+
tosa.variable 1 = dense<-1> : tensor<2x4x8xi32>
621621
// expected-error@+1 {{'tosa.variable' op name has already been declared}}
622-
tosa.variable @stored_var : tensor<1x4x8xi32>
622+
tosa.variable 1 : tensor<1x4x8xi32>
623623
return
624624
}
625625

626626
// -----
627627

628628
func.func @test_variable_read_type(%arg0: tensor<2x4x8xi32>) -> () {
629-
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
629+
tosa.variable 1 = dense<-1> : tensor<2x4x8xi32>
630630
// expected-error@+1 {{'tosa.variable.read' op result type does not equal variable type}}
631-
%0 = tosa.variable.read @stored_var : tensor<2x4x8xi16>
631+
%0 = tosa.variable.read 1 : tensor<2x4x8xi16>
632632
return
633633
}
634634

635635
// -----
636636

637637
func.func @test_variable_read_shape(%arg0: tensor<2x4x8xi32>) -> () {
638-
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
638+
tosa.variable 1 = dense<-1> : tensor<2x4x8xi32>
639639
// expected-error@+1 {{'tosa.variable.read' op result type does not equal variable type}}
640-
%0 = tosa.variable.read @stored_var : tensor<1x4x8xi32>
640+
%0 = tosa.variable.read 1 : tensor<1x4x8xi32>
641641
return
642642
}
643643

644644
// -----
645645

646646
func.func @test_variable_write_type(%arg0: tensor<2x4x8xi16>) -> () {
647-
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
647+
tosa.variable 1 = dense<-1> : tensor<2x4x8xi32>
648648
// expected-error@+1 {{'tosa.variable.write' op operand type does not equal variable type}}
649-
tosa.variable.write @stored_var, %arg0 : tensor<2x4x8xi16>
649+
tosa.variable.write 1, %arg0 : tensor<2x4x8xi16>
650650
return
651651
}
652652

653653
// -----
654654

655655
func.func @test_variable_write_shape(%arg0: tensor<1x4x8xi32>) -> () {
656-
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
656+
tosa.variable 1 = dense<-1> : tensor<2x4x8xi32>
657657
// expected-error@+1 {{'tosa.variable.write' op operand type does not equal variable type}}
658-
tosa.variable.write @stored_var, %arg0 : tensor<1x4x8xi32>
658+
tosa.variable.write 1, %arg0 : tensor<1x4x8xi32>
659659
return
660660
}
661661

mlir/test/Dialect/Tosa/invalid_extension.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,18 @@ func.func @test_fft2d(%arg0: tensor<1x4x8xf32>, %arg1: tensor<1x4x8xf32>) -> (te
1414
// -----
1515
func.func @test_variable_read_type(%arg0: tensor<2x4x8xi32>) -> () {
1616
// expected-error@+1 {{'tosa.variable' op illegal: requires [variable] but not enabled in target}}
17-
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
17+
tosa.variable 1 = dense<-1> : tensor<2x4x8xi32>
1818
// expected-error@+1 {{'tosa.variable.read' op illegal: requires [variable]}}
19-
%0 = tosa.variable.read @stored_var : tensor<2x4x8xi16>
19+
%0 = tosa.variable.read 1 : tensor<2x4x8xi16>
2020
return
2121
}
2222

2323
// -----
2424
func.func @test_variable_write_type(%arg0: tensor<2x4x8xi16>) -> () {
2525
// expected-error@+1 {{'tosa.variable' op illegal: requires [variable] but not enabled in target}}
26-
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
26+
tosa.variable 1 = dense<-1> : tensor<2x4x8xi32>
2727
// expected-error@+1 {{'tosa.variable.write' op illegal: requires [variable]}}
28-
tosa.variable.write @stored_var, %arg0 : tensor<2x4x8xi16>
28+
tosa.variable.write 1, %arg0 : tensor<2x4x8xi16>
2929
return
3030
}
3131

mlir/test/Dialect/Tosa/level_check.mlir

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1151,11 +1151,11 @@ func.func @test_scatter_tensor_size_invalid(%arg0: tensor<13x210000000x3xf32>, %
11511151
// -----
11521152

11531153
func.func @test_variable_read_write_tensor_size_invalid() -> () {
1154-
tosa.variable @stored_var = dense<3.14> : tensor<536870912xf32>
1154+
tosa.variable 1 = dense<3.14> : tensor<536870912xf32>
11551155
// expected-error@+1 {{'tosa.variable.read' op failed level check: result tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}}
1156-
%0 = tosa.variable.read @stored_var : tensor<536870912xf32>
1156+
%0 = tosa.variable.read 1 : tensor<536870912xf32>
11571157
// expected-error@+1 {{'tosa.variable.write' op failed level check: operand tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}}
1158-
tosa.variable.write @stored_var, %0 : tensor<536870912xf32>
1158+
tosa.variable.write 1, %0 : tensor<536870912xf32>
11591159
return
11601160
}
11611161

@@ -1219,11 +1219,11 @@ func.func @test_cond_if_rank_invalid(%arg0: tensor<1x1x1x1x1x1x1x1xf32>, %arg1:
12191219

12201220
func.func @test_variable_read_write_rank_invalid() -> () {
12211221
// expected-error@+1 {{'tosa.variable' op failed level check: attribute rank(shape) <= MAX_RANK}}
1222-
tosa.variable @stored_var = dense<3.14> : tensor<1x1x1x1x1x1x1x1xf32>
1222+
tosa.variable 1 = dense<3.14> : tensor<1x1x1x1x1x1x1x1xf32>
12231223
// expected-error@+1 {{'tosa.variable.read' op failed level check: result rank(shape) <= MAX_RANK}}
1224-
%0 = tosa.variable.read @stored_var : tensor<1x1x1x1x1x1x1x1xf32>
1224+
%0 = tosa.variable.read 1 : tensor<1x1x1x1x1x1x1x1xf32>
12251225
// expected-error@+1 {{'tosa.variable.write' op failed level check: operand rank(shape) <= MAX_RANK}}
1226-
tosa.variable.write @stored_var, %0 : tensor<1x1x1x1x1x1x1x1xf32>
1226+
tosa.variable.write 1, %0 : tensor<1x1x1x1x1x1x1x1xf32>
12271227
return
12281228
}
12291229

mlir/test/Dialect/Tosa/variables.mlir

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,28 +6,28 @@
66
// CHECK-LABEL: @test_variable_scalar(
77
// CHECK-SAME: %[[ADD_VAL:.*]]: tensor<f32>) {
88
func.func @test_variable_scalar(%arg0: tensor<f32>) -> () {
9-
// CHECK: tosa.variable @stored_var = dense<3.140000e+00> : tensor<f32>
10-
tosa.variable @stored_var = dense<3.14> : tensor<f32>
11-
// CHECK: %[[STORED_VAL:.*]] = tosa.variable.read @stored_var : tensor<f32>
12-
%0 = tosa.variable.read @stored_var : tensor<f32>
13-
// CHECK: %[[RESULT_ADD:.*]] = tosa.add %[[ADD_VAL]], %[[STORED_VAL]] : (tensor<f32>, tensor<f32>) -> tensor<f32>
9+
// CHECK: tosa.variable 1 = dense<3.140000e+00> : tensor<f32>
10+
tosa.variable 1 = dense<3.14> : tensor<f32>
11+
// CHECK: %[[VAR_1:.*]] = tosa.variable.read 1 : tensor<f32>
12+
%0 = tosa.variable.read 1 : tensor<f32>
13+
// CHECK: %[[RESULT_ADD:.*]] = tosa.add %[[ADD_VAL]], %[[VAR_1]] : (tensor<f32>, tensor<f32>) -> tensor<f32>
1414
%1 = "tosa.add"(%arg0, %0) : (tensor<f32>, tensor<f32>) -> tensor<f32>
15-
// CHECK: tosa.variable.write @stored_var, %[[RESULT_ADD]] : tensor<f32>
16-
tosa.variable.write @stored_var, %1 : tensor<f32>
15+
// CHECK: tosa.variable.write 1, %[[RESULT_ADD]] : tensor<f32>
16+
tosa.variable.write 1, %1 : tensor<f32>
1717
return
1818
}
1919

2020
// -----
2121
// CHECK-LABEL: @test_variable_tensor(
2222
// CHECK-SAME: %[[ADD_VAL:.*]]: tensor<2x4x8xi32>) {
2323
func.func @test_variable_tensor(%arg0: tensor<2x4x8xi32>) -> () {
24-
// CHECK: tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
25-
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
26-
// CHECK: %[[STORED_VAL:.*]] = tosa.variable.read @stored_var : tensor<2x4x8xi32>
27-
%0 = tosa.variable.read @stored_var : tensor<2x4x8xi32>
28-
// CHECK: %[[RESULT_ADD:.*]] = tosa.add %[[ADD_VAL]], %[[STORED_VAL]] : (tensor<2x4x8xi32>, tensor<2x4x8xi32>) -> tensor<2x4x8xi32>
24+
// CHECK: tosa.variable 1 = dense<-1> : tensor<2x4x8xi32>
25+
tosa.variable 1 = dense<-1> : tensor<2x4x8xi32>
26+
// CHECK: %[[VAL_1:.*]] = tosa.variable.read 1 : tensor<2x4x8xi32>
27+
%0 = tosa.variable.read 1 : tensor<2x4x8xi32>
28+
// CHECK: %[[RESULT_ADD:.*]] = tosa.add %[[ADD_VAL]], %[[VAL_1]] : (tensor<2x4x8xi32>, tensor<2x4x8xi32>) -> tensor<2x4x8xi32>
2929
%1 = "tosa.add"(%arg0, %0) : (tensor<2x4x8xi32>, tensor<2x4x8xi32>) -> tensor<2x4x8xi32>
30-
// CHECK: tosa.variable.write @stored_var, %[[RESULT_ADD]] : tensor<2x4x8xi32>
31-
tosa.variable.write @stored_var, %1 : tensor<2x4x8xi32>
30+
// CHECK: tosa.variable.write 1, %[[RESULT_ADD]] : tensor<2x4x8xi32>
31+
tosa.variable.write 1, %1 : tensor<2x4x8xi32>
3232
return
33-
}
33+
}

0 commit comments

Comments
 (0)