Skip to content

[mlir] Fix use-after-return in #117513 #120968

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 23, 2024

Conversation

matthias-springer
Copy link
Member

@matthias-springer matthias-springer commented Dec 23, 2024

Fix a use-after-return in #117513. Free-standing lambdas should not be defined inside of the LLVMTypeConverter constructor because they go out of scope.

@llvmbot
Copy link
Member

llvmbot commented Dec 23, 2024

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

Changes

Fix a use-after-free in #117513. Free-standing lambdas should not be defined inside of the LLVMTypeConverter constructor because they go out of scope.


Full diff: https://github.com/llvm/llvm-project/pull/120968.diff

2 Files Affected:

  • (modified) mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h (+35-35)
  • (modified) mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp (+88-72)
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
index d79b90f840ce83..38b5e492a8ed8f 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
@@ -161,6 +161,41 @@ class LLVMTypeConverter : public TypeConverter {
   /// Check if a memref type can be converted to a bare pointer.
   static bool canConvertToBarePtr(BaseMemRefType type);
 
+  /// Convert a memref type into a list of LLVM IR types that will form the
+  /// memref descriptor. If `unpackAggregates` is true the `sizes` and `strides`
+  /// arrays in the descriptors are unpacked to individual index-typed elements,
+  /// else they are kept as rank-sized arrays of index type. In particular,
+  /// the list will contain:
+  /// - two pointers to the memref element type, followed by
+  /// - an index-typed offset, followed by
+  /// - (if unpackAggregates = true)
+  ///    - one index-typed size per dimension of the memref, followed by
+  ///    - one index-typed stride per dimension of the memref.
+  /// - (if unpackArrregates = false)
+  ///   - one rank-sized array of index-type for the size of each dimension
+  ///   - one rank-sized array of index-type for the stride of each dimension
+  ///
+  /// For example, memref<?x?xf32> is converted to the following list:
+  /// - `!llvm<"float*">` (allocated pointer),
+  /// - `!llvm<"float*">` (aligned pointer),
+  /// - `i64` (offset),
+  /// - `i64`, `i64` (sizes),
+  /// - `i64`, `i64` (strides).
+  /// These types can be recomposed to a memref descriptor struct.
+  SmallVector<Type, 5> getMemRefDescriptorFields(MemRefType type,
+                                                 bool unpackAggregates) const;
+
+  /// Convert an unranked memref type into a list of non-aggregate LLVM IR types
+  /// that will form the unranked memref descriptor. In particular, this list
+  /// contains:
+  /// - an integer rank, followed by
+  /// - a pointer to the memref descriptor struct.
+  /// For example, memref<*xf32> is converted to the following list:
+  /// i64 (rank)
+  /// !llvm<"i8*"> (type-erased pointer).
+  /// These types can be recomposed to a unranked memref descriptor struct.
+  SmallVector<Type, 2> getUnrankedMemRefDescriptorFields() const;
+
 protected:
   /// Pointer to the LLVM dialect.
   LLVM::LLVMDialect *llvmDialect;
@@ -213,41 +248,6 @@ class LLVMTypeConverter : public TypeConverter {
   /// Convert a memref type into an LLVM type that captures the relevant data.
   Type convertMemRefType(MemRefType type) const;
 
-  /// Convert a memref type into a list of LLVM IR types that will form the
-  /// memref descriptor. If `unpackAggregates` is true the `sizes` and `strides`
-  /// arrays in the descriptors are unpacked to individual index-typed elements,
-  /// else they are kept as rank-sized arrays of index type. In particular,
-  /// the list will contain:
-  /// - two pointers to the memref element type, followed by
-  /// - an index-typed offset, followed by
-  /// - (if unpackAggregates = true)
-  ///    - one index-typed size per dimension of the memref, followed by
-  ///    - one index-typed stride per dimension of the memref.
-  /// - (if unpackArrregates = false)
-  ///   - one rank-sized array of index-type for the size of each dimension
-  ///   - one rank-sized array of index-type for the stride of each dimension
-  ///
-  /// For example, memref<?x?xf32> is converted to the following list:
-  /// - `!llvm<"float*">` (allocated pointer),
-  /// - `!llvm<"float*">` (aligned pointer),
-  /// - `i64` (offset),
-  /// - `i64`, `i64` (sizes),
-  /// - `i64`, `i64` (strides).
-  /// These types can be recomposed to a memref descriptor struct.
-  SmallVector<Type, 5> getMemRefDescriptorFields(MemRefType type,
-                                                 bool unpackAggregates) const;
-
-  /// Convert an unranked memref type into a list of non-aggregate LLVM IR types
-  /// that will form the unranked memref descriptor. In particular, this list
-  /// contains:
-  /// - an integer rank, followed by
-  /// - a pointer to the memref descriptor struct.
-  /// For example, memref<*xf32> is converted to the following list:
-  /// i64 (rank)
-  /// !llvm<"i8*"> (type-erased pointer).
-  /// These types can be recomposed to a unranked memref descriptor struct.
-  SmallVector<Type, 2> getUnrankedMemRefDescriptorFields() const;
-
   /// Convert an unranked memref type to an LLVM type that captures the
   /// runtime rank and a pointer to the static ranked memref desc
   Type convertUnrankedMemRefType(UnrankedMemRefType type) const;
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index e2ab0ed6f66cc5..1a7951282d3f78 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -44,6 +44,74 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
                                      const DataLayoutAnalysis *analysis)
     : LLVMTypeConverter(ctx, LowerToLLVMOptions(ctx), analysis) {}
 
+/// Helper function that checks if the given value range is a bare pointer.
+static bool isBarePointer(ValueRange values) {
+  return values.size() == 1 &&
+         isa<LLVM::LLVMPointerType>(values.front().getType());
+};
+
+/// Pack SSA values into an unranked memref descriptor struct.
+static Value packUnrankedMemRefDesc(OpBuilder &builder,
+                                    UnrankedMemRefType resultType,
+                                    ValueRange inputs, Location loc,
+                                    const LLVMTypeConverter &converter) {
+  // Note: Bare pointers are not supported for unranked memrefs because a
+  // memref descriptor cannot be built just from a bare pointer.
+  if (TypeRange(inputs) != converter.getUnrankedMemRefDescriptorFields())
+    return Value();
+  return UnrankedMemRefDescriptor::pack(builder, loc, converter, resultType,
+                                        inputs);
+}
+
+/// Pack SSA values into a ranked memref descriptor struct.
+static Value packRankedMemRefDesc(OpBuilder &builder, MemRefType resultType,
+                                  ValueRange inputs, Location loc,
+                                  const LLVMTypeConverter &converter) {
+  assert(resultType && "expected non-null result type");
+  if (isBarePointer(inputs))
+    return MemRefDescriptor::fromStaticShape(builder, loc, converter,
+                                             resultType, inputs[0]);
+  if (TypeRange(inputs) ==
+      converter.getMemRefDescriptorFields(resultType,
+                                          /*unpackAggregates=*/true))
+    return MemRefDescriptor::pack(builder, loc, converter, resultType, inputs);
+  // The inputs are neither a bare pointer nor an unpacked memref descriptor.
+  // This materialization function cannot be used.
+  return Value();
+}
+
+/// MemRef descriptor elements -> UnrankedMemRefType
+static Value unrankedMemRefMaterialization(OpBuilder &builder,
+                                           UnrankedMemRefType resultType,
+                                           ValueRange inputs, Location loc,
+                                           const LLVMTypeConverter &converter) {
+  // An argument materialization must return a value of type
+  // `resultType`, so insert a cast from the memref descriptor type
+  // (!llvm.struct) to the original memref type.
+  Value packed =
+      packUnrankedMemRefDesc(builder, resultType, inputs, loc, converter);
+  if (!packed)
+    return Value();
+  return builder.create<UnrealizedConversionCastOp>(loc, resultType, packed)
+      .getResult(0);
+};
+
+/// MemRef descriptor elements -> MemRefType
+static Value rankedMemRefMaterialization(OpBuilder &builder,
+                                         MemRefType resultType,
+                                         ValueRange inputs, Location loc,
+                                         const LLVMTypeConverter &converter) {
+  // An argument materialization must return a value of type `resultType`,
+  // so insert a cast from the memref descriptor type (!llvm.struct) to the
+  // original memref type.
+  Value packed =
+      packRankedMemRefDesc(builder, resultType, inputs, loc, converter);
+  if (!packed)
+    return Value();
+  return builder.create<UnrealizedConversionCastOp>(loc, resultType, packed)
+      .getResult(0);
+}
+
 /// Create an LLVMTypeConverter using custom LowerToLLVMOptions.
 LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
                                      const LowerToLLVMOptions &options,
@@ -166,81 +234,29 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
         .getResult(0);
   });
 
