Skip to content

Commit dfaacc2

Browse files
committed
Addressing review comments
Representing TargetSystemSpec as a set of key-value pairs where key is the DeviceID (string) and the value is TargetDeviceSpec.
1 parent 344c5e0 commit dfaacc2

File tree

10 files changed

+243
-356
lines changed

10 files changed

+243
-356
lines changed

mlir/include/mlir/Dialect/DLTI/DLTIAttrs.td

Lines changed: 16 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -116,24 +116,21 @@ def DLTI_TargetSystemSpecAttr :
116116
}];
117117
let description = [{
118118
A system specification describes the overall system containing
119-
multiple devices, with each device having a unique ID
119+
multiple devices, with each device having a unique ID (string)
120120
and its corresponding TargetDeviceSpec object.
121121

122122
Example:
123123
dlti.target_system_spec =
124-
#dlti.target_device_spec<
125-
#dlti.dl_entry<"dlti.device_id", 0: ui32>,
126-
#dlti.dl_entry<"dlti.device_type", "CPU">>,
127-
#dlti.target_device_spec<
128-
#dlti.dl_entry<"dlti.device_id", 1: ui32>,
129-
#dlti.dl_entry<"dlti.device_type", "GPU">,
130-
#dlti.dl_entry<"dlti.max_vector_op_width", 64 : ui32>>,
131-
#dlti.target_device_spec<
132-
#dlti.dl_entry<"dlti.device_id", 2: ui32>,
133-
#dlti.dl_entry<"dlti.device_type", "XPU">>>
124+
#dlti.target_system_spec<
125+
"CPU": #dlti.target_device_spec<
126+
#dlti.dl_entry<"dlti.L1_cache_size_in_bytes", 4096: ui32>>,
127+
"GPU": #dlti.target_device_spec<
128+
#dlti.dl_entry<"dlti.max_vector_op_width", 64 : ui32>>,
129+
"XPU": #dlti.target_device_spec<
130+
#dlti.dl_entry<"dlti.max_vector_op_width", 4096 : ui32>>>
134131
}];
135132
let parameters = (ins
136-
ArrayRefParameter<"TargetDeviceSpecInterface", "">:$entries
133+
ArrayRefParameter<"DeviceIDTargetDeviceSpecPair", "">:$entries
137134
);
138135
let mnemonic = "target_system_spec";
139136
let genVerifyDecl = 1;
@@ -142,15 +139,15 @@ def DLTI_TargetSystemSpecAttr :
142139
/// Return the device specification that matches the given device ID
143140
std::optional<TargetDeviceSpecInterface>
144141
getDeviceSpecForDeviceID(
145-
TargetDeviceSpecInterface::DeviceID deviceID);
142+
TargetSystemSpecInterface::DeviceID deviceID);
146143
}];
147144
let extraClassDefinition = [{
148145
std::optional<TargetDeviceSpecInterface>
149146
$cppClass::getDeviceSpecForDeviceID(
150-
TargetDeviceSpecInterface::DeviceID deviceID) {
151-
for (TargetDeviceSpecInterface entry : getEntries()) {
152-
if (entry.getDeviceID() == deviceID)
153-
return entry;
147+
TargetSystemSpecInterface::DeviceID deviceID) {
148+
for (const auto& entry : getEntries()) {
149+
if (entry.first == deviceID)
150+
return entry.second;
154151
}
155152
return std::nullopt;
156153
}
@@ -173,15 +170,12 @@ def DLTI_TargetDeviceSpecAttr :
173170
}];
174171
let description = [{
175172
Each device specification describes a single device and its
176-
hardware properties. Each device specification must have a device_id
177-
and a device_type. In addition, the specification can contain any number
173+
hardware properties. Each device specification can contain any number
178174
of optional hardware properties (e.g., max_vector_op_width below).
179175

180176
Example:
181177
#dlti.target_device_spec<
182-
#dlti.dl_entry<"dlti.device_id", 1: ui32>,
183-
#dlti.dl_entry<"dlti.device_type", "GPU">,
184-
#dlti.dl_entry<"dlti.max_vector_op_width", 64 : ui32>>
178+
#dlti.dl_entry<"dlti.max_vector_op_width", 64 : ui32>>
185179
}];
186180
let parameters = (ins
187181
ArrayRefParameter<"DataLayoutEntryInterface", "">:$entries
@@ -190,28 +184,12 @@ def DLTI_TargetDeviceSpecAttr :
190184
let genVerifyDecl = 1;
191185
let assemblyFormat = "`<` $entries `>`";
192186
let extraClassDeclaration = [{
193-
/// Returns the device ID identifier.
194-
StringAttr getDeviceIDIdentifier();
195-
196-
/// Returns the device type identifier.
197-
StringAttr getDeviceTypeIdentifier();
198-
199187
/// Returns max vector op width identifier.
200188
StringAttr getMaxVectorOpWidthIdentifier();
201189

202190
/// Returns L1 cache size identifier
203191
StringAttr getL1CacheSizeInBytesIdentifier();
204192

205-
/// Returns the interface spec for device ID
206-
/// Since we verify that the spec contains device ID the function
207-
/// will return a valid spec.
208-
DataLayoutEntryInterface getSpecForDeviceID();
209-
210-
/// Returns the interface spec for device type
211-
/// Since we verify that the spec contains device type the function
212-
/// will return a valid spec.
213-
DataLayoutEntryInterface getSpecForDeviceType();
214-
215193
/// Returns the interface spec for max vector op width
216194
/// Since max vector op width is an optional property, this function will
217195
/// return a valid spec if the property is defined, otherwise it
@@ -223,9 +201,6 @@ def DLTI_TargetDeviceSpecAttr :
223201
/// return a valid spec if the property is defined, otherwise it
224202
/// will return an empty spec.
225203
DataLayoutEntryInterface getSpecForL1CacheSizeInBytes();
226-
227-
/// Return the value of device ID
228-
TargetDeviceSpecInterface::DeviceID getDeviceID();
229204
}];
230205
}
231206

mlir/include/mlir/Dialect/DLTI/DLTIBase.td

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -57,12 +57,6 @@ def DLTI_Dialect : Dialect {
5757
kDataLayoutStackAlignmentKey = "dlti.stack_alignment";
5858

5959
// Constants used in target description part of DLTI.
60-
constexpr const static ::llvm::StringLiteral
61-
kTargetDeviceIDKey = "dlti.device_id";
62-
63-
constexpr const static ::llvm::StringLiteral
64-
kTargetDeviceTypeKey = "dlti.device_type";
65-
6660
constexpr const static ::llvm::StringLiteral
6761
kTargetDeviceMaxVectorOpWidthKey = "dlti.max_vector_op_width";
6862

mlir/include/mlir/Interfaces/DataLayoutInterfaces.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@ using DataLayoutEntryKey = llvm::PointerUnion<Type, StringAttr>;
3131
using DataLayoutEntryList = llvm::SmallVector<DataLayoutEntryInterface, 4>;
3232
using DataLayoutEntryListRef = llvm::ArrayRef<DataLayoutEntryInterface>;
3333
using TargetDeviceSpecListRef = llvm::ArrayRef<TargetDeviceSpecInterface>;
34+
using DeviceIDTargetDeviceSpecPair =
35+
std::pair<StringAttr, TargetDeviceSpecInterface>;
36+
using DeviceIDTargetDeviceSpecPairListRef =
37+
llvm::ArrayRef<DeviceIDTargetDeviceSpecPair>;
3438
class DataLayoutOpInterface;
3539
class DataLayoutSpecInterface;
3640
class ModuleOp;
@@ -246,12 +250,12 @@ class DataLayout {
246250
/// Returns for max vector op width if the property is defined for the given
247251
/// device ID, otherwise return std::nullopt.
248252
std::optional<uint32_t>
249-
getMaxVectorOpWidth(TargetDeviceSpecInterface::DeviceID) const;
253+
getMaxVectorOpWidth(TargetSystemSpecInterface::DeviceID) const;
250254

251255
/// Returns for L1 cache size if the property is defined for the given
252256
/// device ID, otherwise return std::nullopt.
253257
std::optional<uint32_t>
254-
getL1CacheSizeInBytes(TargetDeviceSpecInterface::DeviceID) const;
258+
getL1CacheSizeInBytes(TargetSystemSpecInterface::DeviceID) const;
255259

256260
private:
257261
/// Combined layout spec at the given scope.

mlir/include/mlir/Interfaces/DataLayoutInterfaces.td

Lines changed: 7 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -238,20 +238,6 @@ def TargetDeviceSpecInterface : AttrInterface<"TargetDeviceSpecInterface"> {
238238
/*methodBody=*/"",
239239
/*defaultImplementation=*/[{ return ::mlir::success(); }]
240240
>,
241-
InterfaceMethod<
242-
/*description=*/"Returns the entry related to Device ID. The function"
243-
"will crash if the entry is missing.",
244-
/*retTy=*/"::mlir::DataLayoutEntryInterface",
245-
/*methodName=*/"getSpecForDeviceID",
246-
/*args=*/(ins)
247-
>,
248-
InterfaceMethod<
249-
/*description=*/"Returns the entry related to Device Type. "
250-
"The function will crash if the entry is missing.",
251-
/*retTy=*/"::mlir::DataLayoutEntryInterface",
252-
/*methodName=*/"getSpecForDeviceType",
253-
/*args=*/(ins)
254-
>,
255241
InterfaceMethod<
256242
/*description=*/"Returns the entry related to the given identifier, if "
257243
"present. Otherwise, return empty spec.",
@@ -265,19 +251,8 @@ def TargetDeviceSpecInterface : AttrInterface<"TargetDeviceSpecInterface"> {
265251
/*retTy=*/"::mlir::DataLayoutEntryInterface",
266252
/*methodName=*/"getSpecForL1CacheSizeInBytes",
267253
/*args=*/(ins)
268-
>,
269-
InterfaceMethod<
270-
/*description=*/"Returns the entry related to the given identifier, if "
271-
"present.",
272-
/*retTy=*/"uint32_t",
273-
/*methodName=*/"getDeviceID",
274-
/*args=*/(ins)
275-
>,
254+
>
276255
];
277-
278-
let extraClassDeclaration = [{
279-
using DeviceID = uint32_t;
280-
}];
281256
}
282257

283258
def TargetSystemSpecInterface : AttrInterface<"TargetSystemSpecInterface"> {
@@ -300,7 +275,7 @@ def TargetSystemSpecInterface : AttrInterface<"TargetSystemSpecInterface"> {
300275
let methods = [
301276
InterfaceMethod<
302277
/*description=*/"Returns the list of layout entries.",
303-
/*retTy=*/"llvm::ArrayRef<::mlir::TargetDeviceSpecInterface>",
278+
/*retTy=*/"llvm::ArrayRef<DeviceIDTargetDeviceSpecPair>",
304279
/*methodName=*/"getEntries",
305280
/*args=*/(ins)
306281
>,
@@ -309,7 +284,7 @@ def TargetSystemSpecInterface : AttrInterface<"TargetSystemSpecInterface"> {
309284
"ID",
310285
/*retTy=*/"std::optional<::mlir::TargetDeviceSpecInterface>",
311286
/*methodName=*/"getDeviceSpecForDeviceID",
312-
/*args=*/(ins "TargetDeviceSpecInterface::DeviceID":$deviceID)
287+
/*args=*/(ins "StringAttr":$deviceID)
313288
>,
314289
InterfaceMethod<
315290
/*description=*/"Verifies the validity of the specification and "
@@ -323,6 +298,10 @@ def TargetSystemSpecInterface : AttrInterface<"TargetSystemSpecInterface"> {
323298
}]
324299
>
325300
];
301+
302+
let extraClassDeclaration = [{
303+
using DeviceID = StringAttr;
304+
}];
326305
}
327306

328307
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)