@@ -193,6 +193,31 @@ static MlirStringRef toMlirStringRef(const std::string &s) {
193
193
return mlirStringRefCreate (s.data (), s.size ());
194
194
}
195
195
196
+ // / Create a block, using the current location context if no locations are
197
+ // / specified.
198
+ static MlirBlock createBlock (const py::sequence &pyArgTypes,
199
+ const std::optional<py::sequence> &pyArgLocs) {
200
+ SmallVector<MlirType> argTypes;
201
+ argTypes.reserve (pyArgTypes.size ());
202
+ for (const auto &pyType : pyArgTypes)
203
+ argTypes.push_back (pyType.cast <PyType &>());
204
+
205
+ SmallVector<MlirLocation> argLocs;
206
+ if (pyArgLocs) {
207
+ argLocs.reserve (pyArgLocs->size ());
208
+ for (const auto &pyLoc : *pyArgLocs)
209
+ argLocs.push_back (pyLoc.cast <PyLocation &>());
210
+ } else if (!argTypes.empty ()) {
211
+ argLocs.assign (argTypes.size (), DefaultingPyLocation::resolve ());
212
+ }
213
+
214
+ if (argTypes.size () != argLocs.size ())
215
+ throw py::value_error ((" Expected " + Twine (argTypes.size ()) +
216
+ " locations, got: " + Twine (argLocs.size ()))
217
+ .str ());
218
+ return mlirBlockCreate (argTypes.size (), argTypes.data (), argLocs.data ());
219
+ }
220
+
196
221
// / Wrapper for the global LLVM debugging flag.
197
222
struct PyGlobalDebugFlag {
198
223
static void set (py::object &o, bool enable) { mlirEnableGlobalDebug (enable); }
@@ -364,21 +389,10 @@ class PyBlockList {
364
389
throw SetPyError (PyExc_IndexError, " attempt to access out of bounds block" );
365
390
}
366
391
367
- PyBlock appendBlock (const py::args &pyArgTypes) {
392
+ PyBlock appendBlock (const py::args &pyArgTypes,
393
+ const std::optional<py::sequence> &pyArgLocs) {
368
394
operation->checkValid ();
369
- llvm::SmallVector<MlirType, 4 > argTypes;
370
- llvm::SmallVector<MlirLocation, 4 > argLocs;
371
- argTypes.reserve (pyArgTypes.size ());
372
- argLocs.reserve (pyArgTypes.size ());
373
- for (auto &pyArg : pyArgTypes) {
374
- argTypes.push_back (pyArg.cast <PyType &>());
375
- // TODO: Pass in a proper location here.
376
- argLocs.push_back (
377
- mlirLocationUnknownGet (mlirTypeGetContext (argTypes.back ())));
378
- }
379
-
380
- MlirBlock block =
381
- mlirBlockCreate (argTypes.size (), argTypes.data (), argLocs.data ());
395
+ MlirBlock block = createBlock (pyArgTypes, pyArgLocs);
382
396
mlirRegionAppendOwnedBlock (region, block);
383
397
return PyBlock (operation, block);
384
398
}
@@ -388,7 +402,8 @@ class PyBlockList {
388
402
.def (" __getitem__" , &PyBlockList::dunderGetItem)
389
403
.def (" __iter__" , &PyBlockList::dunderIter)
390
404
.def (" __len__" , &PyBlockList::dunderLen)
391
- .def (" append" , &PyBlockList::appendBlock, kAppendBlockDocstring );
405
+ .def (" append" , &PyBlockList::appendBlock, kAppendBlockDocstring ,
406
+ py::arg (" arg_locs" ) = std::nullopt);
392
407
}
393
408
394
409
private:
@@ -2966,27 +2981,17 @@ void mlir::python::populateIRCore(py::module &m) {
2966
2981
" Returns a forward-optimized sequence of operations." )
2967
2982
.def_static (
2968
2983
" create_at_start" ,
2969
- [](PyRegion &parent, py::list pyArgTypes) {
2984
+ [](PyRegion &parent, const py::list &pyArgTypes,
2985
+ const std::optional<py::sequence> &pyArgLocs) {
2970
2986
parent.checkValid ();
2971
- llvm::SmallVector<MlirType, 4 > argTypes;
2972
- llvm::SmallVector<MlirLocation, 4 > argLocs;
2973
- argTypes.reserve (pyArgTypes.size ());
2974
- argLocs.reserve (pyArgTypes.size ());
2975
- for (auto &pyArg : pyArgTypes) {
2976
- argTypes.push_back (pyArg.cast <PyType &>());
2977
- // TODO: Pass in a proper location here.
2978
- argLocs.push_back (
2979
- mlirLocationUnknownGet (mlirTypeGetContext (argTypes.back ())));
2980
- }
2981
-
2982
- MlirBlock block = mlirBlockCreate (argTypes.size (), argTypes.data (),
2983
- argLocs.data ());
2987
+ MlirBlock block = createBlock (pyArgTypes, pyArgLocs);
2984
2988
mlirRegionInsertOwnedBlock (parent, 0 , block);
2985
2989
return PyBlock (parent.getParentOperation (), block);
2986
2990
},
2987
2991
py::arg (" parent" ), py::arg (" arg_types" ) = py::list (),
2992
+ py::arg (" arg_locs" ) = std::nullopt,
2988
2993
" Creates and returns a new Block at the beginning of the given "
2989
- " region (with given argument types)." )
2994
+ " region (with given argument types and locations )." )
2990
2995
.def (
2991
2996
" append_to" ,
2992
2997
[](PyBlock &self, PyRegion ®ion) {
@@ -2998,50 +3003,30 @@ void mlir::python::populateIRCore(py::module &m) {
2998
3003
" Append this block to a region, transferring ownership if necessary" )
2999
3004
.def (
3000
3005
" create_before" ,
3001
- [](PyBlock &self, py::args pyArgTypes) {
3006
+ [](PyBlock &self, const py::args &pyArgTypes,
3007
+ const std::optional<py::sequence> &pyArgLocs) {
3002
3008
self.checkValid ();
3003
- llvm::SmallVector<MlirType, 4 > argTypes;
3004
- llvm::SmallVector<MlirLocation, 4 > argLocs;
3005
- argTypes.reserve (pyArgTypes.size ());
3006
- argLocs.reserve (pyArgTypes.size ());
3007
- for (auto &pyArg : pyArgTypes) {
3008
- argTypes.push_back (pyArg.cast <PyType &>());
3009
- // TODO: Pass in a proper location here.
3010
- argLocs.push_back (
3011
- mlirLocationUnknownGet (mlirTypeGetContext (argTypes.back ())));
3012
- }
3013
-
3014
- MlirBlock block = mlirBlockCreate (argTypes.size (), argTypes.data (),
3015
- argLocs.data ());
3009
+ MlirBlock block = createBlock (pyArgTypes, pyArgLocs);
3016
3010
MlirRegion region = mlirBlockGetParentRegion (self.get ());
3017
3011
mlirRegionInsertOwnedBlockBefore (region, self.get (), block);
3018
3012
return PyBlock (self.getParentOperation (), block);
3019
3013
},
3014
+ py::arg (" arg_locs" ) = std::nullopt,
3020
3015
" Creates and returns a new Block before this block "
3021
- " (with given argument types)." )
3016
+ " (with given argument types and locations )." )
3022
3017
.def (
3023
3018
" create_after" ,
3024
- [](PyBlock &self, py::args pyArgTypes) {
3019
+ [](PyBlock &self, const py::args &pyArgTypes,
3020
+ const std::optional<py::sequence> &pyArgLocs) {
3025
3021
self.checkValid ();
3026
- llvm::SmallVector<MlirType, 4 > argTypes;
3027
- llvm::SmallVector<MlirLocation, 4 > argLocs;
3028
- argTypes.reserve (pyArgTypes.size ());
3029
- argLocs.reserve (pyArgTypes.size ());
3030
- for (auto &pyArg : pyArgTypes) {
3031
- argTypes.push_back (pyArg.cast <PyType &>());
3032
-
3033
- // TODO: Pass in a proper location here.
3034
- argLocs.push_back (
3035
- mlirLocationUnknownGet (mlirTypeGetContext (argTypes.back ())));
3036
- }
3037
- MlirBlock block = mlirBlockCreate (argTypes.size (), argTypes.data (),
3038
- argLocs.data ());
3022
+ MlirBlock block = createBlock (pyArgTypes, pyArgLocs);
3039
3023
MlirRegion region = mlirBlockGetParentRegion (self.get ());
3040
3024
mlirRegionInsertOwnedBlockAfter (region, self.get (), block);
3041
3025
return PyBlock (self.getParentOperation (), block);
3042
3026
},
3027
+ py::arg (" arg_locs" ) = std::nullopt,
3043
3028
" Creates and returns a new Block after this block "
3044
- " (with given argument types)." )
3029
+ " (with given argument types and locations )." )
3045
3030
.def (
3046
3031
" __iter__" ,
3047
3032
[](PyBlock &self) {
0 commit comments