Skip to content

Commit bd9e32e

Browse files
oneDNN JSON to MLIR converter. (#82)
The converter uses the internal oneDNN JSON parser. Currently, 3 operations, listed in the _opBuilders map, are supported. The other operations to be added in further commits.
1 parent f40e433 commit bd9e32e

13 files changed

+980
-81
lines changed

CMakeLists.txt

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
################################################################################
22
# Copyright (C) 2024 Intel Corporation
3-
#
3+
#
44
# Licensed under the Apache License, Version 2.0 (the "License");
55
# you may not use this file except in compliance with the License.
66
# You may obtain a copy of the License at
7-
#
7+
#
88
# http://www.apache.org/licenses/LICENSE-2.0
9-
#
9+
#
1010
# Unless required by applicable law or agreed to in writing,
1111
# software distributed under the License is distributed on an "AS IS" BASIS,
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
@@ -95,8 +95,7 @@ if(GC_ENABLE_BINDINGS_PYTHON)
9595
endif()
9696

9797
set(GC_LIB_LINKED_LIBS
98-
MLIRLinalgx
99-
MLIRMicrokernel
98+
GCPasses
10099
MLIROneDNNGraph
101100
)
102101
add_library(graph_compiler SHARED ${GC_LIB_SOURCES})

cmake/functions.cmake

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,12 +65,13 @@ function(gc_fetch_content
6565
FetchContent_Populate(${name})
6666
FetchContent_GetProperties(${name})
6767
set(${name}_POPULATED TRUE PARENT_SCOPE)
68-
set(${name}_SOURCE_DIR ${${name}_SOURCE_DIR} PARENT_SCOPE)
69-
set(${name}_BINARY_DIR ${${name}_BINARY_DIR} PARENT_SCOPE)
7068
endif ()
7169
else ()
7270
FetchContent_MakeAvailable(${name})
7371
endif ()
72+
73+
set(${name}_SOURCE_DIR ${${name}_SOURCE_DIR} PARENT_SCOPE)
74+
set(${name}_BINARY_DIR ${${name}_BINARY_DIR} PARENT_SCOPE)
7475
endif ()
7576
endfunction()
7677

cmake/onednn.cmake

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ if (NOT DEFINED DNNL_INCLUDES)
1414
set(DNNL_INCLUDES
1515
${dnnl_BINARY_DIR}/include
1616
${dnnl_SOURCE_DIR}/include
17-
${dnnl_SOURCE_DIR}/src/graph/backend/elyzor/include
17+
${dnnl_SOURCE_DIR}/src
1818
)
1919
set_property(GLOBAL PROPERTY DNNL_INCLUDES ${DNNL_INCLUDES})
2020

src/dnnl/CMakeLists.txt

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,20 @@
1+
# Copyright (C) 2024 Intel Corporation
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions
13+
# and limitations under the License.
14+
# SPDX-License-Identifier: Apache-2.0
15+
116
include(onednn)
17+
include(functions)
218

319
gc_add_path(GC_LIB_SOURCES GLOB "*.cpp")
420
gc_add_path(GC_LIB_INCLUDES ${DNNL_INCLUDES})

src/dnnl/JsonParser.cpp

Lines changed: 287 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,287 @@
1+
/*
2+
* Copyright (C) 2024 Intel Corporation
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing,
11+
* software distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions
14+
* and limitations under the License.
15+
*
16+
* SPDX-License-Identifier: Apache-2.0
17+
*/
18+
19+
#include <limits>
20+
#include <memory>
21+
#include <sstream>
22+
#include <string_view>
23+
24+
#include "gc/Dialect/OneDNNGraph/OneDNNGraphDialect.h"
25+
26+
#include "mlir/ExecutionEngine/OptUtils.h"
27+
#include "mlir/IR/Builders.h"
28+
#include "mlir/IR/MLIRContext.h"
29+
#include "mlir/InitAllPasses.h"
30+
31+
#include "JsonParser.h"
32+
33+
mlir::ModuleOp JsonParser::parse() {
34+
std::vector<size_t> inputPorts;
35+
bool hasInputPorts = false;
36+
bool hasOutputPorts = false;
37+
_reader.begin_object();
38+
39+
while (_reader.next_object_item(&_str)) {
40+
if (_str == "version") {
41+
_reader.read_string(&_str);
42+
// TODO: Check if the version supported
43+
} else if (_str == "engine_kind") {
44+
_reader.read_string(&_str);
45+
if (_str != "cpu") {
46+
throwErr<std::logic_error>("Unsupported engine: ");
47+
}
48+
} else if (_str == "fpmath_mode") {
49+
_reader.read_string(&_str);
50+
if ((_str != "strict") && (_str != "any")) {
51+
throwErr<std::logic_error>(
52+
"Unsupported fpmath_mode: ",
53+
". Only 'strict' and 'any' are currently supported.");
54+
}
55+
} else if (_str == "input_ports") {
56+
hasInputPorts = true;
57+
readNumArray(inputPorts);
58+
} else if (_str == "output_ports") {
59+
hasOutputPorts = true;
60+
readNumArray(_outputIds);
61+
} else if (_str == "graph") {
62+
_reader.begin_array();
63+
while (_reader.next_array_item()) {
64+
readOp();
65+
}
66+
} else {
67+
throwUnrecognizedKey();
68+
}
69+
}
70+
71+
// Check if the input_ports match the expected inputs
72+
if (hasInputPorts) {
73+
if (inputPorts.size() != _inputIds.size()) {
74+
_str = std::to_string(_inputIds.size());
75+
throwErr<std::invalid_argument>(
76+
"Length mismatch between input_ports and inputs: ");
77+
}
78+
for (auto id : _inputIds) {
79+
// The order of the inputs could be different
80+
if (std::find(inputPorts.begin(), inputPorts.end(), id) ==
81+
inputPorts.end()) {
82+
_str = std::to_string(id);
83+
throwErr<std::invalid_argument>("Input not found in input_ports: ");
84+
}
85+
}
86+
}
87+
88+
if (!hasOutputPorts) {
89+
// If output_ports is not specified, using the last operation's outputs.
90+
_outputIds = _uaS;
91+
}
92+
93+
// The function return values.
94+
std::vector<mlir::Value> outputs;
95+
outputs.reserve(_outputIds.size());
96+
for (auto id : _outputIds) {
97+
auto entry = _valueMap.find(id);
98+
if (entry == _valueMap.end()) {
99+
_str = std::to_string(id);
100+
throwErr<std::invalid_argument>("Output value not found: ");
101+
}
102+
outputs.push_back(entry->second);
103+
}
104+
auto ret = _builder.create<mlir::func::ReturnOp>(_loc, outputs);
105+
106+
// Creating the final function and moving the entry block.
107+
mlir::OpBuilder builder(_builder.getContext());
108+
auto module = builder.create<mlir::ModuleOp>(_loc);
109+
auto func = builder.create<mlir::func::FuncOp>(
110+
_loc, "main",
111+
builder.getFunctionType(_entryBlock->getArgumentTypes(),
112+
ret->getOperandTypes()));
113+
auto entry = func.addEntryBlock();
114+
_entryBlock->moveBefore(entry);
115+
entry->erase();
116+
module.push_back(func);
117+
return module;
118+
}
119+
120+
void JsonParser::readOp() {
121+
OpBuilderFn builderFn = nullptr;
122+
123+
_uaS.clear();
124+
_operands.clear();
125+
_attributes.clear();
126+
_resultTypes.clear();
127+
_reader.begin_object();
128+
129+
while (_reader.next_object_item(&_str)) {
130+
if (_str == "id") {
131+
// ignore
132+
_reader.read_number(&_uS);
133+
} else if (_str == "name") {
134+
// ignore
135+
_reader.read_string(&_str);
136+
} else if (_str == "kind") {
137+
_reader.read_string(&_str);
138+
auto fn = _opBuilders.find(_str);
139+
if (fn == _opBuilders.end()) {
140+
throwErr<std::logic_error>("Unsupported operation: ");
141+
}
142+
builderFn = fn->second;
143+
} else if (_str == "attrs") {
144+
_reader.begin_object();
145+
while (_reader.next_object_item(&_str)) {
146+
auto name = mlir::StringAttr::get(_builder.getContext(), _str);
147+
_attributes.emplace_back(name, readAttr());
148+
}
149+
} else if (_str == "inputs") {
150+
_reader.begin_array();
151+
while (_reader.next_array_item()) {
152+
auto type = readTensorType();
153+
auto entry = _valueMap.find(_uS);
154+
if (entry == _valueMap.end()) {
155+
// If not found, then this is a function argument.
156+
auto value = _entryBlock->addArgument(type, _loc);
157+
_valueMap[_uS] = value;
158+
_operands.push_back(value);
159+
_inputIds.push_back(_uS);
160+
} else {
161+
if (entry->second.getType() != type) {
162+
_str = std::to_string(_uS);
163+
throwErr<std::invalid_argument>("Type mismatch for input: ");
164+
}
165+
_operands.push_back(entry->second);
166+
}
167+
}
168+
} else if (_str == "outputs") {
169+
_reader.begin_array();
170+
while (_reader.next_array_item()) {
171+
_resultTypes.push_back(readTensorType());
172+
_uaS.push_back(_uS);
173+
}
174+
} else {
175+
throwUnrecognizedKey();
176+
}
177+
}
178+
179+
if (builderFn == nullptr) {
180+
throwErr<std::invalid_argument>("Operation kind is not specified");
181+
}
182+
183+
auto outputs = builderFn(*this);
184+
assert(outputs.size() == _uaS.size());
185+
auto id = _uaS.begin();
186+
auto value = outputs.begin();
187+
188+
for (; id != _uaS.end(); ++id, ++value) {
189+
if (!_valueMap.emplace(*id, *value).second) {
190+
_str = std::to_string(*id);
191+
throwErr<std::invalid_argument>("Duplicate output id: ");
192+
}
193+
}
194+
}
195+
196+
inline mlir::Attribute JsonParser::readAttr() {
197+
_reader.begin_object();
198+
readKey("type");
199+
_reader.read_string(&_str);
200+
readKey("value", &_str2);
201+
202+
mlir::Attribute attr;
203+
204+
if (_str == "bool") {
205+
_reader.read_number(&_uS);
206+
attr = _builder.getBoolAttr(_uS != 0);
207+
} else if (_str == "s64") {
208+
_reader.read_number(&_i64);
209+
attr = _builder.getI64IntegerAttr(_i64);
210+
} else if (_str == "f32") {
211+
_reader.read_number(&_f32);
212+
attr = _builder.getF32FloatAttr(_f32);
213+
} else if (_str == "s64[]") {
214+
_ia64.clear();
215+
readNumArray(_ia64);
216+
attr = _builder.getI64ArrayAttr(_ia64);
217+
} else if (_str == "f32[]") {
218+
_fa32.clear();
219+
readNumArray(_fa32);
220+
attr = _builder.getF32ArrayAttr(_fa32);
221+
} else if (_str == "string") {
222+
_reader.read_string(&_str);
223+
attr = _builder.getStringAttr(_str);
224+
} else {
225+
throwErr<std::logic_error>("Unsupported attribute type: ");
226+
}
227+
228+
if (_reader.next_object_item(&_str)) {
229+
throwUnrecognizedKey();
230+
}
231+
232+
return attr;
233+
}
234+
235+
mlir::Type JsonParser::readTensorType() {
236+
GetTypeFn getTypeFn = nullptr;
237+
_ia64.clear();
238+
_reader.begin_object();
239+
240+
while (_reader.next_object_item(&_str)) {
241+
if (_str == "id") {
242+
_reader.read_number(&_uS);
243+
} else if (_str == "dtype") {
244+
_reader.read_string(&_str);
245+
auto fn = _dtypes.find(_str);
246+
if (fn == _dtypes.end()) {
247+
throwErr<std::logic_error>("Unsupported dtype: ");
248+
}
249+
getTypeFn = fn->second;
250+
} else if (_str == "shape") {
251+
readNumArray(_ia64);
252+
} else if (_str == "stride") {
253+
_ia642.clear();
254+
readNumArray(_ia642);
255+
if ((_ia642.size() > 1) ||
256+
((_ia642.size() == 1) &&
257+
(_ia642[0] != std::numeric_limits<int64_t>::min()))) {
258+
// TODO: Add support for strides
259+
throwErr<std::logic_error>("Unsupported stride value: ");
260+
}
261+
} else if (_str == "layout_type") {
262+
_reader.read_string(&_str);
263+
if ((_str != "undef") && (_str != "any")) {
264+
throwErr<std::logic_error>("Unsupported layout_type: ");
265+
}
266+
} else if (_str == "property_type") {
267+
_reader.read_string(&_str);
268+
if ((_str != "undef") && (_str != "constant")) {
269+
throwErr<std::logic_error>("Unsupported property_type: ");
270+
}
271+
} else {
272+
throwUnrecognizedKey();
273+
}
274+
}
275+
276+
if (getTypeFn == nullptr) {
277+
_str.clear();
278+
throwErr<std::invalid_argument>("dtype is not specified");
279+
}
280+
281+
if ((_ia64.size() == 1) &&
282+
(_ia64[0] == std::numeric_limits<int64_t>::min())) {
283+
return mlir::UnrankedTensorType::get(getTypeFn(_builder));
284+
}
285+
286+
return mlir::RankedTensorType::get(_ia64, getTypeFn(_builder));
287+
}

0 commit comments

Comments
 (0)