@@ -1092,44 +1092,50 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
1092
1092
SmallVectorImpl<Value> &remapped) {
1093
1093
remapped.reserve (llvm::size (values));
1094
1094
1095
- SmallVector<Type, 1 > legalTypes;
1096
1095
for (const auto &it : llvm::enumerate (values)) {
1097
1096
Value operand = it.value ();
1098
1097
Type origType = operand.getType ();
1098
+ Location operandLoc = inputLoc ? *inputLoc : operand.getLoc ();
1099
1099
1100
- // If a converter was provided, get the desired legal types for this
1101
- // operand.
1102
- Type desiredType;
1103
- if (currentTypeConverter) {
1104
- // If there is no legal conversion, fail to match this pattern.
1105
- legalTypes.clear ();
1106
- if (failed (currentTypeConverter->convertType (origType, legalTypes))) {
1107
- Location operandLoc = inputLoc ? *inputLoc : operand.getLoc ();
1108
- notifyMatchFailure (operandLoc, [=](Diagnostic &diag) {
1109
- diag << " unable to convert type for " << valueDiagTag << " #"
1110
- << it.index () << " , type was " << origType;
1111
- });
1112
- return failure ();
1113
- }
1114
- // TODO: There currently isn't any mechanism to do 1->N type conversion
1115
- // via the PatternRewriter replacement API, so for now we just ignore it.
1116
- if (legalTypes.size () == 1 )
1117
- desiredType = legalTypes.front ();
1118
- } else {
1119
- // TODO: What we should do here is just set `desiredType` to `origType`
1120
- // and then handle the necessary type conversions after the conversion
1121
- // process has finished. Unfortunately a lot of patterns currently rely on
1122
- // receiving the new operands even if the types change, so we keep the
1123
- // original behavior here for now until all of the patterns relying on
1124
- // this get updated.
1100
+ if (!currentTypeConverter) {
1101
+ // The current pattern does not have a type converter. I.e., it does not
1102
+ // distinguish between legal and illegal types. For each operand, simply
1103
+ // pass through the most recently mapped value.
1104
+ remapped.push_back (mapping.lookupOrDefault (operand));
1105
+ continue ;
1106
+ }
1107
+
1108
+ // If there is no legal conversion, fail to match this pattern.
1109
+ SmallVector<Type, 1 > legalTypes;
1110
+ if (failed (currentTypeConverter->convertType (origType, legalTypes))) {
1111
+ notifyMatchFailure (operandLoc, [=](Diagnostic &diag) {
1112
+ diag << " unable to convert type for " << valueDiagTag << " #"
1113
+ << it.index () << " , type was " << origType;
1114
+ });
1115
+ return failure ();
1125
1116
}
1126
- Value newOperand = mapping.lookupOrDefault (operand, desiredType);
1127
1117
1128
- // Handle the case where the conversion was 1->1 and the new operand type
1129
- // isn't legal.
1130
- Type newOperandType = newOperand.getType ();
1131
- if (currentTypeConverter && desiredType && newOperandType != desiredType) {
1132
- Location operandLoc = inputLoc ? *inputLoc : operand.getLoc ();
1118
+ if (legalTypes.size () != 1 ) {
1119
+ // TODO: Parts of the dialect conversion infrastructure do not support
1120
+ // 1->N type conversions yet. Therefore, if a type is converted to 0 or
1121
+ // multiple types, the only thing that we can do for now is passing
1122
+ // through the most recently mapped value. Fixing this requires
1123
+ // improvements to the `ConversionValueMapping` (to be able to store 1:N
1124
+ // mappings) and to the `ConversionPattern` adaptor handling (to be able
1125
+ // to pass multiple remapped values for a single operand to the adaptor).
1126
+ remapped.push_back (mapping.lookupOrDefault (operand));
1127
+ continue ;
1128
+ }
1129
+
1130
+ // Handle 1->1 type conversions.
1131
+ Type desiredType = legalTypes.front ();
1132
+ // Try to find a mapped value with the desired type. (Or the operand itself
1133
+ // if the value is not mapped at all.)
1134
+ Value newOperand = mapping.lookupOrDefault (operand, desiredType);
1135
+ if (newOperand.getType () != desiredType) {
1136
+ // If the looked up value's type does not have the desired type, it means
1137
+ // that the value was replaced with a value of different type and no
1138
+ // source materialization was created yet.
1133
1139
Value castValue = buildUnresolvedMaterialization (
1134
1140
MaterializationKind::Target, computeInsertPoint (newOperand),
1135
1141
operandLoc, /* inputs=*/ newOperand, /* outputType=*/ desiredType,
0 commit comments