Skip to content

Commit 7964793

Browse files
committed
Addressing review comments - removing methods for L1 size and max vector width
1 parent 45093d6 commit 7964793

File tree

9 files changed

+96
-266
lines changed

9 files changed

+96
-266
lines changed

mlir/include/mlir/Dialect/DLTI/DLTIAttrs.td

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -183,13 +183,6 @@ def DLTI_TargetDeviceSpecAttr :
183183
let mnemonic = "target_device_spec";
184184
let genVerifyDecl = 1;
185185
let assemblyFormat = "`<` $entries `>`";
186-
let extraClassDeclaration = [{
187-
/// Returns max vector op width identifier.
188-
StringAttr getMaxVectorOpWidthIdentifier();
189-
190-
/// Returns L1 cache size identifier
191-
StringAttr getL1CacheSizeInBytesIdentifier();
192-
}];
193186
}
194187

195188
#endif // MLIR_DIALECT_DLTI_DLTIATTRS_TD

mlir/include/mlir/Dialect/DLTI/DLTIBase.td

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -55,13 +55,6 @@ def DLTI_Dialect : Dialect {
5555

5656
constexpr const static ::llvm::StringLiteral
5757
kDataLayoutStackAlignmentKey = "dlti.stack_alignment";
58-
59-
// Constants used in target description part of DLTI.
60-
constexpr const static ::llvm::StringLiteral
61-
kTargetDeviceMaxVectorOpWidthKey = "dlti.max_vector_op_width";
62-
63-
constexpr const static ::llvm::StringLiteral
64-
kTargetDeviceL1CacheSizeInBytesKey = "dlti.L1_cache_size_in_bytes";
6558
}];
6659

6760
let useDefaultAttributePrinterParser = 1;

mlir/include/mlir/Interfaces/DataLayoutInterfaces.h

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -91,13 +91,10 @@ Attribute getDefaultGlobalMemorySpace(DataLayoutEntryInterface entry);
9191
/// DataLayoutInterface if specified, otherwise returns the default.
9292
uint64_t getDefaultStackAlignment(DataLayoutEntryInterface entry);
9393

94-
/// Return max vector op width from the specified DataLayoutEntry. If the
95-
/// property is missing from the entry, then return std::nullopt.
96-
std::optional<int64_t> getMaxVectorOpWidth(DataLayoutEntryInterface entry);
97-
98-
/// Return L1 cache size in bytes from the specified DataLayoutEntry. If the
99-
/// property is missing from the entry, then return std::nullopt.
100-
std::optional<int64_t> getL1CacheSizeInBytes(DataLayoutEntryInterface entry);
94+
/// Returns the value of the property from the specified DataLayoutEntry. If the
95+
/// property is missing from the entry, returns std::nullopt.
96+
std::optional<int64_t>
97+
getDevicePropertyValueAsInt(DataLayoutEntryInterface entry);
10198

10299
/// Given a list of data layout entries, returns a new list containing the
103100
/// entries with keys having the given type ID, i.e. belonging to the same type
@@ -247,15 +244,11 @@ class DataLayout {
247244
/// unspecified.
248245
uint64_t getStackAlignment() const;
249246

250-
/// Returns for max vector op width if the property is defined for the given
251-
/// device ID, otherwise return std::nullopt.
252-
std::optional<int64_t>
253-
getMaxVectorOpWidth(TargetSystemSpecInterface::DeviceID) const;
254-
255-
/// Returns for L1 cache size if the property is defined for the given
256-
/// device ID, otherwise return std::nullopt.
247+
/// Returns the value of the specified property if the property is defined for
248+
/// the given device ID, otherwise returns std::nullopt.
257249
std::optional<int64_t>
258-
getL1CacheSizeInBytes(TargetSystemSpecInterface::DeviceID) const;
250+
getDevicePropertyValueAsInt(TargetSystemSpecInterface::DeviceID,
251+
StringAttr propertyName) const;
259252

260253
private:
261254
/// Combined layout spec at the given scope.

mlir/include/mlir/Interfaces/DataLayoutInterfaces.td

Lines changed: 4 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -237,21 +237,7 @@ def TargetDeviceSpecInterface : AttrInterface<"TargetDeviceSpecInterface"> {
237237
/*args=*/(ins "::mlir::Location":$loc),
238238
/*methodBody=*/"",
239239
/*defaultImplementation=*/[{ return ::mlir::success(); }]
240-
>,
241-
InterfaceMethod<
242-
/*description=*/"Returns max vector op width identifier. ",
243-
/*retTy=*/"::mlir::StringAttr",
244-
/*methodName=*/"getMaxVectorOpWidthIdentifier",
245-
/*args=*/(ins),
246-
/*methodBody=*/""
247-
>,
248-
InterfaceMethod<
249-
/*description=*/"Returns L1 cache size identifier identifier. ",
250-
/*retTy=*/"::mlir::StringAttr",
251-
/*methodName=*/"getL1CacheSizeInBytesIdentifier",
252-
/*args=*/(ins),
253-
/*methodBody=*/""
254-
>,
240+
>
255241
];
256242
}
257243

