Skip to content

[mlir] Align num elements type to LLVM ArrayType #93230

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
May 29, 2024

Conversation

clementval
Copy link
Contributor

MLIR LLMArrayType is using unsigned for the number of elements while LLVM ArrayType is using uint64_t

ArrayType(Type *ElType, uint64_t NumEl);

This leads to silent truncation when we use it for globals in flang.

program test
  integer(8), parameter :: large = 2**30
  real,  dimension(large) :: bigarray
  common /c/ bigarray
  bigarray(999) = 666
end

The above program would result in a segfault since the global would be of size 0 because of the silent truncation.

fir.global common @c_(dense<0> : vector<4294967296xi8>) : !fir.array<4294967296xi8>

became

llvm.mlir.global common @c_(dense<0> : vector<4294967296xi8>) {addr_space = 0 : i32} : !llvm.array<0 x i8>

This patch updates the definition of MLIR ArrayType to take uint64_t as argument of the number of elements to be compatible with LLVM.

@llvmbot llvmbot added mlir:llvm mlir flang Flang issues not falling into any other category labels May 23, 2024
@clementval clementval requested a review from ftynse May 23, 2024 19:01
@llvmbot
Copy link
Member

llvmbot commented May 23, 2024

@llvm/pr-subscribers-mlir-llvm

@llvm/pr-subscribers-mlir

Author: Valentin Clement (バレンタイン クレメン) (clementval)

Changes

MLIR LLMArrayType is using unsigned for the number of elements while LLVM ArrayType is using uint64_t

ArrayType(Type *ElType, uint64_t NumEl);

This leads to silent truncation when we use it for globals in flang.

program test
  integer(8), parameter :: large = 2**30
  real,  dimension(large) :: bigarray
  common /c/ bigarray
  bigarray(999) = 666
end

The above program would result in a segfault since the global would be of size 0 because of the silent truncation.

fir.global common @<!-- -->c_(dense&lt;0&gt; : vector&lt;4294967296xi8&gt;) : !fir.array&lt;4294967296xi8&gt;

became

llvm.mlir.global common @<!-- -->c_(dense&lt;0&gt; : vector&lt;4294967296xi8&gt;) {addr_space = 0 : i32} : !llvm.array&lt;0 x i8&gt;

This patch updates the definition of MLIR ArrayType to take uint64_t as argument of the number of elements to be compatible with LLVM.


Full diff: https://github.com/llvm/llvm-project/pull/93230.diff

5 Files Affected:

  • (modified) flang/test/Fir/convert-to-llvm.fir (+6)
  • (modified) mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td (+2-2)
  • (modified) mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp (+3-3)
  • (modified) mlir/lib/Target/LLVMIR/ModuleTranslation.cpp (+1-1)
  • (modified) mlir/test/Target/LLVMIR/llvmir.mlir (+6)
