-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir] Align num elements type to LLVM ArrayType #93230
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-llvm @llvm/pr-subscribers-mlir Author: Valentin Clement (バレンタイン クレメン) (clementval) ChangesMLIR LLMArrayType is using
This leads to silent truncation when we use it for globals in flang.
The above program would result in a segfault since the global would be of size 0 because of the silent truncation.
became
This patch updates the definition of MLIR ArrayType to take Full diff: https://github.com/llvm/llvm-project/pull/93230.diff 5 Files Affected:
diff --git a/flang/test/Fir/convert-to-llvm.fir b/flang/test/Fir/convert-to-llvm.fir
index 21323a5e657c9..9bd4475e98009 100644
--- a/flang/test/Fir/convert-to-llvm.fir
+++ b/flang/test/Fir/convert-to-llvm.fir
@@ -2698,3 +2698,9 @@ func.func @coordinate_array_unknown_size_1d(%arg0: !fir.ptr<!fir.array<? x i32>>
// CHECK: %[[VAL_2:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_1]]] : (!llvm.ptr, i64) -> !llvm.ptr, i32
// CHECK: llvm.return
// CHECK: }
+
+// -----
+
+fir.global common @c_(dense<0> : vector<4294967296xi8>) : !fir.array<4294967296xi8>
+
+// CHECK: llvm.mlir.global common @c_(dense<0> : vector<4294967296xi8>) {addr_space = 0 : i32} : !llvm.array<4294967296 x i8>
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td
index b7176aa93ff1f..8f9c2f2f8a0b4 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td
@@ -40,7 +40,7 @@ def LLVMArrayType : LLVMType<"LLVMArray", "array", [
```
}];
- let parameters = (ins "Type":$elementType, "unsigned":$numElements);
+ let parameters = (ins "Type":$elementType, "uint64_t":$numElements);
let assemblyFormat = [{
`<` $numElements `x` custom<PrettyLLVMType>($elementType) `>`
}];
@@ -49,7 +49,7 @@ def LLVMArrayType : LLVMType<"LLVMArray", "array", [
let builders = [
TypeBuilderWithInferredContext<(ins "Type":$elementType,
- "unsigned":$numElements)>
+ "uint64_t":$numElements)>
];
let extraClassDeclaration = [{
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
index ad1dc4a36b82b..cf3f38b710130 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
@@ -154,14 +154,14 @@ bool LLVMArrayType::isValidElementType(Type type) {
type);
}
-LLVMArrayType LLVMArrayType::get(Type elementType, unsigned numElements) {
+LLVMArrayType LLVMArrayType::get(Type elementType, uint64_t numElements) {
assert(elementType && "expected non-null subtype");
return Base::get(elementType.getContext(), elementType, numElements);
}
LLVMArrayType
LLVMArrayType::getChecked(function_ref<InFlightDiagnostic()> emitError,
- Type elementType, unsigned numElements) {
+ Type elementType, uint64_t numElements) {
assert(elementType && "expected non-null subtype");
return Base::getChecked(emitError, elementType.getContext(), elementType,
numElements);
@@ -169,7 +169,7 @@ LLVMArrayType::getChecked(function_ref<InFlightDiagnostic()> emitError,
LogicalResult
LLVMArrayType::verify(function_ref<InFlightDiagnostic()> emitError,
- Type elementType, unsigned numElements) {
+ Type elementType, uint64_t numElements) {
if (!isValidElementType(elementType))
return emitError() << "invalid array element type: " << elementType;
return success();
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index 1ec0736ec08bf..5867d85ddf88d 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -632,7 +632,7 @@ llvm::Constant *mlir::LLVM::detail::getLLVMConstant(
llvm::ElementCount::get(numElements, /*Scalable=*/isScalable), child);
if (llvmType->isArrayTy()) {
auto *arrayType = llvm::ArrayType::get(elementType, numElements);
- SmallVector<llvm::Constant *, 8> constants(numElements, child);
+ std::vector<llvm::Constant *> constants(numElements, child);
return llvm::ConstantArray::get(arrayType, constants);
}
}
diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir
index 97f37939551d8..b080bb52b8c56 100644
--- a/mlir/test/Target/LLVMIR/llvmir.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir.mlir
@@ -2396,3 +2396,9 @@ llvm.func @zeroinit_complex_local_aggregate() {
llvm.linker_options ["/DEFAULTLIB:", "libcmt"]
//CHECK: ![[MD1]] = !{!"/DEFAULTLIB:", !"libcmtd"}
llvm.linker_options ["/DEFAULTLIB:", "libcmtd"]
+
+// -----
+
+// Translation is currently very slow so the test is not enabled.
+//llvm.mlir.global common @big_(dense<0> : vector<4294967296xi8>) {addr_space = 0 : i32} : !llvm.array<4294967296 x i8>
+//XCHECK: @big_ = common global [4294967296 x i8] zeroinitializer
|
I wished we'd use signed integer consistently to forbid wrap-around :( |
@@ -632,7 +632,7 @@ llvm::Constant *mlir::LLVM::detail::getLLVMConstant( | |||
llvm::ElementCount::get(numElements, /*Scalable=*/isScalable), child); | |||
if (llvmType->isArrayTy()) { | |||
auto *arrayType = llvm::ArrayType::get(elementType, numElements); | |||
SmallVector<llvm::Constant *, 8> constants(numElements, child); | |||
std::vector<llvm::Constant *> constants(numElements, child); | |||
return llvm::ConstantArray::get(arrayType, constants); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We're in the splat case, we can have a fast path for zero by getting a ConstantAggregateZero
instead here.
And if not, we can go the ConstantDataArray
route instead when the element type is scalar.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍 Using ConstantAggregateZero
improves things a lot. I updated the code to use it when constant is known to be zero. I reverted the std::vector to SmallVector since I don't really have a use case that would require it now.
This allowed me to enable the translation test since it has a normal execution time now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A splat of non-zero would have the problem, wouldn't it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess. I'll have a look at that as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a efficient way to represent splat constant arrays with a non-zero value in LLVM?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That makes sense.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So I tried to use ConstantDataArray
for some example with i32 element type and I don't see a significant improvement. The code would also be more complicated since we need to to have an ArrayRef of the element type populated with the value of the splat constant.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you push this to see? Because ConstantArray will have to do it, so we at least double the memory consumption (need to copy the vector)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I pushed a version that works for i32 elementType.
I tried it with the translation of :
llvm.mlir.global internal @int_global_array(dense<1> : vector<4294967294xi32>) : !llvm.array<4294967294 x i32>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@joker-eph Did you have time to have a look at the ConstantDataArray path?
|
||
fir.global common @c_(dense<0> : vector<4294967296xi8>) : !fir.array<4294967296xi8> | ||
|
||
// CHECK: llvm.mlir.global common @c_(dense<0> : vector<4294967296xi8>) {addr_space = 0 : i32} : !llvm.array<4294967296 x i8> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We also have a zero operation for zero initialized types that uses llvm::Constant::getNullValue
during the lowering to LLVM. This should be efficient for any kind of zero initialized values:
llvm.mlir.global common @c)() {addr_space = 0 : i32} : !llvm.array<4294967296 x i8> {
%0 = llvm.mlir.zero : !llvm.array<4294967296 x i8>
llvm.return %0 : !llvm.array<4294967296 x i8>
}
Note: that your change to the constant lowering still makes sense since there are multiple ways of expressing a zero constant and ideally all of them are efficient in the lowering.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM from my side modulo last comment.
✅ With the latest revision this PR passed the C/C++ code formatter. |
return llvm::ConstantAggregateZero::get(arrayType); | ||
} else { | ||
if (llvm::ConstantDataSequential::isElementTypeCompatible(elementType)) { | ||
// TODO: Handle all compatible types. This code only handle i32. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
// TODO: Handle all compatible types. This code only handle i32. | |
// TODO: Handle all compatible types. This code only handles i32. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the review. I'm not sure we want to keep the ConstantDataArray
path. I'm happy to add more types here if we want to keep it in this PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suspect we want to have it for efficiency reasons. We have had compile time issues with large nested constants previously that I suspect using ConstantData avoids. This is perfectly fine for a separate PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok! I have added at least all int types supported.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Still LGTM!
mlir/test/Target/LLVMIR/llvmir.mlir
Outdated
|
||
// ----- | ||
|
||
//CHECK: @big_ = common global [4294967296 x i8] zeroinitializer |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
//CHECK: @big_ = common global [4294967296 x i8] zeroinitializer | |
// CHECK: @big_ = common global [4294967296 x i8] zeroinitializer |
ultra nit: missing space
return llvm::ConstantArray::get(arrayType, constants); | ||
if (child->isZeroValue()) { | ||
return llvm::ConstantAggregateZero::get(arrayType); | ||
} else { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: can you drop the else since there is a return in the if body? That should save one indention for the else body.
SmallVector<int8_t> constants(numElements, ci->getZExtValue()); | ||
return llvm::ConstantDataArray::get(elementType->getContext(), | ||
constants); | ||
} else if (ci->getBitWidth() == 16) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: These else's are probably also redundant?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
MLIR LLMArrayType is using
unsigned
for the number of elements while LLVM ArrayType is usinguint64_t
llvm-project/llvm/include/llvm/IR/DerivedTypes.h
Line 377 in 4ae896f
This leads to silent truncation when we use it for globals in flang.
The above program would result in a segfault since the global would be of size 0 because of the silent truncation.
became
This patch updates the definition of MLIR ArrayType to take
uint64_t
as argument of the number of elements to be compatible with LLVM.