-  // Helper function that checks if the given value range is a bare pointer.
-  auto isBarePointer = [](ValueRange values) {
-    return values.size() == 1 &&
-           isa<LLVM::LLVMPointerType>(values.front().getType());
-  };
-
-  // TODO: For some reason, `this` is nullptr in here, so the LLVMTypeConverter
-  // must be passed explicitly.
-  auto packUnrankedMemRefDesc =
-      [&](OpBuilder &builder, UnrankedMemRefType resultType, ValueRange inputs,
-          Location loc, LLVMTypeConverter &converter) -> Value {
-    // Note: Bare pointers are not supported for unranked memrefs because a
-    // memref descriptor cannot be built just from a bare pointer.
-    if (TypeRange(inputs) != converter.getUnrankedMemRefDescriptorFields())
-      return Value();
-    return UnrankedMemRefDescriptor::pack(builder, loc, converter, resultType,
-                                          inputs);
-  };
-
-  // MemRef descriptor elements -> UnrankedMemRefType
-  auto unrakedMemRefMaterialization = [&](OpBuilder &builder,
-                                          UnrankedMemRefType resultType,
-                                          ValueRange inputs, Location loc) {
-    // An argument materialization must return a value of type
-    // `resultType`, so insert a cast from the memref descriptor type
-    // (!llvm.struct) to the original memref type.
-    Value packed =
-        packUnrankedMemRefDesc(builder, resultType, inputs, loc, *this);
-    if (!packed)
-      return Value();
-    return builder.create<UnrealizedConversionCastOp>(loc, resultType, packed)
-        .getResult(0);
-  };
-
-  // TODO: For some reason, `this` is nullptr in here, so the LLVMTypeConverter
-  // must be passed explicitly.
-  auto packRankedMemRefDesc = [&](OpBuilder &builder, MemRefType resultType,
-                                  ValueRange inputs, Location loc,
-                                  LLVMTypeConverter &converter) -> Value {
-    assert(resultType && "expected non-null result type");
-    if (isBarePointer(inputs))
-      return MemRefDescriptor::fromStaticShape(builder, loc, converter,
-                                               resultType, inputs[0]);
-    if (TypeRange(inputs) ==
-        converter.getMemRefDescriptorFields(resultType,
-                                            /*unpackAggregates=*/true))
-      return MemRefDescriptor::pack(builder, loc, converter, resultType,
-                                    inputs);
-    // The inputs are neither a bare pointer nor an unpacked memref descriptor.
-    // This materialization function cannot be used.
-    return Value();
-  };
-
-  // MemRef descriptor elements -> MemRefType
-  auto rankedMemRefMaterialization = [&](OpBuilder &builder,
-                                         MemRefType resultType,
-                                         ValueRange inputs, Location loc) {
-    // An argument materialization must return a value of type `resultType`,
-    // so insert a cast from the memref descriptor type (!llvm.struct) to the
-    // original memref type.
-    Value packed =
-        packRankedMemRefDesc(builder, resultType, inputs, loc, *this);
-    if (!packed)
-      return Value();
-    return builder.create<UnrealizedConversionCastOp>(loc, resultType, packed)
-        .getResult(0);
-  };
-
   // Argument materializations convert from the new block argument types
   // (multiple SSA values that make up a memref descriptor) back to the
   // original block argument type.
-  addArgumentMaterialization(unrakedMemRefMaterialization);
-  addArgumentMaterialization(rankedMemRefMaterialization);
-  addSourceMaterialization(unrakedMemRefMaterialization);
-  addSourceMaterialization(rankedMemRefMaterialization);
+  addArgumentMaterialization([&](OpBuilder &builder,
+                                 UnrankedMemRefType resultType,
+                                 ValueRange inputs, Location loc) {
+    return unrankedMemRefMaterialization(builder, resultType, inputs, loc,
+                                         *this);
+  });
+  addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType,
+                                 ValueRange inputs, Location loc) {
+    return rankedMemRefMaterialization(builder, resultType, inputs, loc, *this);
+  });
+  addSourceMaterialization([&](OpBuilder &builder,
+                               UnrankedMemRefType resultType, ValueRange inputs,
+                               Location loc) {
+    return unrankedMemRefMaterialization(builder, resultType, inputs, loc,
+                                         *this);
+  });
+  addSourceMaterialization([&](OpBuilder &builder, MemRefType resultType,
+                               ValueRange inputs, Location loc) {
+    return rankedMemRefMaterialization(builder, resultType, inputs, loc, *this);
+  });
 
   // Bare pointer -> Packed MemRef descriptor
   addTargetMaterialization([&](OpBuilder &builder, Type resultType,

@llvmbot
Copy link
Member

llvmbot commented Dec 23, 2024

@llvm/pr-subscribers-mlir-llvm

Author: Matthias Springer (matthias-springer)

Changes

Fix a use-after-free in #117513. Free-standing lambdas should not be defined inside of the LLVMTypeConverter constructor because they go out of scope.


Full diff: https://github.com/llvm/llvm-project/pull/120968.diff

2 Files Affected:

  • (modified) mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h (+35-35)
  • (modified) mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp (+88-72)
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
index d79b90f840ce83..38b5e492a8ed8f 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
@@ -161,6 +161,41 @@ class LLVMTypeConverter : public TypeConverter {
   /// Check if a memref type can be converted to a bare pointer.
   static bool canConvertToBarePtr(BaseMemRefType type);
 
+  /// Convert a memref type into a list of LLVM IR types that will form the
+  /// memref descriptor. If `unpackAggregates` is true the `sizes` and `strides`
+  /// arrays in the descriptors are unpacked to individual index-typed elements,
+  /// else they are kept as rank-sized arrays of index type. In particular,
+  /// the list will contain:
+  /// - two pointers to the memref element type, followed by
+  /// - an index-typed offset, followed by
+  /// - (if unpackAggregates = true)
+  ///    - one index-typed size per dimension of the memref, followed by
+  ///    - one index-typed stride per dimension of the memref.
+  /// - (if unpackArrregates = false)
+  ///   - one rank-sized array of index-type for the size of each dimension
+  ///   - one rank-sized array of index-type for the stride of each dimension
+  ///
+  /// For example, memref<?x?xf32> is converted to the following list:
+  /// - `!llvm<"float*">` (allocated pointer),
+  /// - `!llvm<"float*">` (aligned pointer),
+  /// - `i64` (offset),
+  /// - `i64`, `i64` (sizes),
+  /// - `i64`, `i64` (strides).
+  /// These types can be recomposed to a memref descriptor struct.
+  SmallVector<Type, 5> getMemRefDescriptorFields(MemRefType type,
+                                                 bool unpackAggregates) const;
+
+  /// Convert an unranked memref type into a list of non-aggregate LLVM IR types
+  /// that will form the unranked memref descriptor. In particular, this list
+  /// contains:
+  /// - an integer rank, followed by
+  /// - a pointer to the memref descriptor struct.
+  /// For example, memref<*xf32> is converted to the following list:
+  /// i64 (rank)
+  /// !llvm<"i8*"> (type-erased pointer).
+  /// These types can be recomposed to a unranked memref descriptor struct.
+  SmallVector<Type, 2> getUnrankedMemRefDescriptorFields() const;
+
 protected:
   /// Pointer to the LLVM dialect.
   LLVM::LLVMDialect *llvmDialect;
@@ -213,41 +248,6 @@ class LLVMTypeConverter : public TypeConverter {
   /// Convert a memref type into an LLVM type that captures the relevant data.
   Type convertMemRefType(MemRefType type) const;
 
-  /// Convert a memref type into a list of LLVM IR types that will form the
-  /// memref descriptor. If `unpackAggregates` is true the `sizes` and `strides`
-  /// arrays in the descriptors are unpacked to individual index-typed elements,
-  /// else they are kept as rank-sized arrays of index type. In particular,
-  /// the list will contain:
-  /// - two pointers to the memref element type, followed by
-  /// - an index-typed offset, followed by
-  /// - (if unpackAggregates = true)
-  ///    - one index-typed size per dimension of the memref, followed by
-  ///    - one index-typed stride per dimension of the memref.
-  /// - (if unpackArrregates = false)
-  ///   - one rank-sized array of index-type for the size of each dimension
-  ///   - one rank-sized array of index-type for the stride of each dimension
-  ///
-  /// For example, memref<?x?xf32> is converted to the following list:
-  /// - `!llvm<"float*">` (allocated pointer),
-  /// - `!llvm<"float*">` (aligned pointer),
-  /// - `i64` (offset),
-  /// - `i64`, `i64` (sizes),
-  /// - `i64`, `i64` (strides).
-  /// These types can be recomposed to a memref descriptor struct.
-  SmallVector<Type, 5> getMemRefDescriptorFields(MemRefType type,
-                                                 bool unpackAggregates) const;
-
-  /// Convert an unranked memref type into a list of non-aggregate LLVM IR types
-  /// that will form the unranked memref descriptor. In particular, this list
-  /// contains:
-  /// - an integer rank, followed by
-  /// - a pointer to the memref descriptor struct.
-  /// For example, memref<*xf32> is converted to the following list:
-  /// i64 (rank)
-  /// !llvm<"i8*"> (type-erased pointer).
-  /// These types can be recomposed to a unranked memref descriptor struct.
-  SmallVector<Type, 2> getUnrankedMemRefDescriptorFields() const;
-
   /// Convert an unranked memref type to an LLVM type that captures the
   /// runtime rank and a pointer to the static ranked memref desc
   Type convertUnrankedMemRefType(UnrankedMemRefType type) const;
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index e2ab0ed6f66cc5..1a7951282d3f78 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -44,6 +44,74 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
                                      const DataLayoutAnalysis *analysis)
     : LLVMTypeConverter(ctx, LowerToLLVMOptions(ctx), analysis) {}
 
+/// Helper function that checks if the given value range is a bare pointer.
+static bool isBarePointer(ValueRange values) {
+  return values.size() == 1 &&
+         isa<LLVM::LLVMPointerType>(values.front().getType());
+};
+
+/// Pack SSA values into an unranked memref descriptor struct.
+static Value packUnrankedMemRefDesc(OpBuilder &builder,
+                                    UnrankedMemRefType resultType,
+                                    ValueRange inputs, Location loc,
+                                    const LLVMTypeConverter &converter) {
+  // Note: Bare pointers are not supported for unranked memrefs because a
+  // memref descriptor cannot be built just from a bare pointer.
+  if (TypeRange(inputs) != converter.getUnrankedMemRefDescriptorFields())
+    return Value();
+  return UnrankedMemRefDescriptor::pack(builder, loc, converter, resultType,
+                                        inputs);
+}
+
+/// Pack SSA values into a ranked memref descriptor struct.
+static Value packRankedMemRefDesc(OpBuilder &builder, MemRefType resultType,
+                                  ValueRange inputs, Location loc,
+                                  const LLVMTypeConverter &converter) {
+  assert(resultType && "expected non-null result type");
+  if (isBarePointer(inputs))
+    return MemRefDescriptor::fromStaticShape(builder, loc, converter,
+                                             resultType, inputs[0]);
+  if (TypeRange(inputs) ==
+      converter.getMemRefDescriptorFields(resultType,
+                                          /*unpackAggregates=*/true))
+    return MemRefDescriptor::pack(builder, loc, converter, resultType, inputs);
+  // The inputs are neither a bare pointer nor an unpacked memref descriptor.
+  // This materialization function cannot be used.
+  return Value();
+}
+
+/// MemRef descriptor elements -> UnrankedMemRefType
+static Value unrankedMemRefMaterialization(OpBuilder &builder,
+                                           UnrankedMemRefType resultType,
+                                           ValueRange inputs, Location loc,
+                                           const LLVMTypeConverter &converter) {
+  // An argument materialization must return a value of type
+  // `resultType`, so insert a cast from the memref descriptor type
+  // (!llvm.struct) to the original memref type.
+  Value packed =
+      packUnrankedMemRefDesc(builder, resultType, inputs, loc, converter);
+  if (!packed)
+    return Value();
+  return builder.create<UnrealizedConversionCastOp>(loc, resultType, packed)
+      .getResult(0);
+};
+
+/// MemRef descriptor elements -> MemRefType
+static Value rankedMemRefMaterialization(OpBuilder &builder,
+                                         MemRefType resultType,
+                                         ValueRange inputs, Location loc,
+                                         const LLVMTypeConverter &converter) {
+  // An argument materialization must return a value of type `resultType`,
+  // so insert a cast from the memref descriptor type (!llvm.struct) to the
+  // original memref type.
+  Value packed =
+      packRankedMemRefDesc(builder, resultType, inputs, loc, converter);
+  if (!packed)
+    return Value();
+  return builder.create<UnrealizedConversionCastOp>(loc, resultType, packed)
+      .getResult(0);
+}
+
 /// Create an LLVMTypeConverter using custom LowerToLLVMOptions.
 LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
                                      const LowerToLLVMOptions &options,
@@ -166,81 +234,29 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
         .getResult(0);
   });
 
-  // Helper function that checks if the given value range is a bare pointer.
-  auto isBarePointer = [](ValueRange values) {
-    return values.size() == 1 &&
-           isa<LLVM::LLVMPointerType>(values.front().getType());
-  };
-
-  // TODO: For some reason, `this` is nullptr in here, so the LLVMTypeConverter
-  // must be passed explicitly.
-  auto packUnrankedMemRefDesc =
-      [&](OpBuilder &builder, UnrankedMemRefType resultType, ValueRange inputs,
-          Location loc, LLVMTypeConverter &converter) -> Value {
-    // Note: Bare pointers are not supported for unranked memrefs because a
-    // memref descriptor cannot be built just from a bare pointer.
-    if (TypeRange(inputs) != converter.getUnrankedMemRefDescriptorFields())
-      return Value();
-    return UnrankedMemRefDescriptor::pack(builder, loc, converter, resultType,
-                                          inputs);
-  };
-
-  // MemRef descriptor elements -> UnrankedMemRefType
-  auto unrakedMemRefMaterialization = [&](OpBuilder &builder,
-                                          UnrankedMemRefType resultType,
-                                          ValueRange inputs, Location loc) {
-    // An argument materialization must return a value of type
-    // `resultType`, so insert a cast from the memref descriptor type
-    // (!llvm.struct) to the original memref type.
-    Value packed =
-        packUnrankedMemRefDesc(builder, resultType, inputs, loc, *this);
-    if (!packed)
-      return Value();
-    return builder.create<UnrealizedConversionCastOp>(loc, resultType, packed)
-        .getResult(0);
-  };
-
-  // TODO: For some reason, `this` is nullptr in here, so the LLVMTypeConverter
-  // must be passed explicitly.
-  auto packRankedMemRefDesc = [&](OpBuilder &builder, MemRefType resultType,
-                                  ValueRange inputs, Location loc,
-                                  LLVMTypeConverter &converter) -> Value {
-    assert(resultType && "expected non-null result type");
-    if (isBarePointer(inputs))
-      return MemRefDescriptor::fromStaticShape(builder, loc, converter,
-                                               resultType, inputs[0]);
-    if (TypeRange(inputs) ==
-        converter.getMemRefDescriptorFields(resultType,
-                                            /*unpackAggregates=*/true))
-      return MemRefDescriptor::pack(builder, loc, converter, resultType,
-                                    inputs);
-    // The inputs are neither a bare pointer nor an unpacked memref descriptor.
-    // This materialization function cannot be used.
-    return Value();
-  };
-
-  // MemRef descriptor elements -> MemRefType
-  auto rankedMemRefMaterialization = [&](OpBuilder &builder,
-                                         MemRefType resultType,
-                                         ValueRange inputs, Location loc) {
-    // An argument materialization must return a value of type `resultType`,
-    // so insert a cast from the memref descriptor type (!llvm.struct) to the
-    // original memref type.
-    Value packed =
-        packRankedMemRefDesc(builder, resultType, inputs, loc, *this);
-    if (!packed)
-      return Value();
-    return builder.create<UnrealizedConversionCastOp>(loc, resultType, packed)
-        .getResult(0);
-  };
-
   // Argument materializations convert from the new block argument types
   // (multiple SSA values that make up a memref descriptor) back to the
   // original block argument type.
-  addArgumentMaterialization(unrakedMemRefMaterialization);
-  addArgumentMaterialization(rankedMemRefMaterialization);
-  addSourceMaterialization(unrakedMemRefMaterialization);
-  addSourceMaterialization(rankedMemRefMaterialization);
+  addArgumentMaterialization([&](OpBuilder &builder,
+                                 UnrankedMemRefType resultType,
+                                 ValueRange inputs, Location loc) {
+    return unrankedMemRefMaterialization(builder, resultType, inputs, loc,
+                                         *this);
+  });
+  addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType,
+                                 ValueRange inputs, Location loc) {
+    return rankedMemRefMaterialization(builder, resultType, inputs, loc, *this);
+  });
+  addSourceMaterialization([&](OpBuilder &builder,
+                               UnrankedMemRefType resultType, ValueRange inputs,
+                               Location loc) {
+    return unrankedMemRefMaterialization(builder, resultType, inputs, loc,
+                                         *this);
+  });
+  addSourceMaterialization([&](OpBuilder &builder, MemRefType resultType,
+                               ValueRange inputs, Location loc) {
+    return rankedMemRefMaterialization(builder, resultType, inputs, loc, *this);
+  });
 
   // Bare pointer -> Packed MemRef descriptor
   addTargetMaterialization([&](OpBuilder &builder, Type resultType,

@matthias-springer matthias-springer changed the title [mlir] Fix use-after-free in #117513 [mlir] Fix use-after-return in #117513 Dec 23, 2024
@matthias-springer matthias-springer merged commit df31fd8 into main Dec 23, 2024
9 of 11 checks passed
@matthias-springer matthias-springer deleted the users/matthias-springer/fix_build_12_23 branch December 23, 2024 14:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants