Skip to content

Commit 4d0d295

Browse files
committed
[mlir][python] Allow specifying block arg locations
Currently blocks are always created with UnknownLoc's for their arguments. This adds an `arg_locs` argument to all block creation APIs, which takes an optional sequence of locations to use, one per block argument. If no locations are supplied, the current Location context is used. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D150084
1 parent b9031d3 commit 4d0d295

File tree

3 files changed

+90
-75
lines changed

3 files changed

+90
-75
lines changed

mlir/lib/Bindings/Python/IRCore.cpp

Lines changed: 45 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,31 @@ static MlirStringRef toMlirStringRef(const std::string &s) {
193193
return mlirStringRefCreate(s.data(), s.size());
194194
}
195195

196+
/// Create a block, using the current location context if no locations are
197+
/// specified.
198+
static MlirBlock createBlock(const py::sequence &pyArgTypes,
199+
const std::optional<py::sequence> &pyArgLocs) {
200+
SmallVector<MlirType> argTypes;
201+
argTypes.reserve(pyArgTypes.size());
202+
for (const auto &pyType : pyArgTypes)
203+
argTypes.push_back(pyType.cast<PyType &>());
204+
205+
SmallVector<MlirLocation> argLocs;
206+
if (pyArgLocs) {
207+
argLocs.reserve(pyArgLocs->size());
208+
for (const auto &pyLoc : *pyArgLocs)
209+
argLocs.push_back(pyLoc.cast<PyLocation &>());
210+
} else if (!argTypes.empty()) {
211+
argLocs.assign(argTypes.size(), DefaultingPyLocation::resolve());
212+
}
213+
214+
if (argTypes.size() != argLocs.size())
215+
throw py::value_error(("Expected " + Twine(argTypes.size()) +
216+
" locations, got: " + Twine(argLocs.size()))
217+
.str());
218+
return mlirBlockCreate(argTypes.size(), argTypes.data(), argLocs.data());
219+
}
220+
196221
/// Wrapper for the global LLVM debugging flag.
197222
struct PyGlobalDebugFlag {
198223
static void set(py::object &o, bool enable) { mlirEnableGlobalDebug(enable); }
@@ -364,21 +389,10 @@ class PyBlockList {
364389
throw SetPyError(PyExc_IndexError, "attempt to access out of bounds block");
365390
}
366391

367-
PyBlock appendBlock(const py::args &pyArgTypes) {
392+
PyBlock appendBlock(const py::args &pyArgTypes,
393+
const std::optional<py::sequence> &pyArgLocs) {
368394
operation->checkValid();
369-
llvm::SmallVector<MlirType, 4> argTypes;
370-
llvm::SmallVector<MlirLocation, 4> argLocs;
371-
argTypes.reserve(pyArgTypes.size());
372-
argLocs.reserve(pyArgTypes.size());
373-
for (auto &pyArg : pyArgTypes) {
374-
argTypes.push_back(pyArg.cast<PyType &>());
375-
// TODO: Pass in a proper location here.
376-
argLocs.push_back(
377-
mlirLocationUnknownGet(mlirTypeGetContext(argTypes.back())));
378-
}
379-
380-
MlirBlock block =
381-
mlirBlockCreate(argTypes.size(), argTypes.data(), argLocs.data());
395+
MlirBlock block = createBlock(pyArgTypes, pyArgLocs);
382396
mlirRegionAppendOwnedBlock(region, block);
383397
return PyBlock(operation, block);
384398
}
@@ -388,7 +402,8 @@ class PyBlockList {
388402
.def("__getitem__", &PyBlockList::dunderGetItem)
389403
.def("__iter__", &PyBlockList::dunderIter)
390404
.def("__len__", &PyBlockList::dunderLen)
391-
.def("append", &PyBlockList::appendBlock, kAppendBlockDocstring);
405+
.def("append", &PyBlockList::appendBlock, kAppendBlockDocstring,
406+
py::arg("arg_locs") = std::nullopt);
392407
}
393408

394409
private:
@@ -2966,27 +2981,17 @@ void mlir::python::populateIRCore(py::module &m) {
29662981
"Returns a forward-optimized sequence of operations.")
29672982
.def_static(
29682983
"create_at_start",
2969-
[](PyRegion &parent, py::list pyArgTypes) {
2984+
[](PyRegion &parent, const py::list &pyArgTypes,
2985+
const std::optional<py::sequence> &pyArgLocs) {
29702986
parent.checkValid();
2971-
llvm::SmallVector<MlirType, 4> argTypes;
2972-
llvm::SmallVector<MlirLocation, 4> argLocs;
2973-
argTypes.reserve(pyArgTypes.size());
2974-
argLocs.reserve(pyArgTypes.size());
2975-
for (auto &pyArg : pyArgTypes) {
2976-
argTypes.push_back(pyArg.cast<PyType &>());
2977-
// TODO: Pass in a proper location here.
2978-
argLocs.push_back(
2979-
mlirLocationUnknownGet(mlirTypeGetContext(argTypes.back())));
2980-
}
2981-
2982-
MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data(),
2983-
argLocs.data());
2987+
MlirBlock block = createBlock(pyArgTypes, pyArgLocs);
29842988
mlirRegionInsertOwnedBlock(parent, 0, block);
29852989
return PyBlock(parent.getParentOperation(), block);
29862990
},
29872991
py::arg("parent"), py::arg("arg_types") = py::list(),
2992+
py::arg("arg_locs") = std::nullopt,
29882993
"Creates and returns a new Block at the beginning of the given "
2989-
"region (with given argument types).")
2994+
"region (with given argument types and locations).")
29902995
.def(
29912996
"append_to",
29922997
[](PyBlock &self, PyRegion &region) {
@@ -2998,50 +3003,30 @@ void mlir::python::populateIRCore(py::module &m) {
29983003
"Append this block to a region, transferring ownership if necessary")
29993004
.def(
30003005
"create_before",
3001-
[](PyBlock &self, py::args pyArgTypes) {
3006+
[](PyBlock &self, const py::args &pyArgTypes,
3007+
const std::optional<py::sequence> &pyArgLocs) {
30023008
self.checkValid();
3003-
llvm::SmallVector<MlirType, 4> argTypes;
3004-
llvm::SmallVector<MlirLocation, 4> argLocs;
3005-
argTypes.reserve(pyArgTypes.size());
3006-
argLocs.reserve(pyArgTypes.size());
3007-
for (auto &pyArg : pyArgTypes) {
3008-
argTypes.push_back(pyArg.cast<PyType &>());
3009-
// TODO: Pass in a proper location here.
3010-
argLocs.push_back(
3011-
mlirLocationUnknownGet(mlirTypeGetContext(argTypes.back())));
3012-
}
3013-
3014-
MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data(),
3015-
argLocs.data());
3009+
MlirBlock block = createBlock(pyArgTypes, pyArgLocs);
30163010
MlirRegion region = mlirBlockGetParentRegion(self.get());
30173011
mlirRegionInsertOwnedBlockBefore(region, self.get(), block);
30183012
return PyBlock(self.getParentOperation(), block);
30193013
},
3014+
py::arg("arg_locs") = std::nullopt,
30203015
"Creates and returns a new Block before this block "
3021-
"(with given argument types).")
3016+
"(with given argument types and locations).")
30223017
.def(
30233018
"create_after",
3024-
[](PyBlock &self, py::args pyArgTypes) {
3019+
[](PyBlock &self, const py::args &pyArgTypes,
3020+
const std::optional<py::sequence> &pyArgLocs) {
30253021
self.checkValid();
3026-
llvm::SmallVector<MlirType, 4> argTypes;
3027-
llvm::SmallVector<MlirLocation, 4> argLocs;
3028-
argTypes.reserve(pyArgTypes.size());
3029-
argLocs.reserve(pyArgTypes.size());
3030-
for (auto &pyArg : pyArgTypes) {
3031-
argTypes.push_back(pyArg.cast<PyType &>());
3032-
3033-
// TODO: Pass in a proper location here.
3034-
argLocs.push_back(
3035-
mlirLocationUnknownGet(mlirTypeGetContext(argTypes.back())));
3036-
}
3037-
MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data(),
3038-
argLocs.data());
3022+
MlirBlock block = createBlock(pyArgTypes, pyArgLocs);
30393023
MlirRegion region = mlirBlockGetParentRegion(self.get());
30403024
mlirRegionInsertOwnedBlockAfter(region, self.get(), block);
30413025
return PyBlock(self.getParentOperation(), block);
30423026
},
3027+
py::arg("arg_locs") = std::nullopt,
30433028
"Creates and returns a new Block after this block "
3044-
"(with given argument types).")
3029+
"(with given argument types and locations).")
30453030
.def(
30463031
"__iter__",
30473032
[](PyBlock &self) {

mlir/python/mlir/dialects/_func_ops_ext.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,15 +90,15 @@ def entry_block(self):
9090
raise IndexError('External function does not have a body')
9191
return self.regions[0].blocks[0]
9292

93-
def add_entry_block(self):
93+
def add_entry_block(self, arg_locs: Optional[Sequence[Location]] = None):
9494
"""
9595
Add an entry block to the function body using the function signature to
9696
infer block arguments.
9797
Returns the newly created block
9898
"""
9999
if not self.is_external:
100100
raise IndexError('The function already has an entry block!')
101-
self.body.blocks.append(*self.type.inputs)
101+
self.body.blocks.append(*self.type.inputs, arg_locs=arg_locs)
102102
return self.body.blocks[0]
103103

104104
@property

mlir/test/python/ir/blocks.py

Lines changed: 43 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,28 +18,28 @@ def run(f):
1818

1919

2020
# CHECK-LABEL: TEST: testBlockCreation
21-
# CHECK: func @test(%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i16)
21+
# CHECK: func @test(%[[ARG0:.*]]: i32 loc("arg0"), %[[ARG1:.*]]: i16 loc("arg1"))
2222
# CHECK: cf.br ^bb1(%[[ARG1]] : i16)
23-
# CHECK: ^bb1(%[[PHI0:.*]]: i16):
23+
# CHECK: ^bb1(%[[PHI0:.*]]: i16 loc("middle")):
2424
# CHECK: cf.br ^bb2(%[[ARG0]] : i32)
25-
# CHECK: ^bb2(%[[PHI1:.*]]: i32):
25+
# CHECK: ^bb2(%[[PHI1:.*]]: i32 loc("successor")):
2626
# CHECK: return
2727
@run
2828
def testBlockCreation():
2929
with Context() as ctx, Location.unknown():
30-
module = Module.create()
30+
module = builtin.ModuleOp()
3131
with InsertionPoint(module.body):
3232
f_type = FunctionType.get(
3333
[IntegerType.get_signless(32),
3434
IntegerType.get_signless(16)], [])
3535
f_op = func.FuncOp("test", f_type)
36-
entry_block = f_op.add_entry_block()
36+
entry_block = f_op.add_entry_block([Location.name("arg0"), Location.name("arg1")])
3737
i32_arg, i16_arg = entry_block.arguments
38-
successor_block = entry_block.create_after(i32_arg.type)
38+
successor_block = entry_block.create_after(i32_arg.type, arg_locs=[Location.name("successor")])
3939
with InsertionPoint(successor_block) as successor_ip:
4040
assert successor_ip.block == successor_block
4141
func.ReturnOp([])
42-
middle_block = successor_block.create_before(i16_arg.type)
42+
middle_block = successor_block.create_before(i16_arg.type, arg_locs=[Location.name("middle")])
4343

4444
with InsertionPoint(entry_block) as entry_ip:
4545
assert entry_ip.block == entry_block
@@ -48,27 +48,57 @@ def testBlockCreation():
4848
with InsertionPoint(middle_block) as middle_ip:
4949
assert middle_ip.block == middle_block
5050
cf.BranchOp([i32_arg], dest=successor_block)
51-
print(module.operation)
51+
module.print(enable_debug_info=True)
5252
# Ensure region back references are coherent.
5353
assert entry_block.region == middle_block.region == successor_block.region
5454

5555

56+
# CHECK-LABEL: TEST: testBlockCreationArgLocs
57+
@run
58+
def testBlockCreationArgLocs():
59+
with Context() as ctx:
60+
ctx.allow_unregistered_dialects = True
61+
f32 = F32Type.get()
62+
op = Operation.create("test", regions=1, loc=Location.unknown())
63+
blocks = op.regions[0].blocks
64+
65+
with Location.name("default_loc"):
66+
blocks.append(f32)
67+
blocks.append()
68+
# CHECK: ^bb0(%{{.+}}: f32 loc("default_loc")):
69+
# CHECK-NEXT: ^bb1:
70+
op.print(enable_debug_info=True)
71+
72+
try:
73+
blocks.append(f32)
74+
except RuntimeError as err:
75+
# CHECK: Missing loc: An MLIR function requires a Location but none was provided
76+
print("Missing loc:", err)
77+
78+
try:
79+
blocks.append(f32, f32, arg_locs=[Location.unknown()])
80+
except ValueError as err:
81+
# CHECK: Wrong loc count: Expected 2 locations, got: 1
82+
print("Wrong loc count:", err)
83+
84+
5685
# CHECK-LABEL: TEST: testFirstBlockCreation
57-
# CHECK: func @test(%{{.*}}: f32)
86+
# CHECK: func @test(%{{.*}}: f32 loc("arg_loc"))
5887
# CHECK: return
5988
@run
6089
def testFirstBlockCreation():
6190
with Context() as ctx, Location.unknown():
62-
module = Module.create()
91+
module = builtin.ModuleOp()
6392
f32 = F32Type.get()
6493
with InsertionPoint(module.body):
6594
f = func.FuncOp("test", ([f32], []))
66-
entry_block = Block.create_at_start(f.operation.regions[0], [f32])
95+
entry_block = Block.create_at_start(f.operation.regions[0],
96+
[f32], [Location.name("arg_loc")])
6797
with InsertionPoint(entry_block):
6898
func.ReturnOp([])
6999

70-
print(module)
71-
assert module.operation.verify()
100+
module.print(enable_debug_info=True)
101+
assert module.verify()
72102
assert f.body.blocks[0] == entry_block
73103

74104

0 commit comments

Comments
 (0)