@@ -480,25 +466,14 @@ def DataLayoutOpInterface : OpInterface<"DataLayoutOpInterface"> {
480466
}]
481467
>,
482468
StaticInterfaceMethod<
483-
/*description=*/"Returns the max vector op width, if the property is "
484-
"defined. Otherwise, it returns std::nullopt.",
485-
/*retTy=*/"std::optional<int64_t>",
486-
/*methodName=*/"getMaxVectorOpWidth",
487-
/*args=*/(ins "::mlir::DataLayoutEntryInterface":$entry),
488-
/*methodBody=*/"",
489-
/*defaultImplementation=*/[{
490-
return ::mlir::detail::getMaxVectorOpWidth(entry);
491-
}]
492-
>,
493-
StaticInterfaceMethod<
494-
/*description=*/"Returns the L1 cache size in bytes, if the property is "
469+
/*description=*/"Returns the value of the property, if the property is "
495470
"defined. Otherwise, it returns std::nullopt.",
496471
/*retTy=*/"std::optional<int64_t>",
497-
/*methodName=*/"getL1CacheSizeInBytes",
472+
/*methodName=*/"getDevicePropertyValueAsInt",
498473
/*args=*/(ins "::mlir::DataLayoutEntryInterface":$entry),
499474
/*methodBody=*/"",
500475
/*defaultImplementation=*/[{
501-
return ::mlir::detail::getL1CacheSizeInBytes(entry);
476+
return ::mlir::detail::getDevicePropertyValueAsInt(entry);
502477
}]
503478
>
504479
];

mlir/lib/Dialect/DLTI/DLTI.cpp

Lines changed: 1 addition & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -356,42 +356,11 @@ TargetDeviceSpecAttr::verify(function_ref<InFlightDiagnostic()> emitError,
356356
if (!ids.insert(id).second)
357357
return emitError() << "repeated layout entry key: " << id.getValue();
358358
}
359-
360-
// Check that required keys are of right type.
361-
StringRef entryName = entry.getKey().get<StringAttr>().strref();
362-
if (entryName == DLTIDialect::kTargetDeviceL1CacheSizeInBytesKey) {
363-
IntegerAttr value =
364-
llvm::dyn_cast_if_present<IntegerAttr>(entry.getValue());
365-
if (!value || !value.getType().isInteger())
366-
return emitError() << "target_device_spec requires value of key: "
367-
<< DLTIDialect::kTargetDeviceL1CacheSizeInBytesKey
368-
<< " to be of integer type";
369-
} else if (entryName == DLTIDialect::kTargetDeviceMaxVectorOpWidthKey) {
370-
IntegerAttr value =
371-
llvm::dyn_cast_if_present<IntegerAttr>(entry.getValue());
372-
if (!value || !value.getType().isInteger())
373-
return emitError() << "target_device_spec requires value of key: "
374-
<< DLTIDialect::kTargetDeviceMaxVectorOpWidthKey
375-
<< " to be of integer type";
376-
} else {
377-
return emitError() << "unknown target device spec key name: "
378-
<< entryName;
379-
}
380359
}
381360

382361
return success();
383362
}
384363

385-
StringAttr TargetDeviceSpecAttr::getMaxVectorOpWidthIdentifier() {
386-
return Builder(getContext())
387-
.getStringAttr(DLTIDialect::kTargetDeviceMaxVectorOpWidthKey);
388-
}
389-
390-
StringAttr TargetDeviceSpecAttr::getL1CacheSizeInBytesIdentifier() {
391-
return Builder(getContext())
392-
.getStringAttr(DLTIDialect::kTargetDeviceL1CacheSizeInBytesKey);
393-
}
394-
395364
//===----------------------------------------------------------------------===//
396365
// TargetSystemSpecAttr
397366
//===----------------------------------------------------------------------===//
@@ -457,34 +426,12 @@ class TargetDataLayoutInterface : public DataLayoutDialectInterface {
457426
};
458427
} // namespace
459428