diff --git a/flang/test/Fir/convert-to-llvm.fir b/flang/test/Fir/convert-to-llvm.fir
index 21323a5e657c9..9bd4475e98009 100644
--- a/flang/test/Fir/convert-to-llvm.fir
+++ b/flang/test/Fir/convert-to-llvm.fir
@@ -2698,3 +2698,9 @@ func.func @coordinate_array_unknown_size_1d(%arg0: !fir.ptr<!fir.array<? x i32>>
 // CHECK:           %[[VAL_2:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_1]]] : (!llvm.ptr, i64) -> !llvm.ptr, i32
 // CHECK:           llvm.return
 // CHECK:         }
+
+// -----
+
+fir.global common @c_(dense<0> : vector<4294967296xi8>) : !fir.array<4294967296xi8>
+
+// CHECK: llvm.mlir.global common @c_(dense<0> : vector<4294967296xi8>) {addr_space = 0 : i32} : !llvm.array<4294967296 x i8>
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td
index b7176aa93ff1f..8f9c2f2f8a0b4 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td
@@ -40,7 +40,7 @@ def LLVMArrayType : LLVMType<"LLVMArray", "array", [
     ```
   }];
 
-  let parameters = (ins "Type":$elementType, "unsigned":$numElements);
+  let parameters = (ins "Type":$elementType, "uint64_t":$numElements);
   let assemblyFormat = [{
     `<` $numElements `x` custom<PrettyLLVMType>($elementType) `>`
   }];
@@ -49,7 +49,7 @@ def LLVMArrayType : LLVMType<"LLVMArray", "array", [
 
   let builders = [
     TypeBuilderWithInferredContext<(ins "Type":$elementType,
-                                        "unsigned":$numElements)>
+                                        "uint64_t":$numElements)>
   ];
 
   let extraClassDeclaration = [{
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
index ad1dc4a36b82b..cf3f38b710130 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
@@ -154,14 +154,14 @@ bool LLVMArrayType::isValidElementType(Type type) {
       type);
 }
 
-LLVMArrayType LLVMArrayType::get(Type elementType, unsigned numElements) {
+LLVMArrayType LLVMArrayType::get(Type elementType, uint64_t numElements) {
   assert(elementType && "expected non-null subtype");
   return Base::get(elementType.getContext(), elementType, numElements);
 }
 
 LLVMArrayType
 LLVMArrayType::getChecked(function_ref<InFlightDiagnostic()> emitError,
-                          Type elementType, unsigned numElements) {
+                          Type elementType, uint64_t numElements) {
   assert(elementType && "expected non-null subtype");
   return Base::getChecked(emitError, elementType.getContext(), elementType,
                           numElements);
@@ -169,7 +169,7 @@ LLVMArrayType::getChecked(function_ref<InFlightDiagnostic()> emitError,
 
 LogicalResult
 LLVMArrayType::verify(function_ref<InFlightDiagnostic()> emitError,
-                      Type elementType, unsigned numElements) {
+                      Type elementType, uint64_t numElements) {
   if (!isValidElementType(elementType))
     return emitError() << "invalid array element type: " << elementType;
   return success();
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index 1ec0736ec08bf..5867d85ddf88d 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -632,7 +632,7 @@ llvm::Constant *mlir::LLVM::detail::getLLVMConstant(
           llvm::ElementCount::get(numElements, /*Scalable=*/isScalable), child);
     if (llvmType->isArrayTy()) {
       auto *arrayType = llvm::ArrayType::get(elementType, numElements);
-      SmallVector<llvm::Constant *, 8> constants(numElements, child);
+      std::vector<llvm::Constant *> constants(numElements, child);
       return llvm::ConstantArray::get(arrayType, constants);
     }
   }
diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir
index 97f37939551d8..b080bb52b8c56 100644
--- a/mlir/test/Target/LLVMIR/llvmir.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir.mlir
@@ -2396,3 +2396,9 @@ llvm.func @zeroinit_complex_local_aggregate() {
 llvm.linker_options ["/DEFAULTLIB:", "libcmt"]
 //CHECK: ![[MD1]] = !{!"/DEFAULTLIB:", !"libcmtd"}
 llvm.linker_options ["/DEFAULTLIB:", "libcmtd"]
+
+// -----
+
+// Translation is currently very slow so the test is not enabled.
+//llvm.mlir.global common @big_(dense<0> : vector<4294967296xi8>) {addr_space = 0 : i32} : !llvm.array<4294967296 x i8>
+//XCHECK: @big_ = common global [4294967296 x i8] zeroinitializer

@joker-eph
Copy link
Collaborator

I wished we'd use signed integer consistently to forbid wrap-around :(

@@ -632,7 +632,7 @@ llvm::Constant *mlir::LLVM::detail::getLLVMConstant(
llvm::ElementCount::get(numElements, /*Scalable=*/isScalable), child);
if (llvmType->isArrayTy()) {
auto *arrayType = llvm::ArrayType::get(elementType, numElements);
SmallVector<llvm::Constant *, 8> constants(numElements, child);
std::vector<llvm::Constant *> constants(numElements, child);
return llvm::ConstantArray::get(arrayType, constants);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're in the splat case, we can have a fast path for zero by getting a ConstantAggregateZero instead here.

And if not, we can go the ConstantDataArray route instead when the element type is scalar.

Copy link
Contributor Author

@clementval clementval May 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 Using ConstantAggregateZero improves things a lot. I updated the code to use it when constant is known to be zero. I reverted the std::vector to SmallVector since I don't really have a use case that would require it now.

This allowed me to enable the translation test since it has a normal execution time now.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A splat of non-zero would have the problem, wouldn't it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess. I'll have a look at that as well.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a efficient way to represent splat constant arrays with a non-zero value in LLVM?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense.

Copy link
Contributor Author

@clementval clementval May 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I tried to use ConstantDataArray for some example with i32 element type and I don't see a significant improvement. The code would also be more complicated since we need to to have an ArrayRef of the element type populated with the value of the splat constant.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you push this to see? Because ConstantArray will have to do it, so we at least double the memory consumption (need to copy the vector)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I pushed a version that works for i32 elementType.

I tried it with the translation of :

llvm.mlir.global internal @int_global_array(dense<1> : vector<4294967294xi32>) : !llvm.array<4294967294 x i32>

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@joker-eph Did you have time to have a look at the ConstantDataArray path?


fir.global common @c_(dense<0> : vector<4294967296xi8>) : !fir.array<4294967296xi8>

// CHECK: llvm.mlir.global common @c_(dense<0> : vector<4294967296xi8>) {addr_space = 0 : i32} : !llvm.array<4294967296 x i8>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We also have a zero operation for zero initialized types that uses llvm::Constant::getNullValue during the lowering to LLVM. This should be efficient for any kind of zero initialized values:

  llvm.mlir.global common @c)() {addr_space = 0 : i32} : !llvm.array<4294967296 x i8> {
    %0 = llvm.mlir.zero : !llvm.array<4294967296 x i8>
    llvm.return %0 : !llvm.array<4294967296 x i8>
  }

Note: that your change to the constant lowering still makes sense since there are multiple ways of expressing a zero constant and ideally all of them are efficient in the lowering.

Copy link
Contributor

@gysit gysit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM from my side modulo last comment.

Copy link

github-actions bot commented May 24, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

return llvm::ConstantAggregateZero::get(arrayType);
} else {
if (llvm::ConstantDataSequential::isElementTypeCompatible(elementType)) {
// TODO: Handle all compatible types. This code only handle i32.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// TODO: Handle all compatible types. This code only handle i32.
// TODO: Handle all compatible types. This code only handles i32.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the review. I'm not sure we want to keep the ConstantDataArray path. I'm happy to add more types here if we want to keep it in this PR.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect we want to have it for efficiency reasons. We have had compile time issues with large nested constants previously that I suspect using ConstantData avoids. This is perfectly fine for a separate PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok! I have added at least all int types supported.

Copy link
Contributor

@gysit gysit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still LGTM!


// -----

//CHECK: @big_ = common global [4294967296 x i8] zeroinitializer
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
//CHECK: @big_ = common global [4294967296 x i8] zeroinitializer
// CHECK: @big_ = common global [4294967296 x i8] zeroinitializer

ultra nit: missing space

return llvm::ConstantArray::get(arrayType, constants);
if (child->isZeroValue()) {
return llvm::ConstantAggregateZero::get(arrayType);
} else {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can you drop the else since there is a return in the if body? That should save one indention for the else body.

SmallVector<int8_t> constants(numElements, ci->getZExtValue());
return llvm::ConstantDataArray::get(elementType->getContext(),
constants);
} else if (ci->getBitWidth() == 16) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: These else's are probably also redundant?

Copy link
Collaborator

@joker-eph joker-eph left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@clementval clementval merged commit 428b9be into llvm:main May 29, 2024
7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang Flang issues not falling into any other category mlir:llvm mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants