Skip to content

Commit a3a55d6

Browse files
committed
address reviewer comments
1 parent e8c8e66 commit a3a55d6

File tree

6 files changed

+44
-35
lines changed

6 files changed

+44
-35
lines changed

mlir/include/mlir/Interfaces/DataLayoutInterfaces.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ using DataLayoutEntryList = llvm::SmallVector<DataLayoutEntryInterface, 4>;
3434
using DataLayoutEntryListRef = llvm::ArrayRef<DataLayoutEntryInterface>;
3535
using TargetDeviceSpecListRef = llvm::ArrayRef<TargetDeviceSpecInterface>;
3636
using TargetDeviceSpecEntry = std::pair<StringAttr, TargetDeviceSpecInterface>;
37+
using DataLayoutIdentifiedEntryMap =
38+
::llvm::DenseMap<::mlir::StringAttr, ::mlir::DataLayoutEntryInterface>;
3739
class DataLayoutOpInterface;
3840
class DataLayoutSpecInterface;
3941
class ModuleOp;

mlir/include/mlir/Interfaces/DataLayoutInterfaces.td

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -655,13 +655,15 @@ def DataLayoutTypeInterface : TypeInterface<"DataLayoutTypeInterface"> {
655655
InterfaceMethod<
656656
/*desc=*/"Returns true if the two lists of entries are compatible, that "
657657
"is, that `newLayout` spec entries can be nested in an op with "
658-
"`oldLayout` spec entries. `newSpec` is provided to further"
659-
"query data from the spec, e.g., the default address space.",
658+
"`oldLayout` spec entries. `newSpec` and `identified` are"
659+
"provided to further query data from the combined spec, e.g.,"
660+
"the default address space.",
660661
/*retTy=*/"bool",
661662
/*methodName=*/"areCompatible",
662663
/*args=*/(ins "::mlir::DataLayoutEntryListRef":$oldLayout,
663664
"::mlir::DataLayoutEntryListRef":$newLayout,
664-
"::mlir::DataLayoutSpecInterface":$newSpec),
665+
"::mlir::DataLayoutSpecInterface":$newSpec,
666+
"const ::mlir::DataLayoutIdentifiedEntryMap&":$identified),
665667
/*methodBody=*/"",
666668
/*defaultImplementation=*/[{ return true; }]
667669
>,