460-
namespace {
461-
/// An interface to check entries of a target device spec.
462-
class SystemDescSpecInterface : public DataLayoutDialectInterface {
463-
public:
464-
using DataLayoutDialectInterface::DataLayoutDialectInterface;
465-
466-
LogicalResult verifyEntry(TargetDeviceSpecInterface entry,
467-
Location loc) const final {
468-
469-
for (DataLayoutEntryInterface dl_entry : entry.getEntries()) {
470-
StringRef entryName = dl_entry.getKey().get<StringAttr>().strref();
471-
// Check that the key name is known to us. Although, we may allow keys
472-
// unknown to us.
473-
if (entryName != DLTIDialect::kTargetDeviceMaxVectorOpWidthKey &&
474-
entryName != DLTIDialect::kTargetDeviceL1CacheSizeInBytesKey)
475-
return emitError(loc) << "unknown target desc key name: " << entryName;
476-
}
477-
return success();
478-
}
479-
};
480-
} // namespace
481-
482429
void DLTIDialect::initialize() {
483430
addAttributes<
484431
#define GET_ATTRDEF_LIST
485432
#include "mlir/Dialect/DLTI/DLTIAttrs.cpp.inc"
486433
>();
487-
addInterfaces<TargetDataLayoutInterface, SystemDescSpecInterface>();
434+
addInterfaces<TargetDataLayoutInterface>();
488435
}
489436

