Skip to content

[mlir][tosa] Align Variable ops to match with TOSA v1.0 spec #130680

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,9 @@ def Tosa_VariableOp : Tosa_Op<"variable", []> {
}

//===----------------------------------------------------------------------===//
// Operator: variable.write
// Operator: variable_write
//===----------------------------------------------------------------------===//
def Tosa_VariableWriteOp : Tosa_Op<"variable.write", []> {
def Tosa_VariableWriteOp : Tosa_Op<"variable_write", []> {
let summary = "write_buffer operator";

let description = [{
Expand All @@ -120,7 +120,7 @@ def Tosa_VariableWriteOp : Tosa_Op<"variable.write", []> {

let arguments = (ins
SymbolNameAttr:$name,
AnyType:$value
Tosa_Tensor:$input1
);

list<Availability> availability = [
Expand All @@ -129,14 +129,14 @@ def Tosa_VariableWriteOp : Tosa_Op<"variable.write", []> {
];

let assemblyFormat = [{
$name attr-dict `,` $value `:` type($value)
$name attr-dict `,` $input1 `:` type($input1)
}];
}

//===----------------------------------------------------------------------===//
// Operator: variable.read
// Operator: variable_read
//===----------------------------------------------------------------------===//
def Tosa_VariableReadOp : Tosa_Op<"variable.read", []> {
def Tosa_VariableReadOp : Tosa_Op<"variable_read", []> {
let summary = "read_buffer operator";

let description = [{
Expand All @@ -148,7 +148,7 @@ def Tosa_VariableReadOp : Tosa_Op<"variable.read", []> {
);

let results = (outs
AnyType:$value
Tosa_Tensor:$output1
);

list<Availability> availability = [
Expand All @@ -157,7 +157,7 @@ def Tosa_VariableReadOp : Tosa_Op<"variable.read", []> {
];

let assemblyFormat = [{
$name attr-dict `:` type($value)
$name attr-dict `:` type($output1)
}];
}

Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Conversion/TosaToMLProgram/TosaToMLProgram.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class VariableWriteOpConverter
auto globalSymbolRef =
SymbolRefAttr::get(rewriter.getContext(), op.getName());
auto newVariableWrite = rewriter.create<ml_program::GlobalStoreOp>(
op.getLoc(), globalSymbolRef, op.getValue());
op.getLoc(), globalSymbolRef, op.getInput1());
rewriter.replaceOp(op, newVariableWrite);
return success();
}
Expand Down
8 changes: 7 additions & 1 deletion mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,12 @@ LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::VariableOp op) {
return failure();
}

template <>
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::VariableWriteOp op) {
addValue(op.getInput1());
return success();
}

template <>
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::IfOp op) {
addValue(op.getCondition());
Expand Down Expand Up @@ -280,6 +286,7 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
POPULATE_PROFILE_INFO_CUSTOM(Rescale)
POPULATE_PROFILE_INFO_CUSTOM(MatMul)
POPULATE_PROFILE_INFO_CUSTOM(Variable)
POPULATE_PROFILE_INFO_CUSTOM(VariableWrite)
POPULATE_PROFILE_INFO_CUSTOM(If)
POPULATE_PROFILE_INFO_CUSTOM(While)

Expand Down Expand Up @@ -334,7 +341,6 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
POPULATE_PROFILE_INFO_COMMON(Reverse)
POPULATE_PROFILE_INFO_COMMON(Identity)
POPULATE_PROFILE_INFO_COMMON(VariableRead)
POPULATE_PROFILE_INFO_COMMON(VariableWrite)

// Type Invariant Extension, a capability extension that is independent
// of the data type, meaning any compatible type can be used. No type
Expand Down
5 changes: 2 additions & 3 deletions mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -767,7 +767,7 @@ inline bool CompatibleTypes(const mlir::Type &type,

bool TosaValidation::CheckVariable(Operation *op) {
if (isa<mlir::tosa::VariableOp>(op)) {
auto nameAttr = cast<mlir::StringAttr>(op->getAttr("name"));
mlir::StringAttr nameAttr = cast<mlir::StringAttr>(op->getAttr("name"));

if (variablesMap.count(nameAttr)) {
op->emitOpError() << "name has already been declared";
Expand All @@ -786,8 +786,7 @@ bool TosaValidation::CheckVariable(Operation *op) {
bool TosaValidation::CheckVariableReadOrWrite(Operation *op) {
if (isa<mlir::tosa::VariableReadOp>(op) ||
isa<mlir::tosa::VariableWriteOp>(op)) {
auto nameAttr = cast<mlir::StringAttr>(op->getAttr("name"));

mlir::StringAttr nameAttr = cast<mlir::StringAttr>(op->getAttr("name"));
if (!variablesMap.count(nameAttr)) {
op->emitOpError() << "name has not been declared";
return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
// check that -tosa-validate of stateful ops kick in
func.func @test_variable_write_shape(%arg0: tensor<1x4x8xi8>) -> () {
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
// expected-error@+1 {{'tosa.variable.write' op operand type does not equal variable type}}
tosa.variable.write @stored_var, %arg0 : tensor<1x4x8xi8>
// expected-error@+1 {{'tosa.variable_write' op operand type does not equal variable type}}
tosa.variable_write @stored_var, %arg0 : tensor<1x4x8xi8>
return
}

Expand Down
4 changes: 2 additions & 2 deletions mlir/test/Conversion/TosaToMLProgram/tosa-to-mlprogram.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ module {
tosa.variable @var_x = dense<7.000000e+00> : tensor<1xf32>
func.func @test_stateful_ops(%arg0: tensor<1xf32>) -> (tensor<1xf32>) {
// CHECK: ml_program.global_store @var_x = %arg0 : tensor<1xf32>
tosa.variable.write @var_x, %arg0 : tensor<1xf32>
tosa.variable_write @var_x, %arg0 : tensor<1xf32>
// CHECK: %[[LOAD:.+]] = ml_program.global_load @var_x : tensor<1xf32>
%0 = tosa.variable.read @var_x : tensor<1xf32>
%0 = tosa.variable_read @var_x : tensor<1xf32>
return %0 : tensor<1xf32>
}
}
16 changes: 8 additions & 8 deletions mlir/test/Dialect/Tosa/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -699,35 +699,35 @@ func.func @test_variable_duplicates(%arg0: tensor<2x4x8xi8>) -> () {

func.func @test_variable_read_type(%arg0: tensor<2x4x8xi8>) -> () {
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
// expected-error@+1 {{'tosa.variable.read' op result type does not equal variable type}}
%0 = tosa.variable.read @stored_var : tensor<2x4x8xi16>
// expected-error@+1 {{'tosa.variable_read' op illegal: operand/result data types not supported}}
%0 = tosa.variable_read @stored_var : tensor<2x4x8xi16>
return
}

// -----

func.func @test_variable_read_shape(%arg0: tensor<2x4x8xi8>) -> () {
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
// expected-error@+1 {{'tosa.variable.read' op result type does not equal variable type}}
%0 = tosa.variable.read @stored_var : tensor<1x4x8xi32>
// expected-error@+1 {{'tosa.variable_read' op illegal: operand/result data types not supported}}
%0 = tosa.variable_read @stored_var : tensor<1x4x8xi32>
return
}

// -----

func.func @test_variable_write_type(%arg0: tensor<2x4x8xi16>) -> () {
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
// expected-error@+1 {{'tosa.variable.write' op operand type does not equal variable type}}
tosa.variable.write @stored_var, %arg0 : tensor<2x4x8xi16>
// expected-error@+1 {{'tosa.variable_write' op illegal: operand/result data types not supported}}
tosa.variable_write @stored_var, %arg0 : tensor<2x4x8xi16>
return
}

// -----

func.func @test_variable_write_shape(%arg0: tensor<1x4x8xi8>) -> () {
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
// expected-error@+1 {{'tosa.variable.write' op operand type does not equal variable type}}
tosa.variable.write @stored_var, %arg0 : tensor<1x4x8xi8>
// expected-error@+1 {{'tosa.variable_write' op operand type does not equal variable type}}
tosa.variable_write @stored_var, %arg0 : tensor<1x4x8xi8>
return
}

Expand Down
10 changes: 5 additions & 5 deletions mlir/test/Dialect/Tosa/invalid_extension.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -313,17 +313,17 @@ func.func @test_identity(%arg0: tensor<13x21x3xi4>) -> tensor<13x21x3xi4> {
func.func @test_variable_read_type(%arg0: tensor<2x4x8xi8>) -> () {
// expected-error@+1 {{'tosa.variable' op illegal: requires [variable] but not enabled in target}}
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
// expected-error@+1 {{'tosa.variable.read' op illegal: requires [variable]}}
%0 = tosa.variable.read @stored_var : tensor<2x4x8xi16>
// expected-error@+1 {{'tosa.variable_read' op illegal: requires [variable]}}
%0 = tosa.variable_read @stored_var : tensor<2x4x8xi8>
return
}

// -----
func.func @test_variable_write_type(%arg0: tensor<2x4x8xi16>) -> () {
func.func @test_variable_write_type(%arg0: tensor<2x4x8xi8>) -> () {
// expected-error@+1 {{'tosa.variable' op illegal: requires [variable] but not enabled in target}}
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
// expected-error@+1 {{'tosa.variable.write' op illegal: requires [variable]}}
tosa.variable.write @stored_var, %arg0 : tensor<2x4x8xi16>
// expected-error@+1 {{'tosa.variable_write' op illegal: requires [variable]}}
tosa.variable_write @stored_var, %arg0 : tensor<2x4x8xi8>
return
}

Expand Down
16 changes: 8 additions & 8 deletions mlir/test/Dialect/Tosa/level_check.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1089,10 +1089,10 @@ func.func @test_scatter_tensor_size_invalid(%arg0: tensor<13x210000000x3xf32>, %

func.func @test_variable_read_write_tensor_size_invalid() -> () {
tosa.variable @stored_var = dense<3.14> : tensor<536870912xf32>
// expected-error@+1 {{'tosa.variable.read' op failed level check: result tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}}
%0 = tosa.variable.read @stored_var : tensor<536870912xf32>
// expected-error@+1 {{'tosa.variable.write' op failed level check: operand tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}}
tosa.variable.write @stored_var, %0 : tensor<536870912xf32>
// expected-error@+1 {{'tosa.variable_read' op failed level check: result tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}}
%0 = tosa.variable_read @stored_var : tensor<536870912xf32>
// expected-error@+1 {{'tosa.variable_write' op failed level check: operand tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}}
tosa.variable_write @stored_var, %0 : tensor<536870912xf32>
return
}

Expand Down Expand Up @@ -1157,10 +1157,10 @@ func.func @test_cond_if_rank_invalid(%arg0: tensor<1x1x1x1x1x1x1x1xf32>, %arg1:
func.func @test_variable_read_write_rank_invalid() -> () {
// expected-error@+1 {{'tosa.variable' op failed level check: attribute rank(shape) <= MAX_RANK}}
tosa.variable @stored_var = dense<3.14> : tensor<1x1x1x1x1x1x1x1xf32>
// expected-error@+1 {{'tosa.variable.read' op failed level check: result rank(shape) <= MAX_RANK}}
%0 = tosa.variable.read @stored_var : tensor<1x1x1x1x1x1x1x1xf32>
// expected-error@+1 {{'tosa.variable.write' op failed level check: operand rank(shape) <= MAX_RANK}}
tosa.variable.write @stored_var, %0 : tensor<1x1x1x1x1x1x1x1xf32>
// expected-error@+1 {{'tosa.variable_read' op failed level check: result rank(shape) <= MAX_RANK}}
%0 = tosa.variable_read @stored_var : tensor<1x1x1x1x1x1x1x1xf32>
// expected-error@+1 {{'tosa.variable_write' op failed level check: operand rank(shape) <= MAX_RANK}}
tosa.variable_write @stored_var, %0 : tensor<1x1x1x1x1x1x1x1xf32>
return
}

Expand Down
16 changes: 8 additions & 8 deletions mlir/test/Dialect/Tosa/variables.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
func.func @test_variable_scalar(%arg0: tensor<f32>) -> () {
// CHECK: tosa.variable @stored_var = dense<3.140000e+00> : tensor<f32>
tosa.variable @stored_var = dense<3.14> : tensor<f32>
// CHECK: %[[STORED_VAL:.*]] = tosa.variable.read @stored_var : tensor<f32>
%0 = tosa.variable.read @stored_var : tensor<f32>
// CHECK: %[[STORED_VAL:.*]] = tosa.variable_read @stored_var : tensor<f32>
%0 = tosa.variable_read @stored_var : tensor<f32>
// CHECK: %[[RESULT_ADD:.*]] = tosa.add %[[ADD_VAL]], %[[STORED_VAL]] : (tensor<f32>, tensor<f32>) -> tensor<f32>
%1 = "tosa.add"(%arg0, %0) : (tensor<f32>, tensor<f32>) -> tensor<f32>
// CHECK: tosa.variable.write @stored_var, %[[RESULT_ADD]] : tensor<f32>
tosa.variable.write @stored_var, %1 : tensor<f32>
// CHECK: tosa.variable_write @stored_var, %[[RESULT_ADD]] : tensor<f32>
tosa.variable_write @stored_var, %1 : tensor<f32>
return
}

Expand All @@ -23,11 +23,11 @@ func.func @test_variable_scalar(%arg0: tensor<f32>) -> () {
func.func @test_variable_tensor(%arg0: tensor<2x4x8xi32>) -> () {
// CHECK: tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
// CHECK: %[[STORED_VAL:.*]] = tosa.variable.read @stored_var : tensor<2x4x8xi32>
%0 = tosa.variable.read @stored_var : tensor<2x4x8xi32>
// CHECK: %[[STORED_VAL:.*]] = tosa.variable_read @stored_var : tensor<2x4x8xi32>
%0 = tosa.variable_read @stored_var : tensor<2x4x8xi32>
// CHECK: %[[RESULT_ADD:.*]] = tosa.add %[[ADD_VAL]], %[[STORED_VAL]] : (tensor<2x4x8xi32>, tensor<2x4x8xi32>) -> tensor<2x4x8xi32>
%1 = "tosa.add"(%arg0, %0) : (tensor<2x4x8xi32>, tensor<2x4x8xi32>) -> tensor<2x4x8xi32>
// CHECK: tosa.variable.write @stored_var, %[[RESULT_ADD]] : tensor<2x4x8xi32>
tosa.variable.write @stored_var, %1 : tensor<2x4x8xi32>
// CHECK: tosa.variable_write @stored_var, %[[RESULT_ADD]] : tensor<2x4x8xi32>
tosa.variable_write @stored_var, %1 : tensor<2x4x8xi32>
return
}