Skip to content

[mlir][Transforms] Merge 1:1 and 1:N type converters #113032

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Oct 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ std::unique_ptr<Pass> createLowerForeachToSCFPass();
//===----------------------------------------------------------------------===//

/// Type converter for iter_space and iterator.
struct SparseIterationTypeConverter : public OneToNTypeConverter {
struct SparseIterationTypeConverter : public TypeConverter {
SparseIterationTypeConverter();
};

Expand Down
62 changes: 48 additions & 14 deletions mlir/include/mlir/Transforms/DialectConversion.h
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,9 @@ class TypeConverter {
/// conversion has finished.
///
/// Note: Target materializations may optionally accept an additional Type
/// parameter, which is the original type of the SSA value.
/// parameter, which is the original type of the SSA value. Furthermore, `T`
/// can be a TypeRange; in that case, the function must return a
/// SmallVector<Value>.

/// This method registers a materialization that will be called when
/// converting (potentially multiple) block arguments that were the result of
Expand Down Expand Up @@ -210,6 +212,9 @@ class TypeConverter {
/// will be invoked with: outputType = "t3", inputs = "v2",
// originalType = "t1". Note that the original type "t1" cannot be recovered
/// from just "t3" and "v2"; that's why the originalType parameter exists.
///
/// Note: During a 1:N conversion, the result types can be a TypeRange. In
/// that case the materialization produces a SmallVector<Value>.
template <typename FnT, typename T = typename llvm::function_traits<
std::decay_t<FnT>>::template arg_t<1>>
void addTargetMaterialization(FnT &&callback) {
Expand Down Expand Up @@ -316,6 +321,11 @@ class TypeConverter {
Value materializeTargetConversion(OpBuilder &builder, Location loc,
Type resultType, ValueRange inputs,
Type originalType = {}) const;
SmallVector<Value> materializeTargetConversion(OpBuilder &builder,
Location loc,
TypeRange resultType,
ValueRange inputs,
Type originalType = {}) const;

/// Convert an attribute present `attr` from within the type `type` using
/// the registered conversion functions. If no applicable conversion has been
Expand All @@ -340,9 +350,9 @@ class TypeConverter {

/// The signature of the callback used to materialize a target conversion.
///
/// Arguments: builder, result type, inputs, location, original type
using TargetMaterializationCallbackFn =
std::function<Value(OpBuilder &, Type, ValueRange, Location, Type)>;
/// Arguments: builder, result types, inputs, location, original type
using TargetMaterializationCallbackFn = std::function<SmallVector<Value>(
OpBuilder &, TypeRange, ValueRange, Location, Type)>;

/// The signature of the callback used to convert a type attribute.
using TypeAttributeConversionCallbackFn =
Expand Down Expand Up @@ -409,32 +419,56 @@ class TypeConverter {
/// callback.
///
/// With callback of form:
/// `Value(OpBuilder &, T, ValueRange, Location, Type)`
/// - Value(OpBuilder &, T, ValueRange, Location, Type)
/// - SmallVector<Value>(OpBuilder &, TypeRange, ValueRange, Location, Type)
template <typename T, typename FnT>
std::enable_if_t<
std::is_invocable_v<FnT, OpBuilder &, T, ValueRange, Location, Type>,
TargetMaterializationCallbackFn>
wrapTargetMaterialization(FnT &&callback) const {
return [callback = std::forward<FnT>(callback)](
OpBuilder &builder, Type resultType, ValueRange inputs,
Location loc, Type originalType) -> Value {
if (T derivedType = dyn_cast<T>(resultType))
return callback(builder, derivedType, inputs, loc, originalType);
return Value();
OpBuilder &builder, TypeRange resultTypes, ValueRange inputs,
Location loc, Type originalType) -> SmallVector<Value> {
SmallVector<Value> result;
if constexpr (std::is_same<T, TypeRange>::value) {
// This is a 1:N target materialization. Return the produces values
// directly.
result = callback(builder, resultTypes, inputs, loc, originalType);
} else if constexpr (std::is_assignable<Type, T>::value) {
// This is a 1:1 target materialization. Invoke the callback only if a
// single SSA value is requested.
if (resultTypes.size() == 1) {
// Invoke the callback only if the type class of the callback matches
// the requested result type.
if (T derivedType = dyn_cast<T>(resultTypes.front())) {
// 1:1 materializations produce single values, but we store 1:N
// target materialization functions in the type converter. Wrap the
// result value in a SmallVector<Value>.
Value val =
callback(builder, derivedType, inputs, loc, originalType);
if (val)
result.push_back(val);
}
}
} else {
static_assert(sizeof(T) == 0, "T must be a Type or a TypeRange");
}
return result;
};
}
/// With callback of form:
/// `Value(OpBuilder &, T, ValueRange, Location)`
/// - Value(OpBuilder &, T, ValueRange, Location)
/// - SmallVector<Value>(OpBuilder &, TypeRange, ValueRange, Location)
template <typename T, typename FnT>
std::enable_if_t<
std::is_invocable_v<FnT, OpBuilder &, T, ValueRange, Location>,
TargetMaterializationCallbackFn>
wrapTargetMaterialization(FnT &&callback) const {
return wrapTargetMaterialization<T>(
[callback = std::forward<FnT>(callback)](
OpBuilder &builder, T resultType, ValueRange inputs, Location loc,
Type originalType) -> Value {
return callback(builder, resultType, inputs, loc);
OpBuilder &builder, T resultTypes, ValueRange inputs, Location loc,
Type originalType) {
return callback(builder, resultTypes, inputs, loc);
});
}

Expand Down
45 changes: 1 addition & 44 deletions mlir/include/mlir/Transforms/OneToNTypeConversion.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,49 +33,6 @@

namespace mlir {

/// Extends `TypeConverter` with 1:N target materializations. Such
/// materializations have to provide the "reverse" of 1:N type conversions,
/// i.e., they need to materialize N values with target types into one value
/// with a source type (which isn't possible in the base class currently).
class OneToNTypeConverter : public TypeConverter {
public:
/// Callback that expresses user-provided materialization logic from the given
/// value to N values of the given types. This is useful for expressing target
/// materializations for 1:N type conversions, which materialize one value in
/// a source type as N values in target types.
using OneToNMaterializationCallbackFn =
std::function<std::optional<SmallVector<Value>>(OpBuilder &, TypeRange,
Value, Location)>;

/// Creates the mapping of the given range of original types to target types
/// of the conversion and stores that mapping in the given (signature)
/// conversion. This function simply calls
/// `TypeConverter::convertSignatureArgs` and exists here with a different
/// name to reflect the broader semantic.
LogicalResult computeTypeMapping(TypeRange types,
SignatureConversion &result) const {
return convertSignatureArgs(types, result);
}

/// Applies one of the user-provided 1:N target materializations. If several
/// exists, they are tried out in the reverse order in which they have been
/// added until the first one succeeds. If none succeeds, the functions
/// returns `std::nullopt`.
std::optional<SmallVector<Value>>
materializeTargetConversion(OpBuilder &builder, Location loc,
TypeRange resultTypes, Value input) const;

/// Adds a 1:N target materialization to the converter. Such materializations
/// build IR that converts N values with target types into 1 value of the
/// source type.
void addTargetMaterialization(OneToNMaterializationCallbackFn &&callback) {
oneToNTargetMaterializations.emplace_back(std::move(callback));
}

private:
SmallVector<OneToNMaterializationCallbackFn> oneToNTargetMaterializations;
};

/// Stores a 1:N mapping of types and provides several useful accessors. This
/// class extends `SignatureConversion`, which already supports 1:N type
/// mappings but lacks some accessors into the mapping as well as access to the
Expand Down Expand Up @@ -295,7 +252,7 @@ class OneToNOpConversionPattern : public OneToNConversionPattern {
/// not fail if some ops or types remain unconverted (i.e., the conversion is
/// only "partial").
LogicalResult
applyPartialOneToNConversion(Operation *op, OneToNTypeConverter &typeConverter,
applyPartialOneToNConversion(Operation *op, TypeConverter &typeConverter,
const FrozenRewritePatternSet &patterns);

/// Add a pattern to the given pattern list to convert the signature of a
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -921,7 +921,7 @@ struct VectorLegalizationPass
: public arm_sme::impl::VectorLegalizationBase<VectorLegalizationPass> {
void runOnOperation() override {
auto *context = &getContext();
OneToNTypeConverter converter;
TypeConverter converter;
RewritePatternSet patterns(context);
converter.addConversion([](Type type) { return type; });
converter.addConversion(
Expand Down
26 changes: 22 additions & 4 deletions mlir/lib/Transforms/Utils/DialectConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2831,11 +2831,29 @@ Value TypeConverter::materializeTargetConversion(OpBuilder &builder,
Location loc, Type resultType,
ValueRange inputs,
Type originalType) const {
SmallVector<Value> result = materializeTargetConversion(
builder, loc, TypeRange(resultType), inputs, originalType);
if (result.empty())
return nullptr;
assert(result.size() == 1 && "expected single result");
return result.front();
}

SmallVector<Value> TypeConverter::materializeTargetConversion(
OpBuilder &builder, Location loc, TypeRange resultTypes, ValueRange inputs,
Type originalType) const {
for (const TargetMaterializationCallbackFn &fn :
llvm::reverse(targetMaterializations))
if (Value result = fn(builder, resultType, inputs, loc, originalType))
return result;
return nullptr;
llvm::reverse(targetMaterializations)) {
SmallVector<Value> result =
fn(builder, resultTypes, inputs, loc, originalType);
if (result.empty())
continue;
assert(TypeRange(result) == resultTypes &&
"callback produced incorrect number of values or values with "
"incorrect types");
return result;
}
return {};
}

std::optional<TypeConverter::SignatureConversion>
Expand Down
44 changes: 14 additions & 30 deletions mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,6 @@
using namespace llvm;
using namespace mlir;

std::optional<SmallVector<Value>>
OneToNTypeConverter::materializeTargetConversion(OpBuilder &builder,
Location loc,
TypeRange resultTypes,
Value input) const {
for (const OneToNMaterializationCallbackFn &fn :
llvm::reverse(oneToNTargetMaterializations)) {
if (std::optional<SmallVector<Value>> result =
fn(builder, resultTypes, input, loc))
return *result;
}
return std::nullopt;
}

TypeRange OneToNTypeMapping::getConvertedTypes(unsigned originalTypeNo) const {
TypeRange convertedTypes = getConvertedTypes();
if (auto mapping = getInputMapping(originalTypeNo))
Expand Down Expand Up @@ -268,20 +254,20 @@ Block *OneToNPatternRewriter::applySignatureConversion(
LogicalResult
OneToNConversionPattern::matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const {
auto *typeConverter = getTypeConverter<OneToNTypeConverter>();
auto *typeConverter = getTypeConverter();

// Construct conversion mapping for results.
Operation::result_type_range originalResultTypes = op->getResultTypes();
OneToNTypeMapping resultMapping(originalResultTypes);
if (failed(typeConverter->computeTypeMapping(originalResultTypes,
resultMapping)))
if (failed(typeConverter->convertSignatureArgs(originalResultTypes,
resultMapping)))
return failure();

// Construct conversion mapping for operands.
Operation::operand_type_range originalOperandTypes = op->getOperandTypes();
OneToNTypeMapping operandMapping(originalOperandTypes);
if (failed(typeConverter->computeTypeMapping(originalOperandTypes,
operandMapping)))
if (failed(typeConverter->convertSignatureArgs(originalOperandTypes,
operandMapping)))
return failure();

// Cast operands to target types.
Expand Down Expand Up @@ -318,7 +304,7 @@ namespace mlir {
// inserted by this pass are annotated with a string attribute that also
// documents which kind of the cast (source, argument, or target).
LogicalResult
applyPartialOneToNConversion(Operation *op, OneToNTypeConverter &typeConverter,
applyPartialOneToNConversion(Operation *op, TypeConverter &typeConverter,
const FrozenRewritePatternSet &patterns) {
#ifndef NDEBUG
// Remember existing unrealized casts. This data structure is only used in
Expand Down Expand Up @@ -370,15 +356,13 @@ applyPartialOneToNConversion(Operation *op, OneToNTypeConverter &typeConverter,
// Target materialization.
assert(!areOperandTypesLegal && areResultsTypesLegal &&
operands.size() == 1 && "found unexpected target cast");
std::optional<SmallVector<Value>> maybeResults =
typeConverter.materializeTargetConversion(
rewriter, castOp->getLoc(), resultTypes, operands.front());
if (!maybeResults) {
materializedResults = typeConverter.materializeTargetConversion(
rewriter, castOp->getLoc(), resultTypes, operands.front());
if (materializedResults.empty()) {
emitError(castOp->getLoc())
<< "failed to create target materialization";
return failure();
}
materializedResults = maybeResults.value();
} else {
// Source and argument materializations.
assert(areOperandTypesLegal && !areResultsTypesLegal &&
Expand Down Expand Up @@ -427,18 +411,18 @@ class FunctionOpInterfaceSignatureConversion : public OneToNConversionPattern {
const OneToNTypeMapping &resultMapping,
ValueRange convertedOperands) const override {
auto funcOp = cast<FunctionOpInterface>(op);
auto *typeConverter = getTypeConverter<OneToNTypeConverter>();
auto *typeConverter = getTypeConverter();

// Construct mapping for function arguments.
OneToNTypeMapping argumentMapping(funcOp.getArgumentTypes());
if (failed(typeConverter->computeTypeMapping(funcOp.getArgumentTypes(),
argumentMapping)))
if (failed(typeConverter->convertSignatureArgs(funcOp.getArgumentTypes(),
argumentMapping)))
return failure();

// Construct mapping for function results.
OneToNTypeMapping funcResultMapping(funcOp.getResultTypes());
if (failed(typeConverter->computeTypeMapping(funcOp.getResultTypes(),
funcResultMapping)))
if (failed(typeConverter->convertSignatureArgs(funcOp.getResultTypes(),
funcResultMapping)))
return failure();

// Nothing to do if the op doesn't have any non-identity conversions for its
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,9 +147,14 @@ populateDecomposeTuplesTestPatterns(const TypeConverter &typeConverter,
///
/// This function has been copied (with small adaptions) from
/// TestDecomposeCallGraphTypes.cpp.
static std::optional<SmallVector<Value>>
buildGetTupleElementOps(OpBuilder &builder, TypeRange resultTypes, Value input,
Location loc) {
static SmallVector<Value> buildGetTupleElementOps(OpBuilder &builder,
TypeRange resultTypes,
ValueRange inputs,
Location loc) {
if (inputs.size() != 1)
return {};
Value input = inputs.front();

TupleType inputType = dyn_cast<TupleType>(input.getType());
if (!inputType)
return {};
Expand Down Expand Up @@ -222,7 +227,7 @@ void TestOneToNTypeConversionPass::runOnOperation() {
auto *context = &getContext();

// Assemble type converter.
OneToNTypeConverter typeConverter;
TypeConverter typeConverter;

typeConverter.addConversion([](Type type) { return type; });
typeConverter.addConversion(
Expand All @@ -234,6 +239,11 @@ void TestOneToNTypeConversionPass::runOnOperation() {
typeConverter.addArgumentMaterialization(buildMakeTupleOp);
typeConverter.addSourceMaterialization(buildMakeTupleOp);
typeConverter.addTargetMaterialization(buildGetTupleElementOps);
// Test the other target materialization variant that takes the original type
// as additional argument. This materialization function always fails.
typeConverter.addTargetMaterialization(
[](OpBuilder &builder, TypeRange resultTypes, ValueRange inputs,
Location loc, Type originalType) -> SmallVector<Value> { return {}; });

// Assemble patterns.
RewritePatternSet patterns(context);
Expand Down
Loading