Skip to content

Commit d3811d8

Browse files
committed
Reimplementing target description concept using DLTI attribute
1 parent 7d9634e commit d3811d8

File tree

14 files changed

+1148
-8
lines changed

14 files changed

+1148
-8
lines changed

mlir/include/mlir/Dialect/DLTI/DLTI.h

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ namespace mlir {
2121
namespace impl {
2222
class DataLayoutEntryStorage;
2323
class DataLayoutSpecStorage;
24+
class TargetSystemDescSpecAttrStorage;
25+
class TargetDeviceDescSpecAttrStorage;
2426
} // namespace impl
2527

2628
//===----------------------------------------------------------------------===//
@@ -124,6 +126,150 @@ class DataLayoutSpecAttr
124126
static constexpr StringLiteral name = "builtin.data_layout_spec";
125127
};
126128

129+
//===----------------------------------------------------------------------===//
130+
// TargetSystemDescSpecAttr
131+
//===----------------------------------------------------------------------===//
132+
133+
/// A system description attribute is a list of device descriptors, each
134+
/// having a unique device ID
135+
class TargetSystemDescSpecAttr
136+
: public Attribute::AttrBase<TargetSystemDescSpecAttr, Attribute,
137+
impl::TargetSystemDescSpecAttrStorage,
138+
TargetSystemDescSpecInterface::Trait> {
139+
public:
140+
using Base::Base;
141+
142+
/// The keyword used for this attribute in custom syntax.
143+
constexpr const static StringLiteral kAttrKeyword = "tsd_spec";
144+
145+
/// Returns a system descriptor attribute from the given system descriptor
146+
static TargetSystemDescSpecAttr
147+
get(MLIRContext *context, ArrayRef<TargetDeviceDescSpecInterface> entries);
148+
149+
/// Returns the list of entries.
150+
TargetDeviceDescSpecListRef getEntries() const;
151+
152+
/// Return the device descriptor that matches the given device ID
153+
TargetDeviceDescSpecInterface getDeviceDescForDeviceID(uint32_t deviceID);
154+
155+
/// Returns the specification containing the given list of keys. If the list
156+
/// contains duplicate keys or is otherwise invalid, reports errors using the
157+
/// given callback and returns null.
158+
static TargetSystemDescSpecAttr
159+
getChecked(function_ref<InFlightDiagnostic()> emitError, MLIRContext *context,
160+
ArrayRef<TargetDeviceDescSpecInterface> entries);
161+
162+
/// Checks that the given list of entries does not contain duplicate keys.
163+
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
164+
ArrayRef<TargetDeviceDescSpecInterface> entries);
165+
166+
/// Parses an instance of this attribute.
167+
static TargetSystemDescSpecAttr parse(AsmParser &parser);
168+
169+
/// Prints this attribute.
170+
void print(AsmPrinter &os) const;
171+
172+
static constexpr StringLiteral name = "builtin.target_system_description";
173+
};
174+
175+
//===----------------------------------------------------------------------===//
176+
// TargetDeviceDescSpecAttr
177+
//===----------------------------------------------------------------------===//
178+
179+
class TargetDeviceDescSpecAttr
180+
: public Attribute::AttrBase<TargetDeviceDescSpecAttr, Attribute,
181+
impl::TargetDeviceDescSpecAttrStorage,
182+
TargetDeviceDescSpecInterface::Trait> {
183+
public:
184+
using Base::Base;
185+
186+
/// The keyword used for this attribute in custom syntax.
187+
constexpr const static StringLiteral kAttrKeyword = "tdd_spec";
188+
189+
/// Returns a system descriptor attribute from the given system descriptor
190+
static TargetDeviceDescSpecAttr
191+
get(MLIRContext *context, ArrayRef<DataLayoutEntryInterface> entries);
192+
193+
/// Returns the specification containing the given list of keys. If the list
194+
/// contains duplicate keys or is otherwise invalid, reports errors using the
195+
/// given callback and returns null.
196+
static TargetDeviceDescSpecAttr
197+
getChecked(function_ref<InFlightDiagnostic()> emitError, MLIRContext *context,
198+
ArrayRef<DataLayoutEntryInterface> entries);
199+
200+
/// Checks that the given list of entries does not contain duplicate keys.
201+
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
202+
ArrayRef<DataLayoutEntryInterface> entries);
203+
204+
/// Returns the list of entries.
205+
DataLayoutEntryListRef getEntries() const;
206+
207+
/// Parses an instance of this attribute.
208+
static TargetDeviceDescSpecAttr parse(AsmParser &parser);
209+
210+
/// Prints this attribute.
211+
void print(AsmPrinter &os) const;
212+
213+
/// Returns the device ID identifier.
214+
StringAttr getDeviceIDIdentifier(MLIRContext *context);
215+
216+
/// Returns the device type identifier.
217+
StringAttr getDeviceTypeIdentifier(MLIRContext *context);
218+
219+
/// Returns max vector op width identifier.
220+
StringAttr getMaxVectorOpWidthIdentifier(MLIRContext *context);
221+
222+
/// Returns canonicalizer max iterations identifier.
223+
StringAttr getCanonicalizerMaxIterationsIdentifier(MLIRContext *context);
224+
225+
/// Returns canonicalizer max num rewrites identifier.
226+
StringAttr getCanonicalizerMaxNumRewritesIdentifier(MLIRContext *context);
227+
228+
/// Returns L1 cache size identifier
229+
StringAttr getL1CacheSizeInBytesIdentifier(MLIRContext *context);
230+
231+
/// Returns the interface spec for device ID
232+
/// Since we verify that the spec contains device ID the function
233+
/// will return a valid spec.
234+
DataLayoutEntryInterface getSpecForDeviceID(MLIRContext *context);
235+
236+
/// Returns the interface spec for device type
237+
/// Since we verify that the spec contains device type the function
238+
/// will return a valid spec.
239+
DataLayoutEntryInterface getSpecForDeviceType(MLIRContext *context);
240+
241+
/// Returns the interface spec for max vector op width
242+
/// Since max vector op width is an optional property, this function will
243+
/// return a valid spec if the property is defined, otherwise it
244+
/// will return an empty spec.
245+
DataLayoutEntryInterface getSpecForMaxVectorOpWidth(MLIRContext *context);
246+
247+
/// Returns the interface spec for L1 cache size
248+
/// Since L1 cache size is an optional property, this function will
249+
/// return a valid spec if the property is defined, otherwise it
250+
/// will return an empty spec.
251+
DataLayoutEntryInterface getSpecForL1CacheSizeInBytes(MLIRContext *context);
252+
253+
/// Returns the interface spec for canonicalizer max iterations.
254+
/// Since this is an optional property, this function will
255+
/// return a valid spec if the property is defined, otherwise it
256+
/// will return an empty spec.
257+
DataLayoutEntryInterface
258+
getSpecForCanonicalizerMaxIterations(MLIRContext *context);
259+
260+
/// Returns the interface spec for canonicalizer max num rewrites.
261+
/// Since this is an optional property, this function will
262+
/// return a valid spec if the property is defined, otherwise it
263+
/// will return an empty spec.
264+
DataLayoutEntryInterface
265+
getSpecForCanonicalizerMaxNumRewrites(MLIRContext *context);
266+
267+
/// Return the value of device ID
268+
uint32_t getDeviceID(MLIRContext *context);
269+
270+
static constexpr StringLiteral name = "builtin.target_device_description";
271+
};
272+
127273
} // namespace mlir
128274