490437
LogicalResult DLTIDialect::verifyOperationAttribute(Operation *op,

mlir/lib/Interfaces/DataLayoutInterfaces.cpp

Lines changed: 26 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -293,21 +293,8 @@ mlir::detail::getDefaultStackAlignment(DataLayoutEntryInterface entry) {
293293
return value.getValue().getZExtValue();
294294
}
295295

296-
// Returns the max vector op width if specified in the given entry. If the entry
297-
// is empty (meaning the spec is missing), returns std::nullopt.
298296
std::optional<int64_t>
299-
mlir::detail::getMaxVectorOpWidth(DataLayoutEntryInterface entry) {
300-
if (entry == DataLayoutEntryInterface())
301-
return std::nullopt;
302-
303-
auto value = cast<IntegerAttr>(entry.getValue());
304-
return value.getValue().getZExtValue();
305-
}
306-
307-
// Returns the L1 cache size if specified in the given entry. If the entry
308-
// is empty (meaning the spec is missing), returns std::nullopt.
309-
std::optional<int64_t>
310-
mlir::detail::getL1CacheSizeInBytes(DataLayoutEntryInterface entry) {
297+
mlir::detail::getDevicePropertyValueAsInt(DataLayoutEntryInterface entry) {
311298
if (entry == DataLayoutEntryInterface())
312299
return std::nullopt;
313300

@@ -348,15 +335,12 @@ static DataLayoutSpecInterface getSpec(Operation *operation) {
348335

349336
static TargetSystemSpecInterface getTargetSystemSpec(Operation *operation) {
350337
if (operation) {
351-
ModuleOp moduleOp;
352-
if (isa<ModuleOp>(operation)) {
353-
moduleOp = llvm::dyn_cast<ModuleOp>(operation);
354-
} else {
338+
ModuleOp moduleOp = dyn_cast<ModuleOp>(operation);
339+
if (!moduleOp)
355340
moduleOp = operation->getParentOfType<ModuleOp>();
356-
}
357341
return moduleOp.getTargetSystemSpec();
358-
} else
359-
return TargetSystemSpecInterface();
342+
}
343+
return TargetSystemSpecInterface();
360344
}
361345

362346
/// Populates `opsWithLayout` with the list of proper ancestors of `leaf` that
@@ -677,44 +661,24 @@ uint64_t mlir::DataLayout::getStackAlignment() const {
677661
return *stackAlignment;
678662
}
679663

680-
std::optional<int64_t> mlir::DataLayout::getMaxVectorOpWidth(
681-
TargetSystemSpecInterface::DeviceID deviceID) const {
682-
checkValid();
683-
DataLayoutEntryInterface entry;
684-
if (originalTargetSystemDesc) {
685-
if (auto device =
686-
originalTargetSystemDesc.getDeviceSpecForDeviceID(deviceID))
687-
entry =
688-
device->getSpecForIdentifier(device->getMaxVectorOpWidthIdentifier());
689-
}
690-
// Currently I am not caching the results because we do not return
691-
// default values of these properties. Instead if the property is
692-
// missing, we return std::nullopt so that the users can resort to
693-
// the default value however they want.
694-
if (auto iface = dyn_cast_or_null<DataLayoutOpInterface>(scope))
695-
return iface.getMaxVectorOpWidth(entry);
696-
else
697-
return detail::getMaxVectorOpWidth(entry);
698-
}
699-
700-
std::optional<int64_t> mlir::DataLayout::getL1CacheSizeInBytes(
701-
TargetSystemSpecInterface::DeviceID deviceID) const {
664+
std::optional<int64_t> mlir::DataLayout::getDevicePropertyValueAsInt(
665+
TargetSystemSpecInterface::DeviceID deviceID,
666+
StringAttr propertyName) const {
702667
checkValid();
703668
DataLayoutEntryInterface entry;
704669
if (originalTargetSystemDesc) {
705-
if (auto device =
670+
if (std::optional<TargetDeviceSpecInterface> device =
706671
originalTargetSystemDesc.getDeviceSpecForDeviceID(deviceID))
707-
entry = device->getSpecForIdentifier(
708-
device->getL1CacheSizeInBytesIdentifier());
672+
entry = device->getSpecForIdentifier(propertyName);
709673
}
710674
// Currently I am not caching the results because we do not return
711675
// default values of these properties. Instead if the property is
712676
// missing, we return std::nullopt so that the users can resort to
713677
// the default value however they want.
714678
if (auto iface = dyn_cast_or_null<DataLayoutOpInterface>(scope))
715-
return iface.getL1CacheSizeInBytes(entry);
679+
return iface.getDevicePropertyValueAsInt(entry);
716680
else
717-
return detail::getL1CacheSizeInBytes(entry);
681+
return detail::getDevicePropertyValueAsInt(entry);
718682
}
719683

720684
//===----------------------------------------------------------------------===//
@@ -824,47 +788,46 @@ LogicalResult mlir::detail::verifyDataLayoutSpec(DataLayoutSpecInterface spec,
824788
LogicalResult
825789
mlir::detail::verifyTargetSystemSpec(TargetSystemSpecInterface spec,
826790
Location loc) {
827-
DenseMap<StringAttr, DataLayoutEntryInterface> device_desc_keys;
828-
DenseSet<TargetSystemSpecInterface::DeviceID> device_ids;
791+
DenseMap<StringAttr, DataLayoutEntryInterface> deviceDescKeys;
792+
DenseSet<TargetSystemSpecInterface::DeviceID> deviceIDs;
829793
for (const auto &entry : spec.getEntries()) {
830-
TargetDeviceSpecInterface tdd_spec = entry.second;
794+
TargetDeviceSpecInterface targetDeviceSpec = entry.second;
831795
// First, verify individual target device desc specs.
832-
if (failed(tdd_spec.verifyEntry(loc)))
796+
if (failed(targetDeviceSpec.verifyEntry(loc)))
833797
return failure();
834798

835799
// Check that device IDs are unique across all entries.
836-
TargetSystemSpecInterface::DeviceID device_id = entry.first;
837-
if (!device_ids.insert(device_id).second) {
800+
TargetSystemSpecInterface::DeviceID deviceID = entry.first;
801+
if (!deviceIDs.insert(deviceID).second) {
838802
return failure();
839803
}
840804

841-
// collect all the keys used by all the tdd_specs.
842-
for (DataLayoutEntryInterface entry : tdd_spec.getEntries()) {
805+
// collect all the keys used by all the target device specs.
806+
for (DataLayoutEntryInterface entry : targetDeviceSpec.getEntries()) {
843807
if (auto type = llvm::dyn_cast_if_present<Type>(entry.getKey())) {
844-
// tdd_spec does not support Type as a key.
808+
// targetDeviceSpec does not support Type as a key.
845809
return failure();
846810
} else {
847-
device_desc_keys[entry.getKey().get<StringAttr>()] = entry;
811+
deviceDescKeys[entry.getKey().get<StringAttr>()] = entry;
848812
}
849813
}
850814
}
851815

852-
for (const auto &kvp : device_desc_keys) {
853-
StringAttr identifier = kvp.second.getKey().get<StringAttr>();
854-
Dialect *dialect = identifier.getReferencedDialect();
816+
for (const auto &[keyName, keyVal] : deviceDescKeys) {
817+
Dialect *dialect = keyName.getReferencedDialect();
855818

856819
// Ignore attributes that belong to an unknown dialect, the dialect may
857820
// actually implement the relevant interface but we don't know about that.
858821
if (!dialect)
859-
continue;
822+
return failure();
860823

861824
const auto *iface = dyn_cast<DataLayoutDialectInterface>(dialect);
862825
if (!iface) {
863826
return emitError(loc)
864827
<< "the '" << dialect->getNamespace()
865828
<< "' dialect does not support identifier data layout entries";
866829
}
867-
if (failed(iface->verifyEntry(kvp.second, loc)))
830+
if (failed(iface->verifyEntry(keyVal, loc)))
868831
return failure();
869832
}
870833

0 commit comments

Comments
 (0)