Skip to content

Commit 3dd5833

Browse files
[mlir][Transforms] TypeConverter: Mark conversion/materialization functions as "const"
Functions that materialize IR or convert types can be const. Caching data structures inside the TypeConverter are marked as `mutable`. Differential Revision: https://reviews.llvm.org/D157597
1 parent 6ebeecf commit 3dd5833

File tree

2 files changed

+59
-50
lines changed

2 files changed

+59
-50
lines changed

mlir/include/mlir/Transforms/DialectConversion.h

Lines changed: 32 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -219,76 +219,78 @@ class TypeConverter {
219219
/// conversion exists, success otherwise. If the new set of types is empty,
220220
/// the type is removed and any usages of the existing value are expected to
221221
/// be removed during conversion.
222-
LogicalResult convertType(Type t, SmallVectorImpl<Type> &results);
222+
LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) const;
223223

224224
/// This hook simplifies defining 1-1 type conversions. This function returns
225225
/// the type to convert to on success, and a null type on failure.
226-
Type convertType(Type t);
226+
Type convertType(Type t) const;
227227

228228
/// Attempts a 1-1 type conversion, expecting the result type to be
229229
/// `TargetType`. Returns the converted type cast to `TargetType` on success,
230230
/// and a null type on conversion or cast failure.
231-
template <typename TargetType>
232-
TargetType convertType(Type t) {
231+
template <typename TargetType> TargetType convertType(Type t) const {
233232
return dyn_cast_or_null<TargetType>(convertType(t));
234233
}
235234

236235
/// Convert the given set of types, filling 'results' as necessary. This
237236
/// returns failure if the conversion of any of the types fails, success
238237
/// otherwise.
239-
LogicalResult convertTypes(TypeRange types, SmallVectorImpl<Type> &results);
238+
LogicalResult convertTypes(TypeRange types,
239+
SmallVectorImpl<Type> &results) const;
240240

241241
/// Return true if the given type is legal for this type converter, i.e. the
242242
/// type converts to itself.
243-
bool isLegal(Type type);
243+
bool isLegal(Type type) const;
244+
244245
/// Return true if all of the given types are legal for this type converter.
245246
template <typename RangeT>
246247
std::enable_if_t<!std::is_convertible<RangeT, Type>::value &&
247248
!std::is_convertible<RangeT, Operation *>::value,
248249
bool>
249-
isLegal(RangeT &&range) {
250+
isLegal(RangeT &&range) const {
250251
return llvm::all_of(range, [this](Type type) { return isLegal(type); });
251252
}
252253
/// Return true if the given operation has legal operand and result types.
253-
bool isLegal(Operation *op);
254+
bool isLegal(Operation *op) const;
254255

255256
/// Return true if the types of block arguments within the region are legal.
256-
bool isLegal(Region *region);
257+
bool isLegal(Region *region) const;
257258

258259
/// Return true if the inputs and outputs of the given function type are
259260
/// legal.
260-
bool isSignatureLegal(FunctionType ty);
261+
bool isSignatureLegal(FunctionType ty) const;
261262

262263
/// This method allows for converting a specific argument of a signature. It
263264
/// takes as inputs the original argument input number, type.
264265
/// On success, it populates 'result' with any new mappings.
265266
LogicalResult convertSignatureArg(unsigned inputNo, Type type,
266-
SignatureConversion &result);
267+
SignatureConversion &result) const;
267268
LogicalResult convertSignatureArgs(TypeRange types,
268269
SignatureConversion &result,
269-
unsigned origInputOffset = 0);
270+
unsigned origInputOffset = 0) const;
270271

271272
/// This function converts the type signature of the given block, by invoking
272273
/// 'convertSignatureArg' for each argument. This function should return a
273274
/// valid conversion for the signature on success, std::nullopt otherwise.
274-
std::optional<SignatureConversion> convertBlockSignature(Block *block);
275+
std::optional<SignatureConversion> convertBlockSignature(Block *block) const;
275276