129275
#include "mlir/Dialect/DLTI/DLTIDialect.h.inc"

mlir/include/mlir/Dialect/DLTI/DLTIBase.td

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,13 @@ def DLTI_Dialect : Dialect {
2727
constexpr const static ::llvm::StringLiteral
2828
kDataLayoutAttrName = "dlti.dl_spec";
2929

30+
// Top level attribute name for target system description
31+
constexpr const static ::llvm::StringLiteral
32+
kTargetSystemDescAttrName = "dlti.tsd_spec";
33+
34+
constexpr const static ::llvm::StringLiteral
35+
kTargetDeviceDescAttrName = "dlti.tdd_spec";
36+
3037
// Constants used in entries.
3138
constexpr const static ::llvm::StringLiteral
3239
kDataLayoutEndiannessKey = "dlti.endianness";
@@ -48,6 +55,25 @@ def DLTI_Dialect : Dialect {
4855

4956
constexpr const static ::llvm::StringLiteral
5057
kDataLayoutStackAlignmentKey = "dlti.stack_alignment";
58+
59+
// Constants used in target description part of DLTI
60+
constexpr const static ::llvm::StringLiteral
61+
kTargetDeviceIDKey = "dlti.device_id";
62+
63+
constexpr const static ::llvm::StringLiteral
64+
kTargetDeviceTypeKey = "dlti.device_type";
65+
66+
constexpr const static ::llvm::StringLiteral
67+
kTargetDeviceMaxVectorOpWidthKey = "dlti.max_vector_op_width";
68+
69+
constexpr const static ::llvm::StringLiteral
70+
kTargetDeviceCanonicalizerMaxIterationsKey = "dlti.canonicalizer_max_iterations";
71+
72+
constexpr const static ::llvm::StringLiteral
73+
kTargetDeviceCanonicalizerMaxNumRewritesKey = "dlti.canonicalizer_max_num_rewrites";
74+
75+
constexpr const static ::llvm::StringLiteral
76+
kTargetDeviceL1CacheSizeInBytesKey = "dlti.L1_cache_size_in_bytes";
5177
}];
5278

