Skip to content

Commit 824a2bf

Browse files
committed
Removed tf attributes and allowed for no attributes
1 parent 886bbc5 commit 824a2bf

File tree

5 files changed

+94
-66
lines changed

5 files changed

+94
-66
lines changed

mlir/include/mlir/Dialect/EmitC/IR/EmitC.td

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1584,14 +1584,13 @@ def EmitC_ClassOp
15841584
for its data fields (`emitc.field`) and methods (`emitc.func`).
15851585
It creates a distinct scope, isolating its contents from the surrounding
15861586
MLIR region, similar to how C++ classes encapsulate their internals.
1587-
All the class memebrs need to be default initalizable.
15881587

15891588
Example:
15901589
```mlir
1591-
emitc.class @MymainClass {
1592-
emitc.field @another_feature : !emitc.array<1xf32> = {tf_saved_model.index_path = ["another_feature"]}
1593-
emitc.field @some_feature : !emitc.array<1xf32> = {tf_saved_model.index_path = ["some_feature"]}
1594-
emitc.field @output_0 : !emitc.array<1xf32> = {tf_saved_model.index_path = ["output_0"]}
1590+
emitc.class @mainClass {
1591+
emitc.field @another_feature : !emitc.array<1xf32> = {emitc.opaque = ["another_feature"]}
1592+
emitc.field @some_feature : !emitc.array<1xf32> = {emitc.opaque = ["some_feature"]}
1593+
emitc.field @output_0 : !emitc.array<1xf32> = {emitc.opaque = ["output_0"]}
15951594

15961595
emitc.func @execute() {
15971596
%0 = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t
@@ -1634,15 +1633,22 @@ def EmitC_FieldOp : EmitC_Op<"field", [Symbol]> {
16341633
let summary = "A field within a class";
16351634
let description = [{
16361635
The `emitc.field` operation declares a named field within an `emitc.class`
1637-
operation. The field's type must be an EmitC type. An optional initial value can be provided.
1636+
operation. The field's type must be an EmitC type. The initial value is optional.
1637+
If the argument has attributes, these become the initial value, else we end up with no initial value.
16381638

16391639
Example with initial values:
16401640

16411641
```mlir
1642-
emitc.class @MyModelClass {
1643-
emitc.field @another_feature : !emitc.array<1xf32> = {tf_saved_model.index_path = ["another_feature"]}
1644-
emitc.field @some_feature : !emitc.array<1xf32> = {tf_saved_model.index_path = ["some_feature"]}
1645-
emitc.field @output_0 : !emitc.array<1xf32> = {tf_saved_model.index_path = ["output_0"]}
1642+
emitc.class @modelClass {
1643+
emitc.field @another_feature : !emitc.array<1xf32> = {emitc.opaque = ["another_feature"]}
1644+
emitc.field @some_feature : !emitc.array<1xf32> = {emitc.opaque = ["some_feature"]}
1645+
emitc.field @output_0 : !emitc.array<1xf32> = {emitc.opaque = ["output_0"]}
1646+
}
1647+
```
1648+
Example with no initial value:
1649+
```mlir
1650+
emitc.class @modelClass {
1651+
emitc.field @another_feature : !emitc.array<1xf32>
16461652
}
16471653
```
16481654
}];

mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp

Lines changed: 35 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,17 @@
1212
#include "mlir/Dialect/EmitC/Transforms/Transforms.h"
1313
#include "mlir/IR/Attributes.h"
1414
#include "mlir/IR/Builders.h"
15+
#include "mlir/IR/BuiltinAttributes.h"
1516
#include "mlir/IR/PatternMatch.h"
1617
#include "mlir/IR/TypeRange.h"
18+
#include "mlir/IR/Value.h"
1719
#include "mlir/Pass/Pass.h"
1820
#include "mlir/Transforms/DialectConversion.h"
1921
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
22+
#include "llvm/ADT/StringRef.h"
2023
#include "llvm/Support/GraphWriter.h"
2124
#include "llvm/Support/LogicalResult.h"
25+
#include <string>
2226

2327
namespace mlir {
2428
namespace emitc {
@@ -67,7 +71,7 @@ class WrapFuncInClass : public OpRewritePattern<emitc::FuncOp> {
6771
if (funcOp->getParentOfType<emitc::ClassOp>()) {
6872
return failure();
6973
}
70-
auto className = "My" + funcOp.getSymNameAttr().str() + "Class";
74+
auto className = funcOp.getSymNameAttr().str() + "Class";
7175
mlir::emitc::ClassOp newClassOp =
7276
rewriter.create<emitc::ClassOp>(funcOp.getLoc(), className);
7377

@@ -76,25 +80,33 @@ class WrapFuncInClass : public OpRewritePattern<emitc::FuncOp> {
7680
rewriter.setInsertionPointToStart(&newClassOp.getBody().front());
7781

7882
auto argAttrs = funcOp.getArgAttrs();
79-
if (argAttrs) {
80-
for (const auto &[arg, val] :
81-
llvm::zip(*argAttrs, funcOp.getArguments())) {
82-
if (auto namedAttr =
83-
dyn_cast<mlir::DictionaryAttr>(arg).getNamed(attributeName)) {
84-
Attribute nv = namedAttr->getValue();
85-
StringAttr fieldName =
86-
cast<mlir::StringAttr>(cast<mlir::ArrayAttr>(nv)[0]);
87-
TypeAttr typeAttr = TypeAttr::get(val.getType());
88-
fields.push_back({fieldName, typeAttr});
89-
90-
rewriter.create<emitc::FieldOp>(funcOp.getLoc(), fieldName, typeAttr,
91-
/* attributes*/ arg);
83+
size_t idx = 0;
84+
85+
for (const BlockArgument &val : funcOp.getArguments()) {
86+
StringAttr fieldName;
87+
Attribute argAttr = nullptr;
88+
89+
if (argAttrs && idx < argAttrs->size()) {
90+
if (DictionaryAttr dictAttr =
91+
dyn_cast<mlir::DictionaryAttr>((*argAttrs)[idx])) {
92+
if (auto namedAttr = dictAttr.getNamed(attributeName)) {
93+
Attribute nv = namedAttr->getValue();
94+
fieldName = cast<mlir::StringAttr>(cast<mlir::ArrayAttr>(nv)[0]);
95+
argAttr = (*argAttrs)[idx];
96+
}
9297
}
9398
}
94-
} else {
95-
funcOp->emitOpError("arguments should have attributes so we can "
96-
"initialize class fields.");
97-
return failure();
99+
100+
if (!fieldName) {
101+
fieldName = rewriter.getStringAttr("fieldName" + std::to_string(idx));
102+
}
103+
104+
TypeAttr typeAttr = TypeAttr::get(val.getType());
105+
fields.push_back({fieldName, typeAttr});
106+
rewriter.create<emitc::FieldOp>(funcOp.getLoc(), fieldName, typeAttr,
107+
argAttr);
108+
109+
++idx;
98110
}
99111

100112
rewriter.setInsertionPointToEnd(&newClassOp.getBody().front());
@@ -112,7 +124,7 @@ class WrapFuncInClass : public OpRewritePattern<emitc::FuncOp> {
112124
rewriter.setInsertionPointToStart(&newFuncOp.getBody().front());
113125
std::vector<Value> newArguments;
114126
for (auto [fieldName, attr] : fields) {
115-
auto arg =
127+
GetFieldOp arg =
116128
rewriter.create<emitc::GetFieldOp>(loc, attr.getValue(), fieldName);
117129
newArguments.push_back(arg);
118130
}
@@ -122,14 +134,13 @@ class WrapFuncInClass : public OpRewritePattern<emitc::FuncOp> {
122134
rewriter.replaceAllUsesWith(oldArg, newArg);
123135
}
124136

125-
while (!newFuncOp.getArguments().empty()) {
126-
if (failed(newFuncOp.eraseArgument(0))) {
127-
break;
128-
}
137+
llvm::BitVector argsToErase(newFuncOp.getNumArguments(), true);
138+
if (failed(newFuncOp.eraseArguments(argsToErase))) {
139+
newFuncOp->emitOpError("Failed to erase all arguments using BitVector.");
129140
}
130141

131142
rewriter.replaceOp(funcOp, newClassOp);
132-
return funcOp->use_empty() ? success() : failure();
143+
return success();
133144
}
134145
};
135146

Lines changed: 26 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
// RUN: mlir-opt --wrap-emitc-func-in-class='named-attribute=tf_saved_model.index_path' %s | FileCheck %s
1+
// RUN: mlir-opt --wrap-emitc-func-in-class='named-attribute=emitc.opaque' %s | FileCheck %s
22

3-
module attributes {tf_saved_model.semantics, tfl.description = "MLIR Converted.", tfl.schema_version = 3 : i32} {
4-
emitc.func @Model(%arg0: !emitc.array<1xf32> {tf_saved_model.index_path = ["another_feature"]}, %arg1: !emitc.array<1xf32> {tf_saved_model.index_path = ["some_feature"]}, %arg2: !emitc.array<1xf32> {tf_saved_model.index_path = ["output_0"]}) attributes {tf.entry_function = {inputs = "serving_default_another_feature:0,serving_default_some_feature:0", outputs = "PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} {
3+
module attributes { } {
4+
emitc.func @model(%arg0: !emitc.array<1xf32> {emitc.opaque = ["another_feature"]}, %arg1: !emitc.array<1xf32> {emitc.opaque = ["some_feature"]}, %arg2: !emitc.array<1xf32> {emitc.opaque = ["output_0"]}) attributes { } {
55
%0 = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t
66
%1 = subscript %arg1[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
77
%2 = load %1 : <f32>
@@ -14,24 +14,26 @@ module attributes {tf_saved_model.semantics, tfl.description = "MLIR Converted."
1414
}
1515
}
1616

17-
// CHECK: module attributes {tf_saved_model.semantics, tfl.description = "MLIR Converted.", tfl.schema_version = 3 : i32} {
18-
// CHECK: emitc.class @MyModelClass {
19-
// CHECK: emitc.field @another_feature : !emitc.array<1xf32> = {tf_saved_model.index_path = ["another_feature"]}
20-
// CHECK: emitc.field @some_feature : !emitc.array<1xf32> = {tf_saved_model.index_path = ["some_feature"]}
21-
// CHECK: emitc.field @output_0 : !emitc.array<1xf32> = {tf_saved_model.index_path = ["output_0"]}
22-
// CHECK: emitc.func @execute() {
23-
// CHECK: %{{[0-9]+}} = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t
24-
// CHECK: %{{[0-9]+}} = get_field @another_feature : !emitc.array<1xf32>
25-
// CHECK: %{{[0-9]+}} = get_field @some_feature : !emitc.array<1xf32>
26-
// CHECK: %{{[0-9]+}} = get_field @output_0 : !emitc.array<1xf32>
27-
// CHECK: %{{[0-9]+}} = subscript %{{[0-9]+}}[%{{[0-9]+}}] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
28-
// CHECK: %{{[0-9]+}} = load %{{[0-9]+}} : <f32>
29-
// CHECK: %{{[0-9]+}} = subscript %{{[0-9]+}}[%{{[0-9]+}}] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
30-
// CHECK: %{{[0-9]+}} = load %{{[0-9]+}} : <f32>
31-
// CHECK: %{{[0-9]+}} = add %{{[0-9]+}}, %{{[0-9]+}} : (f32, f32) -> f32
32-
// CHECK: %{{[0-9]+}} = subscript %{{[0-9]+}}[%{{[0-9]+}}] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
33-
// CHECK: assign %{{[0-9]+}} : f32 to %{{[0-9]+}} : <f32>
34-
// CHECK: return
35-
// CHECK: }
36-
// CHECK: }
37-
// CHECK: }
17+
18+
// CHECK: module {
19+
// CHECK-NEXT: emitc.class @modelClass {
20+
// CHECK-NEXT: emitc.field @another_feature : !emitc.array<1xf32> = {emitc.opaque = ["another_feature"]}
21+
// CHECK-NEXT: emitc.field @some_feature : !emitc.array<1xf32> = {emitc.opaque = ["some_feature"]}
22+
// CHECK-NEXT: emitc.field @output_0 : !emitc.array<1xf32> = {emitc.opaque = ["output_0"]}
23+
// CHECK-NEXT: emitc.func @execute() {
24+
// CHECK-NEXT: %{{[0-9]+}} = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t
25+
// CHECK-NEXT: %{{[0-9]+}} = get_field @another_feature : !emitc.array<1xf32>
26+
// CHECK-NEXT: %{{[0-9]+}} = get_field @some_feature : !emitc.array<1xf32>
27+
// CHECK-NEXT: %{{[0-9]+}} = get_field @output_0 : !emitc.array<1xf32>
28+
// CHECK-NEXT: %{{[0-9]+}} = subscript %{{[0-9]+}}[%{{[0-9]+}}] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
29+
// CHECK-NEXT: %{{[0-9]+}} = load %{{[0-9]+}} : <f32>
30+
// CHECK-NEXT: %{{[0-9]+}} = subscript %{{[0-9]+}}[%{{[0-9]+}}] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
31+
// CHECK-NEXT: %{{[0-9]+}} = load %{{[0-9]+}} : <f32>
32+
// CHECK-NEXT: %{{[0-9]+}} = add %{{[0-9]+}}, %{{[0-9]+}} : (f32, f32) -> f32
33+
// CHECK-NEXT: %{{[0-9]+}} = subscript %{{[0-9]+}}[%{{[0-9]+}}] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
34+
// CHECK-NEXT: assign %{{[0-9]+}} : f32 to %{{[0-9]+}} : <f32>
35+
// CHECK-NEXT: return
36+
// CHECK-NEXT: }
37+
// CHECK-NEXT: }
38+
// CHECK-NEXT: }
39+

mlir/test/Dialect/EmitC/wrap_emitc_func_in_class_neg.mlir

Lines changed: 0 additions & 8 deletions
This file was deleted.
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
// RUN: mlir-opt --wrap-emitc-func-in-class='named-attribute=emitc.opaque' %s | FileCheck %s
2+
3+
emitc.func @foo(%arg0 : !emitc.array<1xf32>) {
4+
emitc.call_opaque "bar" (%arg0) : (!emitc.array<1xf32>) -> ()
5+
emitc.return
6+
}
7+
8+
// CHECK: module {
9+
// CHECK-NEXT: emitc.class @fooClass {
10+
// CHECK-NEXT: emitc.field @fieldName0 : !emitc.array<1xf32>
11+
// CHECK-NEXT: emitc.func @execute() {
12+
// CHECK-NEXT: %0 = get_field @fieldName0 : !emitc.array<1xf32>
13+
// CHECK-NEXT: call_opaque "bar"(%0) : (!emitc.array<1xf32>) -> ()
14+
// CHECK-NEXT: return
15+
// CHECK-NEXT: }
16+
// CHECK-NEXT: }
17+
// CHECK-NEXT: }

0 commit comments

Comments
 (0)