-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][Transforms] Dialect conversion: Fix missing source materialization #97903
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -707,10 +707,9 @@ class UnresolvedMaterializationRewrite : public OperationRewrite { | |
UnresolvedMaterializationRewrite( | ||
ConversionPatternRewriterImpl &rewriterImpl, | ||
UnrealizedConversionCastOp op, const TypeConverter *converter = nullptr, | ||
MaterializationKind kind = MaterializationKind::Target, | ||
Type origOutputType = nullptr) | ||
MaterializationKind kind = MaterializationKind::Target) | ||
: OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op), | ||
converterAndKind(converter, kind), origOutputType(origOutputType) {} | ||
converterAndKind(converter, kind) {} | ||
|
||
static bool classof(const IRRewrite *rewrite) { | ||
return rewrite->getKind() == Kind::UnresolvedMaterialization; | ||
|
@@ -734,17 +733,11 @@ class UnresolvedMaterializationRewrite : public OperationRewrite { | |
return converterAndKind.getInt(); | ||
} | ||
|
||
/// Return the original illegal output type of the input values. | ||
Type getOrigOutputType() const { return origOutputType; } | ||
|
||
private: | ||
/// The corresponding type converter to use when resolving this | ||
/// materialization, and the kind of this materialization. | ||
llvm::PointerIntPair<const TypeConverter *, 1, MaterializationKind> | ||
converterAndKind; | ||
|
||
/// The original output type. This is only used for argument conversions. | ||
Type origOutputType; | ||
}; | ||
} // namespace | ||
|
||
|
@@ -860,12 +853,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { | |
Block *insertBlock, | ||
Block::iterator insertPt, Location loc, | ||
ValueRange inputs, Type outputType, | ||
Type origOutputType, | ||
const TypeConverter *converter); | ||
|
||
Value buildUnresolvedArgumentMaterialization(Block *block, Location loc, | ||
ValueRange inputs, | ||
Type origOutputType, | ||
Type outputType, | ||
const TypeConverter *converter); | ||
|
||
|
@@ -1388,20 +1379,28 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( | |
if (replArgs.size() == 1 && | ||
(!converter || replArgs[0].getType() == origArg.getType())) { | ||
newArg = replArgs.front(); | ||
mapping.map(origArg, newArg); | ||
} else { | ||
Type origOutputType = origArg.getType(); | ||
|
||
// Legalize the argument output type. | ||
Type outputType = origOutputType; | ||
if (Type legalOutputType = converter->convertType(outputType)) | ||
outputType = legalOutputType; | ||
|
||
newArg = buildUnresolvedArgumentMaterialization( | ||
newBlock, origArg.getLoc(), replArgs, origOutputType, outputType, | ||
converter); | ||
// Build argument materialization: new block arguments -> old block | ||
// argument type. | ||
Value argMat = buildUnresolvedArgumentMaterialization( | ||
newBlock, origArg.getLoc(), replArgs, origArg.getType(), converter); | ||
mapping.map(origArg, argMat); | ||
|
||
// Build target materialization: old block argument type -> legal type. | ||
// Note: This function returns an "empty" type if no valid conversion to | ||
// a legal type exists. In that case, we continue the conversion with the | ||
// original block argument type. | ||
Type legalOutputType = converter->convertType(origArg.getType()); | ||
if (legalOutputType && legalOutputType != origArg.getType()) { | ||
newArg = buildUnresolvedTargetMaterialization( | ||
origArg.getLoc(), argMat, legalOutputType, converter); | ||
mapping.map(argMat, newArg); | ||
} else { | ||
newArg = argMat; | ||
} | ||
} | ||
|
||
mapping.map(origArg, newArg); | ||
appendRewrite<ReplaceBlockArgRewrite>(block, origArg); | ||
argInfo[i] = ConvertedArgInfo(inputMap->inputNo, inputMap->size, newArg); | ||
} | ||
|
@@ -1424,7 +1423,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( | |
/// of input operands. | ||
Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization( | ||
MaterializationKind kind, Block *insertBlock, Block::iterator insertPt, | ||
Location loc, ValueRange inputs, Type outputType, Type origOutputType, | ||
Location loc, ValueRange inputs, Type outputType, | ||
const TypeConverter *converter) { | ||
// Avoid materializing an unnecessary cast. | ||
if (inputs.size() == 1 && inputs.front().getType() == outputType) | ||
|
@@ -1435,16 +1434,15 @@ Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization( | |
OpBuilder builder(insertBlock, insertPt); | ||
auto convertOp = | ||
builder.create<UnrealizedConversionCastOp>(loc, outputType, inputs); | ||
appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind, | ||
origOutputType); | ||
appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind); | ||
return convertOp.getResult(0); | ||
} | ||
Value ConversionPatternRewriterImpl::buildUnresolvedArgumentMaterialization( | ||
Block *block, Location loc, ValueRange inputs, Type origOutputType, | ||
Type outputType, const TypeConverter *converter) { | ||
Block *block, Location loc, ValueRange inputs, Type outputType, | ||
const TypeConverter *converter) { | ||
return buildUnresolvedMaterialization(MaterializationKind::Argument, block, | ||
block->begin(), loc, inputs, outputType, | ||
origOutputType, converter); | ||
converter); | ||
} | ||
Value ConversionPatternRewriterImpl::buildUnresolvedTargetMaterialization( | ||
Location loc, Value input, Type outputType, | ||
|
@@ -1456,7 +1454,7 @@ Value ConversionPatternRewriterImpl::buildUnresolvedTargetMaterialization( | |
|
||
return buildUnresolvedMaterialization(MaterializationKind::Target, | ||
insertBlock, insertPt, loc, input, | ||
outputType, outputType, converter); | ||
outputType, converter); | ||
} | ||
|
||
//===----------------------------------------------------------------------===// | ||
|
@@ -2672,19 +2670,28 @@ static void computeNecessaryMaterializations( | |
ConversionPatternRewriterImpl &rewriterImpl, | ||
DenseMap<Value, SmallVector<Value>> &inverseMapping, | ||
SetVector<UnresolvedMaterializationRewrite *> &necessaryMaterializations) { | ||
// Helper function to check if the given value or a not yet materialized | ||
// replacement of the given value is live. | ||
// Note: `inverseMapping` maps from replaced values to original values. | ||
auto isLive = [&](Value value) { | ||
auto findFn = [&](Operation *user) { | ||
auto matIt = materializationOps.find(user); | ||
if (matIt != materializationOps.end()) | ||
return !necessaryMaterializations.count(matIt->second); | ||
return rewriterImpl.isOpIgnored(user); | ||
}; | ||
// This value may be replacing another value that has a live user. | ||
for (Value inv : inverseMapping.lookup(value)) | ||
if (llvm::find_if_not(inv.getUsers(), findFn) != inv.user_end()) | ||
// A worklist is needed because a value may have gone through a chain of | ||
// replacements and each of the replaced values may have live users. | ||
SmallVector<Value> worklist; | ||
worklist.push_back(value); | ||
while (!worklist.empty()) { | ||
Value next = worklist.pop_back_val(); | ||
if (llvm::find_if_not(next.getUsers(), findFn) != next.user_end()) | ||
return true; | ||
// Or have live users itself. | ||
return llvm::find_if_not(value.getUsers(), findFn) != value.user_end(); | ||
// This value may be replacing another value that has a live user. | ||
llvm::append_range(worklist, inverseMapping.lookup(next)); | ||
} | ||
return false; | ||
}; | ||
|
||
llvm::unique_function<Value(Value, Value, Type)> lookupRemappedValue = | ||
|
@@ -2844,18 +2851,10 @@ static LogicalResult legalizeUnresolvedMaterialization( | |
switch (mat.getMaterializationKind()) { | ||
case MaterializationKind::Argument: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am slightly confused by the comment at line 2804 which states that this code only deals with target materializations (I am interpreting this as a materializations to the target type system, not specifically target conversions). It seems me either the comment needs to be updated or the fallback path can be removed or changed There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Honestly, that comment makes no sense to me. This part of the code base deals exclusively with materializations; there are no type conversions anymore. When this comment was added, the implementation already handled both argument and target materializations. Source materializations are not handled here because they never show up as "unresolved materializations". There isn't even a
Interestingly this comment is right before the |
||
// Try to materialize an argument conversion. | ||
// FIXME: The current argument materialization hook expects the original | ||
// output type, even though it doesn't use that as the actual output type | ||
// of the generated IR. The output type is just used as an indicator of | ||
// the type of materialization to do. This behavior is really awkward in | ||
// that it diverges from the behavior of the other hooks, and can be | ||
// easily misunderstood. We should clean up the argument hooks to better | ||
// represent the desired invariants we actually care about. | ||
newMaterialization = converter->materializeArgumentConversion( | ||
rewriter, op->getLoc(), mat.getOrigOutputType(), inputOperands); | ||
rewriter, op->getLoc(), outputType, inputOperands); | ||
if (newMaterialization) | ||
break; | ||
|
||
// If an argument materialization failed, fallback to trying a target | ||
// materialization. | ||
[[fallthrough]]; | ||
|
@@ -2865,6 +2864,8 @@ static LogicalResult legalizeUnresolvedMaterialization( | |
break; | ||
} | ||
if (newMaterialization) { | ||
assert(newMaterialization.getType() == outputType && | ||
"materialization callback produced value of incorrect type"); | ||
replaceMaterialization(rewriterImpl, opResult, newMaterialization, | ||
inverseMapping); | ||
return success(); | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
// RUN: mlir-opt %s -transform-interpreter | FileCheck %s | ||
|
||
// CHECK-LABEL: func @complex_block_signature_conversion( | ||
// CHECK: %[[cst:.*]] = complex.constant | ||
// CHECK: %[[complex_llvm:.*]] = builtin.unrealized_conversion_cast %[[cst]] : complex<f64> to !llvm.struct<(f64, f64)> | ||
// Note: Some blocks are omitted. | ||
// CHECK: llvm.br ^[[block1:.*]](%[[complex_llvm]] | ||
// CHECK: ^[[block1]](%[[arg:.*]]: !llvm.struct<(f64, f64)>): | ||
// CHECK: %[[cast:.*]] = builtin.unrealized_conversion_cast %[[arg]] : !llvm.struct<(f64, f64)> to complex<f64> | ||
// CHECK: llvm.br ^[[block2:.*]] | ||
// CHECK: ^[[block2]]: | ||
// CHECK: "test.consumer_of_complex"(%[[cast]]) : (complex<f64>) -> () | ||
func.func @complex_block_signature_conversion() { | ||
%cst = complex.constant [0.000000e+00, 0.000000e+00] : complex<f64> | ||
%true = arith.constant true | ||
%0 = scf.if %true -> complex<f64> { | ||
scf.yield %cst : complex<f64> | ||
} else { | ||
scf.yield %cst : complex<f64> | ||
} | ||
|
||
// Regression test to ensure that the a source materialization is inserted. | ||
// The operand of "test.consumer_of_complex" must not change. | ||
"test.consumer_of_complex"(%0) : (complex<f64>) -> () | ||
return | ||
} | ||
|
||
module attributes {transform.with_named_sequence} { | ||
transform.named_sequence @__transform_main(%toplevel_module: !transform.any_op {transform.readonly}) { | ||
%func = transform.structured.match ops{["func.func"]} in %toplevel_module | ||
: (!transform.any_op) -> !transform.any_op | ||
transform.apply_conversion_patterns to %func { | ||
transform.apply_conversion_patterns.dialect_to_llvm "cf" | ||
transform.apply_conversion_patterns.func.func_to_llvm | ||
transform.apply_conversion_patterns.scf.scf_to_control_flow | ||
} with type_converter { | ||
transform.apply_conversion_patterns.memref.memref_to_llvm_type_converter | ||
} { | ||
legal_dialects = ["llvm"], | ||
partial_conversion | ||
} : !transform.any_op | ||
transform.yield | ||
} | ||
} |
Uh oh!
There was an error while loading. Please reload this page.