5379
let useDefaultAttributePrinterParser = 1;
@@ -71,6 +97,24 @@ def DLTI_DataLayoutSpecAttr : DialectAttr<
7197
let convertFromStorage = "$_self";
7298
}
7399

100+
def DLTI_TargetSystemDescSpecAttr : DialectAttr<
101+
DLTI_Dialect,
102+
CPred<"::llvm::isa<::mlir::TargetSystemDescSpecAttr>($_self)">,
103+
"Target system description part of DLTI"> {
104+
let storageType = "::mlir::TargetSystemDescSpecAttr";
105+
let returnType = "::mlir::TargetSystemDescSpecAttr";
106+
let convertFromStorage = "$_self";
107+
}
108+
109+
def DLTI_TargetDeviceDescSpecAttr : DialectAttr<
110+
DLTI_Dialect,
111+
CPred<"::llvm::isa<::mlir::TargetDeviceDescSpecAttr>($_self)">,
112+
"Target device description part of DLTI"> {
113+
let storageType = "::mlir::TargetDeviceDescSpecAttr";
114+
let returnType = "::mlir::TargetDeviceDescSpecAttr";
115+
let convertFromStorage = "$_self";
116+
}
117+
74118
def HasDefaultDLTIDataLayout : NativeOpTrait<"HasDefaultDLTIDataLayout"> {
75119
let cppNamespace = "::mlir";
76120
}

mlir/include/mlir/Dialect/DLTI/Traits.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ class DataLayoutSpecAttr;
1818
namespace impl {
1919
LogicalResult verifyHasDefaultDLTIDataLayoutTrait(Operation *op);
2020
DataLayoutSpecInterface getDataLayoutSpec(Operation *op);
21+
TargetSystemDescSpecInterface getTargetSystemDescSpec(Operation *op);
2122
} // namespace impl
2223

2324
/// Trait to be used by operations willing to use the implementation of the
@@ -37,6 +38,12 @@ class HasDefaultDLTIDataLayout
3738
DataLayoutSpecInterface getDataLayoutSpec() {
3839
return impl::getDataLayoutSpec(this->getOperation());
3940
}
41+
42+
/// Returns the target system description specification as provided by DLTI
43+
/// dialect
44+
TargetSystemDescSpecInterface getTargetSystemDescSpec() {
45+
return impl::getTargetSystemDescSpec(this->getOperation());
46+
}
4047
};
4148
} // namespace mlir
4249