mlir/lib/Dialect/DLTI/DLTI.cpp

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -305,26 +305,8 @@ combineOneSpec(DataLayoutSpecInterface spec,
305305
DenseMap<StringAttr, DataLayoutEntryInterface> newEntriesForID;
306306
spec.bucketEntriesByType(newEntriesForType, newEntriesForID);
307307

308-
// Try overwriting the old entries with the new ones.
309-
for (auto &kvp : newEntriesForType) {
310-
if (!entriesForType.count(kvp.first)) {
311-
entriesForType[kvp.first] = std::move(kvp.second);
312-
continue;
313-
}
314-
315-
Type typeSample = cast<Type>(kvp.second.front().getKey());
316-
assert(&typeSample.getDialect() !=
317-
typeSample.getContext()->getLoadedDialect<BuiltinDialect>() &&
318-
"unexpected data layout entry for built-in type");
319-
320-
auto interface = cast<DataLayoutTypeInterface>(typeSample);
321-
if (!interface.areCompatible(entriesForType.lookup(kvp.first), kvp.second,
322-
spec))
323-
return failure();
324-
325-
overwriteDuplicateEntries(entriesForType[kvp.first], kvp.second);
326-
}
327-
308+
// Combine non-Type DL entries first so they are visible to the
309+
// `type.areCompatible` method, allowing to query global properties.
328310
for (const auto &kvp : newEntriesForID) {
329311
StringAttr id = cast<StringAttr>(kvp.second.getKey());
330312
Dialect *dialect = id.getReferencedDialect();
@@ -333,7 +315,7 @@ combineOneSpec(DataLayoutSpecInterface spec,
333315
continue;
334316
}
335317

336-
// Attempt to combine the enties using the dialect interface. If the
318+
// Attempt to combine the entries using the dialect interface. If the
337319
// dialect is not loaded for some reason, use the default combinator
338320
// that conservatively accepts identical entries only.
339321
entriesForID[id] =
@@ -345,6 +327,26 @@ combineOneSpec(DataLayoutSpecInterface spec,
345327
return failure();
346328
}
347329

330+
// Try overwriting the old entries with the new ones.
331+
for (auto &kvp : newEntriesForType) {
332+
if (!entriesForType.count(kvp.first)) {
333+
entriesForType[kvp.first] = std::move(kvp.second);
334+
continue;
335+
}
336+
337+
Type typeSample = cast<Type>(kvp.second.front().getKey());
338+
assert(&typeSample.getDialect() !=
339+
typeSample.getContext()->getLoadedDialect<BuiltinDialect>() &&
340+
"unexpected data layout entry for built-in type");
341+
342+
auto interface = cast<DataLayoutTypeInterface>(typeSample);
343+
if (!interface.areCompatible(entriesForType.lookup(kvp.first), kvp.second,
344+
spec, entriesForID))
345+
return failure();
346+
347+
overwriteDuplicateEntries(entriesForType[kvp.first], kvp.second);
348+
}
349+
348350
return success();
349351
}
350352

mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -349,9 +349,10 @@ LLVMPointerType::getIndexBitwidth(const DataLayout &dataLayout,
349349
return dataLayout.getTypeIndexBitwidth(get(getContext()));
350350
}
351351

352-
bool LLVMPointerType::areCompatible(DataLayoutEntryListRef oldLayout,
353-
DataLayoutEntryListRef newLayout,
354-
DataLayoutSpecInterface newSpec) const {
352+
bool LLVMPointerType::areCompatible(
353+
DataLayoutEntryListRef oldLayout, DataLayoutEntryListRef newLayout,
354+
DataLayoutSpecInterface newSpec,
355+
const DataLayoutIdentifiedEntryMap &map) const {
355356
for (DataLayoutEntryInterface newEntry : newLayout) {
356357
if (!newEntry.isTypeEntry())
357358
continue;
@@ -598,9 +599,10 @@ static uint64_t extractStructSpecValue(Attribute attr, StructDLEntryPos pos) {
598599
.getValues<uint64_t>()[static_cast<size_t>(pos)];
599600
}
600601

601-
bool LLVMStructType::areCompatible(DataLayoutEntryListRef oldLayout,
602-
DataLayoutEntryListRef newLayout,
603-
DataLayoutSpecInterface newSpec) const {
602+
bool LLVMStructType::areCompatible(
603+
DataLayoutEntryListRef oldLayout, DataLayoutEntryListRef newLayout,
604+
DataLayoutSpecInterface newSpec,
605+
const DataLayoutIdentifiedEntryMap &map) const {
604606
for (DataLayoutEntryInterface newEntry : newLayout) {
605607
if (!newEntry.isTypeEntry())
606608
continue;

mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,8 @@ static SpecAttr getPointerSpec(DataLayoutEntryListRef params, PtrType type,
4949

5050
bool PtrType::areCompatible(DataLayoutEntryListRef oldLayout,
5151
DataLayoutEntryListRef newLayout,
52-
DataLayoutSpecInterface newSpec) const {
52+
DataLayoutSpecInterface newSpec,
53+
const DataLayoutIdentifiedEntryMap &map) const {
5354
for (DataLayoutEntryInterface newEntry : newLayout) {
5455
if (!newEntry.isTypeEntry())
5556
continue;
@@ -65,9 +66,8 @@ bool PtrType::areCompatible(DataLayoutEntryListRef oldLayout,
6566
return false;
6667
});
6768
if (it == oldLayout.end()) {
68-
Attribute defaultMemorySpace =
69-
mlir::detail::getDefaultMemorySpace(newSpec.getSpecForIdentifier(
70-
newSpec.getDefaultMemorySpaceIdentifier(getContext())));
69+
Attribute defaultMemorySpace = mlir::detail::getDefaultMemorySpace(
70+
map.lookup(newSpec.getDefaultMemorySpaceIdentifier(getContext())));
7171
it = llvm::find_if(oldLayout, [&](DataLayoutEntryInterface entry) {
7272
if (auto type = llvm::dyn_cast_if_present<Type>(entry.getKey())) {
7373
auto ptrTy = llvm::cast<PtrType>(type);

mlir/test/lib/Dialect/Test/TestTypes.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,8 @@ TestTypeWithLayoutType::getIndexBitwidth(const DataLayout &dataLayout,
285285

286286
bool TestTypeWithLayoutType::areCompatible(
287287
DataLayoutEntryListRef oldLayout, DataLayoutEntryListRef newLayout,
288-
DataLayoutSpecInterface newSpec) const {
288+
DataLayoutSpecInterface newSpec,
289+
const DataLayoutIdentifiedEntryMap &map) const {
289290
unsigned old = extractKind(oldLayout, "alignment");
290291
return old == 1 || extractKind(newLayout, "alignment") <= old;
291292
}

0 commit comments

Comments
 (0)