Skip to content

Commit 36c2681

Browse files
committed
address reviewer comments
1 parent 1b51f55 commit 36c2681

File tree

9 files changed

+212
-117
lines changed

9 files changed

+212
-117
lines changed

mlir/include/mlir-c/Dialect/GPU.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,12 @@ MLIR_CAPI_EXPORTED bool mlirAttributeIsAGPUObjectAttr(MlirAttribute attr);
3535

3636
MLIR_CAPI_EXPORTED MlirAttribute
3737
mlirGPUObjectAttrGet(MlirContext mlirCtx, MlirAttribute target, uint32_t format,
38-
MlirStringRef objectStrRef, MlirAttribute mlirObjectProps,
39-
MlirAttribute mlirKernelsAttr);
38+
MlirStringRef objectStrRef, MlirAttribute mlirObjectProps);
39+
40+
MLIR_CAPI_EXPORTED MlirAttribute mlirGPUObjectAttrGetWithKernels(
41+
MlirContext mlirCtx, MlirAttribute target, uint32_t format,
42+
MlirStringRef objectStrRef, MlirAttribute mlirObjectProps,
43+
MlirAttribute mlirKernelsAttr);
4044

4145
MLIR_CAPI_EXPORTED MlirAttribute
4246
mlirGPUObjectAttrGetTarget(MlirAttribute mlirObjectAttr);

mlir/include/mlir/Dialect/GPU/IR/CompilationAttrs.td

