Skip to content

Commit a782922

Browse files
committed
[mlir][SubElementInterfaces] Prefer calling the derived get if possible
This allows for better supporting attributes/types that override the default builders.
1 parent f530e6e commit a782922

File tree

1 file changed

+23
-3
lines changed

1 file changed

+23
-3
lines changed

mlir/include/mlir/IR/SubElementInterfaces.h

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,8 @@ template <typename T>
220220
struct is_tuple : public std::false_type {};
221221
template <typename... Ts>
222222
struct is_tuple<std::tuple<Ts...>> : public std::true_type {};
223+
template <typename T, typename... Ts>
224+
using has_get_method = decltype(T::get(std::declval<Ts>()...));
223225

224226
/// This function provides the underlying implementation for the
225227
/// SubElementInterface walk method, using the key type of the derived
@@ -239,6 +241,23 @@ void walkImmediateSubElementsImpl(T derived,
239241
}
240242
}
241243

244+
/// This function invokes the proper `get` method for a type `T` with the given
245+
/// values.
246+
template <typename T, typename... Ts>
247+
T constructSubElementReplacement(MLIRContext *ctx, Ts &&...params) {
248+
// Prefer a direct `get` method if one exists.
249+
if constexpr (llvm::is_detected<has_get_method, T, Ts...>::value) {
250+
(void)ctx;
251+
return T::get(std::forward<Ts>(params)...);
252+
} else if constexpr (llvm::is_detected<has_get_method, T, MLIRContext *,
253+
Ts...>::value) {
254+
return T::get(ctx, std::forward<Ts>(params)...);
255+
} else {
256+
// Otherwise, pass to the base get.
257+
return T::Base::get(ctx, std::forward<Ts>(params)...);
258+
}
259+
}
260+
242261
/// This function provides the underlying implementation for the
243262
/// SubElementInterface replace method, using the key type of the derived
244263
/// attribute/type to interact with the individual parameters.
@@ -260,12 +279,13 @@ T replaceImmediateSubElementsImpl(T derived, ArrayRef<Attribute> &replAttrs,
260279
if constexpr (is_tuple<decltype(key)>::value) {
261280
return std::apply(
262281
[&](auto &&...params) {
263-
return T::Base::get(derived.getContext(),
264-
std::forward<decltype(params)>(params)...);
282+
return constructSubElementReplacement<T>(
283+
derived.getContext(),
284+
std::forward<decltype(params)>(params)...);
265285
},
266286
newKey);
267287
} else {
268-
return T::Base::get(derived.getContext(), newKey);
288+
return constructSubElementReplacement<T>(derived.getContext(), newKey);
269289
}
270290
}
271291
}

0 commit comments

Comments
 (0)