@@ -220,6 +220,8 @@ template <typename T>
220
220
struct is_tuple : public std ::false_type {};
221
221
template <typename ... Ts>
222
222
struct is_tuple <std::tuple<Ts...>> : public std::true_type {};
223
+ template <typename T, typename ... Ts>
224
+ using has_get_method = decltype (T::get(std::declval<Ts>()...));
223
225
224
226
// / This function provides the underlying implementation for the
225
227
// / SubElementInterface walk method, using the key type of the derived
@@ -239,6 +241,23 @@ void walkImmediateSubElementsImpl(T derived,
239
241
}
240
242
}
241
243
244
+ // / This function invokes the proper `get` method for a type `T` with the given
245
+ // / values.
246
+ template <typename T, typename ... Ts>
247
+ T constructSubElementReplacement (MLIRContext *ctx, Ts &&...params) {
248
+ // Prefer a direct `get` method if one exists.
249
+ if constexpr (llvm::is_detected<has_get_method, T, Ts...>::value) {
250
+ (void )ctx;
251
+ return T::get (std::forward<Ts>(params)...);
252
+ } else if constexpr (llvm::is_detected<has_get_method, T, MLIRContext *,
253
+ Ts...>::value) {
254
+ return T::get (ctx, std::forward<Ts>(params)...);
255
+ } else {
256
+ // Otherwise, pass to the base get.
257
+ return T::Base::get (ctx, std::forward<Ts>(params)...);
258
+ }
259
+ }
260
+
242
261
// / This function provides the underlying implementation for the
243
262
// / SubElementInterface replace method, using the key type of the derived
244
263
// / attribute/type to interact with the individual parameters.
@@ -260,12 +279,13 @@ T replaceImmediateSubElementsImpl(T derived, ArrayRef<Attribute> &replAttrs,
260
279
if constexpr (is_tuple<decltype (key)>::value) {
261
280
return std::apply (
262
281
[&](auto &&...params ) {
263
- return T::Base::get (derived.getContext (),
264
- std::forward<decltype (params)>(params)...);
282
+ return constructSubElementReplacement<T>(
283
+ derived.getContext (),
284
+ std::forward<decltype (params)>(params)...);
265
285
},
266
286
newKey);
267
287
} else {
268
- return T::Base::get (derived.getContext (), newKey);
288
+ return constructSubElementReplacement<T> (derived.getContext (), newKey);
269
289
}
270
290
}
271
291
}
0 commit comments