276277
/// Materialize a conversion from a set of types into one result type by
277278
/// generating a cast sequence of some kind. See the respective
278279
/// `add*Materialization` for more information on the context for these
279280
/// methods.
280281
Value materializeArgumentConversion(OpBuilder &builder, Location loc,
281-
Type resultType, ValueRange inputs) {
282+
Type resultType,
283+
ValueRange inputs) const {
282284
return materializeConversion(argumentMaterializations, builder, loc,
283285
resultType, inputs);
284286
}
285287
Value materializeSourceConversion(OpBuilder &builder, Location loc,
286-
Type resultType, ValueRange inputs) {
288+
Type resultType, ValueRange inputs) const {
287289
return materializeConversion(sourceMaterializations, builder, loc,
288290
resultType, inputs);
289291
}
290292
Value materializeTargetConversion(OpBuilder &builder, Location loc,
291-
Type resultType, ValueRange inputs) {
293+
Type resultType, ValueRange inputs) const {
292294
return materializeConversion(targetMaterializations, builder, loc,
293295
resultType, inputs);
294296
}
@@ -297,7 +299,8 @@ class TypeConverter {
297299
/// the registered conversion functions. If no applicable conversion has been
298300
/// registered, return std::nullopt. Note that the empty attribute/`nullptr`
299301
/// is a valid return value for this function.
300-
std::optional<Attribute> convertTypeAttribute(Type type, Attribute attr);
302+
std::optional<Attribute> convertTypeAttribute(Type type,
303+
Attribute attr) const;
301304

302305
private:
303306
/// The signature of the callback used to convert a type. If the new set of
@@ -316,16 +319,17 @@ class TypeConverter {
316319

317320
/// Attempt to materialize a conversion using one of the provided
318321
/// materialization functions.
319-
Value materializeConversion(
320-
MutableArrayRef<MaterializationCallbackFn> materializations,
321-
OpBuilder &builder, Location loc, Type resultType, ValueRange inputs);
322+
Value
323+
materializeConversion(ArrayRef<MaterializationCallbackFn> materializations,
324+
OpBuilder &builder, Location loc, Type resultType,
325+
ValueRange inputs) const;
322326

323327
/// Generate a wrapper for the given callback. This allows for accepting
324328
/// different callback forms, that all compose into a single version.
325329
/// With callback of form: `std::optional<Type>(T)`
326330
template <typename T, typename FnT>
327331
std::enable_if_t<std::is_invocable_v<FnT, T>, ConversionCallbackFn>
328-
wrapCallback(FnT &&callback) {
332+
wrapCallback(FnT &&callback) const {
329333
return wrapCallback<T>(
330334
[callback = std::forward<FnT>(callback)](
331335
T type, SmallVectorImpl<Type> &results, ArrayRef<Type>) {
@@ -343,7 +347,7 @@ class TypeConverter {
343347
template <typename T, typename FnT>
344348
std::enable_if_t<std::is_invocable_v<FnT, T, SmallVectorImpl<Type> &>,
345349
ConversionCallbackFn>
346-
wrapCallback(FnT &&callback) {
350+
wrapCallback(FnT &&callback) const {
347351
return wrapCallback<T>(
348352
[callback = std::forward<FnT>(callback)](
349353
T type, SmallVectorImpl<Type> &results, ArrayRef<Type>) {
@@ -356,7 +360,7 @@ class TypeConverter {
356360
std::enable_if_t<
357361
std::is_invocable_v<FnT, T, SmallVectorImpl<Type> &, ArrayRef<Type>>,
358362
ConversionCallbackFn>
359-
wrapCallback(FnT &&callback) {
363+
wrapCallback(FnT &&callback) const {
360364
return [callback = std::forward<FnT>(callback)](
361365
Type type, SmallVectorImpl<Type> &results,
362366
ArrayRef<Type> callStack) -> std::optional<LogicalResult> {
@@ -378,7 +382,7 @@ class TypeConverter {
378382
/// may take any subclass of `Type` and the wrapper will check for the target
379383
/// type to be of the expected class before calling the callback.
380384
template <typename T, typename FnT>
381-
MaterializationCallbackFn wrapMaterialization(FnT &&callback) {
385+
MaterializationCallbackFn wrapMaterialization(FnT &&callback) const {
382386
return [callback = std::forward<FnT>(callback)](
383387
OpBuilder &builder, Type resultType, ValueRange inputs,
384388
Location loc) -> std::optional<Value> {
@@ -394,7 +398,7 @@ class TypeConverter {
394398
/// callback.
395399
template <typename T, typename A, typename FnT>
396400
TypeAttributeConversionCallbackFn
397-
wrapTypeAttributeConversion(FnT &&callback) {
401+
wrapTypeAttributeConversion(FnT &&callback) const {
398402
return [callback = std::forward<FnT>(callback)](
399403
Type type, Attribute attr) -> AttributeConversionResult {
400404
if (T derivedType = dyn_cast<T>(type)) {
@@ -428,13 +432,13 @@ class TypeConverter {
428432
/// A set of cached conversions to avoid recomputing in the common case.
429433
/// Direct 1-1 conversions are the most common, so this cache stores the
430434
/// successful 1-1 conversions as well as all failed conversions.
431-
DenseMap<Type, Type> cachedDirectConversions;
435+
mutable DenseMap<Type, Type> cachedDirectConversions;
432436
/// This cache stores the successful 1->N conversions, where N != 1.
433-
DenseMap<Type, SmallVector<Type, 2>> cachedMultiConversions;
437+
mutable DenseMap<Type, SmallVector<Type, 2>> cachedMultiConversions;
434438

435439
/// Stores the types that are being converted in the case when convertType
436440
/// is being called recursively to convert nested types.
437-
SmallVector<Type, 2> conversionCallStack;
441+
mutable SmallVector<Type, 2> conversionCallStack;
438442
};
439443

440444
//===----------------------------------------------------------------------===//

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 27 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2906,7 +2906,7 @@ void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo,
29062906
}
29072907

29082908
LogicalResult TypeConverter::convertType(Type t,
2909-
SmallVectorImpl<Type> &results) {
2909+
SmallVectorImpl<Type> &results) const {
29102910
auto existingIt = cachedDirectConversions.find(t);
29112911
if (existingIt != cachedDirectConversions.end()) {
29122912
if (existingIt->second)
@@ -2925,7 +2925,7 @@ LogicalResult TypeConverter::convertType(Type t,
29252925
conversionCallStack.push_back(t);
29262926
auto popConversionCallStack =
29272927
llvm::make_scope_exit([this]() { conversionCallStack.pop_back(); });
2928-
for (ConversionCallbackFn &converter : llvm::reverse(conversions)) {
2928+
for (const ConversionCallbackFn &converter : llvm::reverse(conversions)) {
29292929
if (std::optional<LogicalResult> result =
29302930
converter(t, results, conversionCallStack)) {
29312931
if (!succeeded(*result)) {
@@ -2943,7 +2943,7 @@ LogicalResult TypeConverter::convertType(Type t,
29432943
return failure();
29442944
}
29452945

2946-
Type TypeConverter::convertType(Type t) {
2946+
Type TypeConverter::convertType(Type t) const {
29472947
// Use the multi-type result version to convert the type.
29482948
SmallVector<Type, 1> results;
29492949
if (failed(convertType(t, results)))
@@ -2953,31 +2953,35 @@ Type TypeConverter::convertType(Type t) {
29532953
return results.size() == 1 ? results.front() : nullptr;
29542954
}
29552955

2956-
LogicalResult TypeConverter::convertTypes(TypeRange types,
2957-
SmallVectorImpl<Type> &results) {
2956+
LogicalResult
2957+
TypeConverter::convertTypes(TypeRange types,
2958+
SmallVectorImpl<Type> &results) const {
29582959
for (Type type : types)
29592960
if (failed(convertType(type, results)))
29602961
return failure();
29612962
return success();
29622963
}
29632964

2964-
bool TypeConverter::isLegal(Type type) { return convertType(type) == type; }
2965-
bool TypeConverter::isLegal(Operation *op) {
2965+
bool TypeConverter::isLegal(Type type) const {
2966+
return convertType(type) == type;
2967+
}
2968+
bool TypeConverter::isLegal(Operation *op) const {
29662969
return isLegal(op->getOperandTypes()) && isLegal(op->getResultTypes());
29672970
}
29682971

2969-
bool TypeConverter::isLegal(Region *region) {
2972+
bool TypeConverter::isLegal(Region *region) const {
29702973
return llvm::all_of(*region, [this](Block &block) {
29712974
return isLegal(block.getArgumentTypes());
29722975
});
29732976
}
29742977

2975-
bool TypeConverter::isSignatureLegal(FunctionType ty) {
2978+
bool TypeConverter::isSignatureLegal(FunctionType ty) const {
29762979
return isLegal(llvm::concat<const Type>(ty.getInputs(), ty.getResults()));
29772980
}
29782981

2979-
LogicalResult TypeConverter::convertSignatureArg(unsigned inputNo, Type type,
2980-
SignatureConversion &result) {
2982+
LogicalResult
2983+
TypeConverter::convertSignatureArg(unsigned inputNo, Type type,
2984+
SignatureConversion &result) const {
29812985
// Try to convert the given input type.
29822986
SmallVector<Type, 1> convertedTypes;
29832987
if (failed(convertType(type, convertedTypes)))
@@ -2991,26 +2995,27 @@ LogicalResult TypeConverter::convertSignatureArg(unsigned inputNo, Type type,
29912995
result.addInputs(inputNo, convertedTypes);
29922996
return success();
29932997
}
2994-
LogicalResult TypeConverter::convertSignatureArgs(TypeRange types,
2995-
SignatureConversion &result,
2996-
unsigned origInputOffset) {
2998+
LogicalResult
2999+
TypeConverter::convertSignatureArgs(TypeRange types,
3000+
SignatureConversion &result,
3001+
unsigned origInputOffset) const {
29973002
for (unsigned i = 0, e = types.size(); i != e; ++i)
29983003
if (failed(convertSignatureArg(origInputOffset + i, types[i], result)))
29993004
return failure();
30003005
return success();
30013006
}
30023007

30033008
Value TypeConverter::materializeConversion(
3004-
MutableArrayRef<MaterializationCallbackFn> materializations,
3005-
OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) {
3006-
for (MaterializationCallbackFn &fn : llvm::reverse(materializations))
3009+
ArrayRef<MaterializationCallbackFn> materializations, OpBuilder &builder,
3010+
Location loc, Type resultType, ValueRange inputs) const {
3011+
for (const MaterializationCallbackFn &fn : llvm::reverse(materializations))
30073012
if (std::optional<Value> result = fn(builder, resultType, inputs, loc))
30083013
return *result;
30093014
return nullptr;
30103015
}
30113016

3012-
auto TypeConverter::convertBlockSignature(Block *block)
3013-
-> std::optional<SignatureConversion> {
3017+
std::optional<TypeConverter::SignatureConversion>
3018+
TypeConverter::convertBlockSignature(Block *block) const {
30143019
SignatureConversion conversion(block->getNumArguments());
30153020
if (failed(convertSignatureArgs(block->getArgumentTypes(), conversion)))
30163021
return std::nullopt;
@@ -3052,9 +3057,9 @@ Attribute TypeConverter::AttributeConversionResult::getResult() const {
30523057
return impl.getPointer();
30533058
}
30543059

3055-
std::optional<Attribute> TypeConverter::convertTypeAttribute(Type type,
3056-
Attribute attr) {
3057-
for (TypeAttributeConversionCallbackFn &fn :
3060+
std::optional<Attribute>
3061+
TypeConverter::convertTypeAttribute(Type type, Attribute attr) const {
3062+
for (const TypeAttributeConversionCallbackFn &fn :
30583063
llvm::reverse(typeAttributeConversions)) {
30593064
AttributeConversionResult res = fn(type, attr);
30603065
if (res.hasResult())

0 commit comments

Comments
 (0)