Skip to content

Commit 5de799c

Browse files
committed
Generic support for legalizing tosa.custom_op into another dialect
operation.
1 parent 813e43e commit 5de799c

File tree

2 files changed

+16
-5
lines changed

2 files changed

+16
-5
lines changed

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -508,9 +508,14 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
508508

509509
// tosa::CustomOp
510510
if (auto customOp = dyn_cast<tosa::CustomOp>(op)) {
511-
return llvm::StringSwitch<Value>(customOp.getIdentifierAttr().str())
512-
.Case("atan2", rewriter.create<math::Atan2Op>(loc, resultTypes, args))
513-
.Default(nullptr);
511+
// Only legalize tosa.custom_op's that are marked as implementable with
512+
// 'linalg.generic' by looking at the 'implementation_attrs' attribute
513+
auto implementationAttr = customOp.getImplementationAttrs();
514+
if (implementationAttr == "linalg.generic") {
515+
OperationState state(loc, customOp.getIdentifierAttr(), args,
516+
resultTypes);
517+
return rewriter.create(state)->getResult(0);
518+
}
514519
}
515520

516521
(void)rewriter.notifyMatchFailure(

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1414,9 +1414,12 @@ func.func @select_fp32(%arg0: tensor<1x1x5x5xi1>, %arg1: tensor<1x12x5x5xf32>, %
14141414

14151415
// CHECK-LABEL: @test_custom_ops
14161416
func.func @test_custom_ops(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> () {
1417+
// CHECK: linalg.generic
1418+
// CHECK: math.sin
14171419
// CHECK: linalg.generic
14181420
// CHECK: math.atan2
1419-
%2 = "tosa.custom"(%arg0, %arg1) <{config = "UNDEF", identifier = "atan2", implementation_attrs = "UNDEF"}> : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
1421+
%2 = "tosa.custom"(%arg0) <{config = "UNDEF", identifier = "math.sin", implementation_attrs = "linalg.generic"}> : (tensor<1xf32>) -> tensor<1xf32>
1422+
%3 = "tosa.custom"(%arg0, %arg1) <{config = "UNDEF", identifier = "math.atan2", implementation_attrs = "linalg.generic"}> : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
14201423

14211424
return
14221425
}
@@ -1426,9 +1429,12 @@ func.func @test_custom_ops(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> () {
14261429

14271430
// CHECK-LABEL: @test_custom_ops_dyn
14281431
func.func @test_custom_ops_dyn(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> () {
1432+
// CHECK: linalg.generic
1433+
// CHECK: math.cos
14291434
// CHECK: linalg.generic
14301435
// CHECK: math.atan2
1431-
%2 = "tosa.custom"(%arg0, %arg1) <{config = "UNDEF", identifier = "atan2", implementation_attrs = "UNDEF"}> : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
1436+
%2 = "tosa.custom"(%arg0) <{config = "UNDEF", identifier = "math.cos", implementation_attrs = "linalg.generic"}> : (tensor<?xf32>) -> tensor<?xf32>
1437+
%3 = "tosa.custom"(%arg0, %arg1) <{config = "UNDEF", identifier = "math.atan2", implementation_attrs = "linalg.generic"}> : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
14321438

14331439
return
14341440
}

0 commit comments

Comments
 (0)