mlir/include/mlir/IR/BuiltinOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ def ModuleOp : Builtin_Op<"module", [
7878
//===------------------------------------------------------------------===//
7979

8080
DataLayoutSpecInterface getDataLayoutSpec();
81+
TargetSystemDescSpecInterface getTargetSystemDescSpec();
8182

8283
//===------------------------------------------------------------------===//
8384
// OpAsmOpInterface Methods

mlir/include/mlir/Interfaces/DataLayoutInterfaces.h

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,17 @@
2323
namespace mlir {
2424
class DataLayout;
2525
class DataLayoutEntryInterface;
26+
class TargetDeviceDescSpecInterface;
27+
class TargetSystemDescSpecInterface;
2628
using DataLayoutEntryKey = llvm::PointerUnion<Type, StringAttr>;
2729
// Using explicit SmallVector size because we cannot infer the size from the
2830
// forward declaration, and we need the typedef in the actual declaration.
2931
using DataLayoutEntryList = llvm::SmallVector<DataLayoutEntryInterface, 4>;
3032
using DataLayoutEntryListRef = llvm::ArrayRef<DataLayoutEntryInterface>;
33+
// using TargetDeviceDescSpecList =
34+
// llvm::SmallVector<TargetDeviceDescSpecInterface, 4>;
35+
using TargetDeviceDescSpecListRef =
36+
llvm::ArrayRef<TargetDeviceDescSpecInterface>;
3137
class DataLayoutOpInterface;
3238
class DataLayoutSpecInterface;
3339
class ModuleOp;
@@ -84,6 +90,24 @@ Attribute getDefaultGlobalMemorySpace(DataLayoutEntryInterface entry);
8490
/// DataLayoutInterface if specified, otherwise returns the default.
8591
uint64_t getDefaultStackAlignment(DataLayoutEntryInterface entry);
8692

93+
/// return max vector op width from the specified DataLayoutEntry. If the
94+
/// property is missing from the entry, then return std::nullopt.
95+
std::optional<uint32_t> getMaxVectorOpWidth(DataLayoutEntryInterface entry);
96+
97+
/// return L1 cache size in bytes from the specified DataLayoutEntry. If the
98+
/// property is missing from the entry, then return std::nullopt.
99+
std::optional<uint32_t> getL1CacheSizeInBytes(DataLayoutEntryInterface entry);
100+
101+
/// return canonicalizer max iterations from the specified DataLayoutEntry.
102+
/// If the property is missing from the entry, then return std::nullopt.
103+
std::optional<int64_t>
104+
getCanonicalizerMaxIterations(DataLayoutEntryInterface entry);
105+
106+
/// returncanonicalizer max num rewrites from the specified DataLayoutEntry.
107+
/// If the property is missing from the entry, then return std::nullopt.
108+
std::optional<int64_t>
109+
getCanonicalizerMaxNumRewrites(DataLayoutEntryInterface entry);
110+
87111
/// Given a list of data layout entries, returns a new list containing the
88112
/// entries with keys having the given type ID, i.e. belonging to the same type
89113
/// class.
@@ -95,6 +119,11 @@ DataLayoutEntryList filterEntriesForType(DataLayoutEntryListRef entries,
95119
DataLayoutEntryInterface
96120
filterEntryForIdentifier(DataLayoutEntryListRef entries, StringAttr id);
97121

122+
/// Given a list of target device entries, returns the entry that has the given
123+
/// identifier as key, if such an entry exists in the list.
124+
TargetDeviceDescSpecInterface
125+
filterEntryForIdentifier(TargetDeviceDescSpecListRef entries, StringAttr id);
126+
98127
/// Verifies that the operation implementing the data layout interface, or a
99128
/// module operation, is valid. This calls the verifier of the spec attribute
100129
/// and checks if the layout is compatible with specs attached to the enclosing
@@ -106,6 +135,12 @@ LogicalResult verifyDataLayoutOp(Operation *op);
106135
/// and dialect interfaces for type and identifier keys respectively.
107136
LogicalResult verifyDataLayoutSpec(DataLayoutSpecInterface spec, Location loc);
108137

138+
/// Verifies that a target system desc spec is valid. This dispatches to
139+
/// individual entry verifiers, and then to the verifiers implemented by the
140+
/// relevant dialect interfaces for identifier keys.
141+
LogicalResult verifyTargetSystemDescSpec(TargetSystemDescSpecInterface spec,
142+
Location loc);
143+
109144
/// Divides the known min value of the numerator by the denominator and rounds
110145
/// the result up to the next integer. Preserves the scalable flag.
111146
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator);
@@ -137,6 +172,13 @@ class DataLayoutDialectInterface
137172
return success();
138173
}
139174

175+
/// Checks whether the given data layout entry is valid and reports any errors
176+
/// at the provided location. Derived classes should override this.
177+
virtual LogicalResult verifyEntry(TargetDeviceDescSpecInterface entry,
178+
Location loc) const {
179+
return success();
180+
}
181+
140182
/// Default implementation of entry combination that combines identical
141183
/// entries and returns null otherwise.
142184
static DataLayoutEntryInterface
@@ -214,10 +256,33 @@ class DataLayout {
214256
/// unspecified.
215257
uint64_t getStackAlignment() const;
216258

259+
/// Returns for max vector op width if the property is defined for the given
260+
/// device ID, otherwise return std::nullopt.
261+
std::optional<uint32_t>
262+
getMaxVectorOpWidth(TargetDeviceDescSpecInterface::DeviceID) const;
263+
264+
/// Returns for L1 cache size if the property is defined for the given
265+
/// device ID, otherwise return std::nullopt.
266+
std::optional<uint32_t>
267+
getL1CacheSizeInBytes(TargetDeviceDescSpecInterface::DeviceID) const;
268+
269+
/// Returns for canonicalizer max iterations if the property is defined for
270+
/// the given device ID, otherwise return std::nullopt.
271+
std::optional<int64_t> getCanonicalizerMaxIterations(
272+
TargetDeviceDescSpecInterface::DeviceID) const;
273+
274+
/// Returns for canonicalizer max rewrites if the property is defined for
275+
/// the given device ID, otherwise return std::nullopt.
276+
std::optional<int64_t> getCanonicalizerMaxNumRewrites(
277+
TargetDeviceDescSpecInterface::DeviceID) const;
278+
217279
private:
218280
/// Combined layout spec at the given scope.
219281
const DataLayoutSpecInterface originalLayout;
220282

283+
/// Combined target system desc spec at the given scope.
284+
const TargetSystemDescSpecInterface originalTargetSystemDesc;
285+
221286
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
222287
/// List of enclosing layout specs.
223288
SmallVector<DataLayoutSpecInterface, 2> layoutStack;

0 commit comments

Comments
 (0)