Lines changed: 24 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -62,18 +62,23 @@ def GPU_KernelAttr : GPU_Attr<"Kernel", "kernel"> {
6262
];
6363
let genVerifyDecl = 1;
6464
let extraClassDeclaration = [{
65+
/// Compare two kernels based on the name.
66+
bool operator<(const KernelAttr& other) const {
67+
return getName().getValue() < other.getName().getValue();
68+
}
69+
6570
/// Returns the metadata attribute corresponding to `key` or `nullptr`
6671
/// if missing.
6772
Attribute getAttr(StringRef key) const {
68-
auto attrs = getMetadata();
73+
DictionaryAttr attrs = getMetadata();
6974
return attrs ? attrs.get(key) : nullptr;
7075
}
7176
template <typename ConcreteAttr>
7277
ConcreteAttr getAttr(StringRef key) const {
7378
return llvm::dyn_cast_or_null<ConcreteAttr>(getAttr(key));
7479
}
7580
Attribute getAttr(StringAttr key) const {
76-
auto attrs = getMetadata();
81+
DictionaryAttr attrs = getMetadata();
7782
return attrs ? attrs.get(key) : nullptr;
7883
}
7984
template <typename ConcreteAttr>
@@ -83,18 +88,18 @@ def GPU_KernelAttr : GPU_Attr<"Kernel", "kernel"> {
8388

8489
/// Returns the attribute dictionary at position `index`.
8590
DictionaryAttr getArgAttrDict(unsigned index) {
86-
auto argArray = getArgAttrs();
91+
ArrayAttr argArray = getArgAttrs();
8792
return argArray ? llvm::cast<DictionaryAttr>(argArray[index]) : nullptr;
8893
}
8994

9095
/// Return the specified attribute, if present, for the argument at 'index',
9196
/// null otherwise.
9297
Attribute getArgAttr(unsigned index, StringAttr name) {
93-
auto argDict = getArgAttrDict(index);
98+
DictionaryAttr argDict = getArgAttrDict(index);
9499
return argDict ? argDict.get(name) : nullptr;
95100
}
96101
Attribute getArgAttr(unsigned index, StringRef name) {
97-
auto argDict = getArgAttrDict(index);
102+
DictionaryAttr argDict = getArgAttrDict(index);
98103
return argDict ? argDict.get(name) : nullptr;
99104
}
100105

@@ -114,54 +119,38 @@ def GPU_KernelTableAttr : GPU_Attr<"KernelTable", "kernel_table"> {
114119

115120
Examples:
116121
```mlir
117-
#gpu.kernel_table<{kernel0 = #gpu.kernel<...>}>
122+
#gpu.kernel_table<[#gpu.kernel<kernel0, ...>]>
118123
```
119124
}];
120125
let parameters = (ins
121-
"DictionaryAttr":$kernel_table
126+
OptionalArrayRefParameter<"KernelAttr", "array of kernels">:$kernel_table
122127
);
123128
let assemblyFormat = [{
124-
`<` $kernel_table `>`
129+
`<` (`[` qualified($kernel_table)^ `]`)? `>`
125130
}];
126131
let builders = [
127-
AttrBuilderWithInferredContext<(ins "DictionaryAttr":$kernel_table), [{
128-
assert(kernel_table && "invalid kernel table");
129-
return $_get(kernel_table.getContext(), kernel_table);
130-
}]>
132+
AttrBuilder<(ins "ArrayRef<KernelAttr>":$kernels,
133+
CArg<"bool", "false">:$isSorted)>
131134
];
132135
let skipDefaultBuilders = 1;
133136
let genVerifyDecl = 1;
134137
let extraClassDeclaration = [{
135-
/// Helper iterator class for traversing the kernel table.
136-
struct KernelIterator
137-
: llvm::mapped_iterator_base<KernelIterator,
138-
llvm::ArrayRef<NamedAttribute>::iterator,
139-
std::pair<StringAttr, KernelAttr>> {
140-
using llvm::mapped_iterator_base<
141-
KernelIterator, llvm::ArrayRef<NamedAttribute>::iterator,
142-
std::pair<StringAttr, KernelAttr>>::mapped_iterator_base;
143-
/// Map the iterator to the kernel name and a KernelAttribute.
144-
std::pair<StringAttr, KernelAttr> mapElement(NamedAttribute attr) const {
145-
return {attr.getName(), llvm::cast<KernelAttr>(attr.getValue())};
146-
}
147-
};
148-
auto begin() const {
149-
return KernelIterator(getKernelTable().begin());
138+
llvm::ArrayRef<KernelAttr>::iterator begin() const {
139+
return getKernelTable().begin();
150140
}
151-
auto end() const {
152-
return KernelIterator(getKernelTable().end());
141+
llvm::ArrayRef<KernelAttr>::iterator end() const {
142+
return getKernelTable().end();
153143
}
154144
size_t size() const {
155145
return getKernelTable().size();
156146
}
147+
bool empty() const {
148+
return getKernelTable().empty();
149+
}
157150

158151
/// Returns the kernel with name `key` or `nullptr` if not present.
159-
KernelAttr lookup(StringRef key) const {
160-
return getKernelTable().getAs<KernelAttr>(key);
161-
}
162-
KernelAttr lookup(StringAttr key) const {
163-
return getKernelTable().getAs<KernelAttr>(key);
164-
}
152+
KernelAttr lookup(StringRef key) const;
153+
KernelAttr lookup(StringAttr key) const;
165154
}];
166155
}
167156

mlir/lib/CAPI/Dialect/GPU.cpp

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,23 @@ bool mlirAttributeIsAGPUObjectAttr(MlirAttribute attr) {
3737

3838
MlirAttribute mlirGPUObjectAttrGet(MlirContext mlirCtx, MlirAttribute target,
3939
uint32_t format, MlirStringRef objectStrRef,
40-
MlirAttribute mlirObjectProps,
41-
MlirAttribute mlirKernelsAttr) {
40+
MlirAttribute mlirObjectProps) {
41+
MLIRContext *ctx = unwrap(mlirCtx);
42+
llvm::StringRef object = unwrap(objectStrRef);
43+
DictionaryAttr objectProps;
44+
if (mlirObjectProps.ptr != nullptr)
45+
objectProps = llvm::cast<DictionaryAttr>(unwrap(mlirObjectProps));
46+
return wrap(gpu::ObjectAttr::get(
47+
ctx, unwrap(target), static_cast<gpu::CompilationTarget>(format),
48+
StringAttr::get(ctx, object), objectProps, nullptr));
49+
}
50+
51+
MlirAttribute mlirGPUObjectAttrGetWithKernels(MlirContext mlirCtx,
52+
MlirAttribute target,
53+
uint32_t format,
54+
MlirStringRef objectStrRef,
55+
MlirAttribute mlirObjectProps,
56+
MlirAttribute mlirKernelsAttr) {
4257
MLIRContext *ctx = unwrap(mlirCtx);
4358
llvm::StringRef object = unwrap(objectStrRef);
4459
DictionaryAttr objectProps;

mlir/lib/Dialect/GPU/IR/GPUDialect.cpp

Lines changed: 51 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2201,7 +2201,7 @@ KernelAttr KernelAttr::appendMetadata(ArrayRef<NamedAttribute> attrs) const {
22012201
if (attrs.empty())
22022202
return *this;
22032203
NamedAttrList attrList;
2204-
if (auto dict = getMetadata())
2204+
if (DictionaryAttr dict = getMetadata())
22052205
attrList.append(dict);
22062206
attrList.append(attrs);
22072207
return KernelAttr::get(getName(), getFunctionType(), getArgAttrs(),
@@ -2227,23 +2227,62 @@ LogicalResult KernelAttr::verify(function_ref<InFlightDiagnostic()> emitError,
22272227
// GPU KernelTableAttr
22282228
//===----------------------------------------------------------------------===//
22292229

2230+
KernelTableAttr KernelTableAttr::get(MLIRContext *context,
2231+
ArrayRef<KernelAttr> kernels,
2232+
bool isSorted) {
2233+
// Note that `is_sorted` is always only invoked once even with assertions ON.
2234+
assert((!isSorted || llvm::is_sorted(kernels)) &&
2235+
"expected a sorted kernel array");
2236+
// Immediately return the attribute if the array is sorted.
2237+
if (isSorted || llvm::is_sorted(kernels))
2238+
return Base::get(context, kernels);
2239+
// Sort the array.
2240+
SmallVector<KernelAttr> kernelsTmp(kernels);
2241+
llvm::array_pod_sort(kernelsTmp.begin(), kernelsTmp.end());
2242+
return Base::get(context, kernelsTmp);
2243+
}
2244+
2245+
KernelTableAttr
2246+
KernelTableAttr::getChecked(function_ref<InFlightDiagnostic()> emitError,
2247+
MLIRContext *context, ArrayRef<KernelAttr> kernels,
2248+
bool isSorted) {
2249+
// Note that `is_sorted` is always only invoked once even with assertions ON.
2250+
assert((!isSorted || llvm::is_sorted(kernels)) &&
2251+
"expected a sorted kernel array");
2252+
// Immediately return the attribute if the array is sorted.
2253+
if (isSorted || llvm::is_sorted(kernels))
2254+
return Base::getChecked(emitError, context, kernels);
2255+
// Sort the array.
2256+
SmallVector<KernelAttr> kernelsTmp(kernels);
2257+
llvm::array_pod_sort(kernelsTmp.begin(), kernelsTmp.end());
2258+
return Base::getChecked(emitError, context, kernelsTmp);
2259+
}
2260+
22302261
LogicalResult
22312262
KernelTableAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2232-
DictionaryAttr dict) {
2233-
if (!dict)
2234-
return emitError() << "table cannot be null";
2235-
for (NamedAttribute attr : dict) {
2236-
auto kernel = llvm::dyn_cast<KernelAttr>(attr.getValue());
2237-
if (!kernel)
2238-
return emitError()
2239-
<< "all the dictionary values must be `#gpu.kernel` attributes";
2240-
if (kernel.getName() != attr.getName())
2241-
return emitError() << "expected kernel to be named `" << attr.getName()
2242-
<< "` but got `" << kernel.getName() << "`";
2263+
ArrayRef<KernelAttr> kernels) {
2264+
if (kernels.size() < 2)
2265+
return success();
2266+
// Check that the kernels are uniquely named.
2267+
if (std::adjacent_find(kernels.begin(), kernels.end(),
2268+
[](KernelAttr l, KernelAttr r) {
2269+
return l.getName() == r.getName();
2270+
}) != kernels.end()) {
2271+
return emitError() << "expected all kernels to be uniquely named";
22432272
}
22442273
return success();
22452274
}
22462275

2276+
KernelAttr KernelTableAttr::lookup(StringRef key) const {
2277+
auto it = impl::findAttrSorted(begin(), end(), key);
2278+
return it.second ? *it.first : KernelAttr();
2279+
}
2280+
2281+
KernelAttr KernelTableAttr::lookup(StringAttr key) const {
2282+
auto it = impl::findAttrSorted(begin(), end(), key);
2283+
return it.second ? *it.first : KernelAttr();
2284+
}
2285+
22472286
//===----------------------------------------------------------------------===//
22482287
// GPU target options
22492288
//===----------------------------------------------------------------------===//

mlir/lib/Target/LLVM/ROCDL/Target.cpp

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -508,13 +508,10 @@ ROCDLTargetAttrImpl::createObject(Attribute attribute, Operation *module,
508508
// supported.
509509
if (format > gpu::CompilationTarget::Binary)
510510
format = gpu::CompilationTarget::Binary;
511-
512511
DictionaryAttr properties{};
513512
Builder builder(attribute.getContext());
514-
return builder.getAttr<gpu::ObjectAttr>(
515-
attribute,
516-
format > gpu::CompilationTarget::Binary ? gpu::CompilationTarget::Binary
517-
: format,
518-
builder.getStringAttr(StringRef(object.data(), object.size())),
519-
properties, nullptr);
513+
StringAttr objectStr =
514+
builder.getStringAttr(StringRef(object.data(), object.size()));
515+
return builder.getAttr<gpu::ObjectAttr>(attribute, format, objectStr,
516+
properties, nullptr);
520517
}

0 commit comments

Comments
 (0)