Skip to content

Commit f8184d4

Browse files
committed
[mlir] Lookup the latest value with a legal type when remapping values.
The current condition implies that the target materialization will be called even if the type is the new operand type is legal, but slightly different. For example, if there is a bufferization pattern that changes memref layout, then target materialization for an illegal type (TensorType) would be called. Differential Revision: https://reviews.llvm.org/D93126
1 parent f2661be commit f8184d4

File tree

1 file changed

+59
-50
lines changed

1 file changed

+59
-50
lines changed

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 59 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -105,10 +105,15 @@ namespace {
105105
/// functionality, i.e. we will traverse if the mapped value also has a mapping.
106106
struct ConversionValueMapping {
107107
/// Lookup a mapped value within the map. If a mapping for the provided value
108-
/// does not exist then return the provided value. If `desiredType` is
109-
/// non-null, returns the most recently mapped value with that type. If an
110-
/// operand of that type does not exist, defaults to normal behavior.
111-
Value lookupOrDefault(Value from, Type desiredType = nullptr) const;
108+
/// does not exist then return the provided value.
109+
Value lookupOrDefault(Value from) const;
110+
111+
/// Lookup the latest legal value within the map. If a mapping for the
112+
/// provided value does not exist then return the provided value. If
113+
/// `converter` is non-null, returns the most recently mapped value with the
114+
/// legal type. If an operand of that type does not exist, defaults to normal
115+
/// behavior.
116+
Value lookupLatestLegal(Value from, TypeConverter *converter) const;
112117

113118
/// Lookup a mapped value within the map, or return null if a mapping does not
114119
/// exist. If a mapping exists, this follows the same behavior of
@@ -127,22 +132,24 @@ struct ConversionValueMapping {
127132
};
128133
} // end anonymous namespace
129134

130-
Value ConversionValueMapping::lookupOrDefault(Value from,
131-
Type desiredType) const {
132-
// If there was no desired type, simply find the leaf value.
133-
if (!desiredType) {
134-
// If this value had a valid mapping, unmap that value as well in the case
135-
// that it was also replaced.
136-
while (auto mappedValue = mapping.lookupOrNull(from))
137-
from = mappedValue;
138-
return from;
139-
}
135+
Value ConversionValueMapping::lookupOrDefault(Value from) const {
136+
// If this value had a valid mapping, unmap that value as well in the case
137+
// that it was also replaced.
138+
while (auto mappedValue = mapping.lookupOrNull(from))
139+
from = mappedValue;
140+
return from;
141+
}
140142

141-
// Otherwise, try to find the deepest value that has the desired type.
142-
Value desiredValue;
143+
Value ConversionValueMapping::lookupLatestLegal(
144+
Value from, TypeConverter *converter) const {
145+
if (!converter)
146+
return lookupOrDefault(from);
147+
148+
// Otherwise, try to find the deepest value that has the legal type.
149+
Value legalValue;
143150
do {
144-
if (from.getType() == desiredType)
145-
desiredValue = from;
151+
if (converter->isLegal(from.getType()))
152+
legalValue = from;
146153

147154
Value mappedValue = mapping.lookupOrNull(from);
148155
if (!mappedValue)
@@ -151,7 +158,7 @@ Value ConversionValueMapping::lookupOrDefault(Value from,
151158
} while (true);
152159

153160
// If the desired value was found use it, otherwise default to the leaf value.
154-
return desiredValue ? desiredValue : from;
161+
return legalValue ? legalValue : from;
155162
}
156163

157164
Value ConversionValueMapping::lookupOrNull(Value from) const {
@@ -1039,47 +1046,49 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
10391046
Value operand = it.value();
10401047
Type origType = operand.getType();
10411048

1042-
// If a converter was provided, get the desired legal types for this
1043-
// operand.
1044-
Type desiredType;
1049+
Value newOperand = mapping.lookupLatestLegal(operand, converter);
1050+
1051+
// Handle the case where the conversion was 1->1 and the new operand type
1052+
// isn't legal.
1053+
Type newOperandType = newOperand.getType();
10451054
if (converter) {
1046-
// If there is no legal conversion, fail to match this pattern.
1047-
legalTypes.clear();
1048-
if (failed(converter->convertType(origType, legalTypes))) {
1049-
return notifyMatchFailure(loc, [=](Diagnostic &diag) {
1050-
diag << "unable to convert type for operand #" << it.index()
1051-
<< ", type was " << origType;
1052-
});
1055+
if (!converter->isLegal(newOperandType)) {
1056+
legalTypes.clear();
1057+
1058+
// If there is no legal conversion, fail to match this pattern.
1059+
if (failed(converter->convertType(origType, legalTypes))) {
1060+
return notifyMatchFailure(loc, [=](Diagnostic &diag) {
1061+
diag << "unable to convert type for operand #" << it.index()
1062+
<< ", type was " << origType;
1063+
});
1064+
}
1065+
// TODO: There currently isn't any mechanism to do 1->N type conversion
1066+
// via the PatternRewriter replacement API, so for now we just ignore
1067+
// it.
1068+
if (legalTypes.size() != 1) {
1069+
remapped.push_back(newOperand);
1070+
continue;
1071+
}
1072+
Type desiredType = legalTypes.front();
1073+
newOperand = converter->materializeTargetConversion(
1074+
rewriter, loc, desiredType, newOperand);
1075+
if (!newOperand) {
1076+
return notifyMatchFailure(loc, [=](Diagnostic &diag) {
1077+
diag << "unable to materialize a conversion for "
1078+
"operand #"
1079+
<< it.index() << ", from " << newOperandType << " to "
1080+
<< desiredType;
1081+
});
1082+
}
10531083
}
1054-
// TODO: There currently isn't any mechanism to do 1->N type conversion
1055-
// via the PatternRewriter replacement API, so for now we just ignore it.
1056-
if (legalTypes.size() == 1)
1057-
desiredType = legalTypes.front();
10581084
} else {
10591085
// TODO: What we should do here is just set `desiredType` to `origType`
10601086
// and then handle the necessary type conversions after the conversion
10611087
// process has finished. Unfortunately a lot of patterns currently rely on
10621088
// receiving the new operands even if the types change, so we keep the
10631089
// original behavior here for now until all of the patterns relying on
10641090
// this get updated.
1065-
}
1066-
Value newOperand = mapping.lookupOrDefault(operand, desiredType);
1067-
1068-
// Handle the case where the conversion was 1->1 and the new operand type
1069-
// isn't legal.
1070-
Type newOperandType = newOperand.getType();
1071-
if (converter && desiredType && newOperandType != desiredType) {
10721091
// Attempt to materialize a conversion for this new value.
1073-
newOperand = converter->materializeTargetConversion(
1074-
rewriter, loc, desiredType, newOperand);
1075-
if (!newOperand) {
1076-
return notifyMatchFailure(loc, [=](Diagnostic &diag) {
1077-
diag << "unable to materialize a conversion for "
1078-
"operand #"
1079-
<< it.index() << ", from " << newOperandType << " to "
1080-
<< desiredType;
1081-
});
1082-
}
10831092
}
10841093
remapped.push_back(newOperand);
10851094
}

0 commit comments

Comments
 (0)