Skip to content

[mlir] Align num elements type to LLVM ArrayType #93230

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
May 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions flang/test/Fir/convert-to-llvm.fir
Original file line number Diff line number Diff line change
Expand Up @@ -2698,3 +2698,9 @@ func.func @coordinate_array_unknown_size_1d(%arg0: !fir.ptr<!fir.array<? x i32>>
// CHECK: %[[VAL_2:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_1]]] : (!llvm.ptr, i64) -> !llvm.ptr, i32
// CHECK: llvm.return
// CHECK: }

// -----

fir.global common @c_(dense<0> : vector<4294967296xi8>) : !fir.array<4294967296xi8>

// CHECK: llvm.mlir.global common @c_(dense<0> : vector<4294967296xi8>) {addr_space = 0 : i32} : !llvm.array<4294967296 x i8>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We also have a zero operation for zero initialized types that uses llvm::Constant::getNullValue during the lowering to LLVM. This should be efficient for any kind of zero initialized values:

  llvm.mlir.global common @c)() {addr_space = 0 : i32} : !llvm.array<4294967296 x i8> {
    %0 = llvm.mlir.zero : !llvm.array<4294967296 x i8>
    llvm.return %0 : !llvm.array<4294967296 x i8>
  }

Note: that your change to the constant lowering still makes sense since there are multiple ways of expressing a zero constant and ideally all of them are efficient in the lowering.

4 changes: 2 additions & 2 deletions mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def LLVMArrayType : LLVMType<"LLVMArray", "array", [
```
}];

let parameters = (ins "Type":$elementType, "unsigned":$numElements);
let parameters = (ins "Type":$elementType, "uint64_t":$numElements);
let assemblyFormat = [{
`<` $numElements `x` custom<PrettyLLVMType>($elementType) `>`
}];
Expand All @@ -49,7 +49,7 @@ def LLVMArrayType : LLVMType<"LLVMArray", "array", [

let builders = [
TypeBuilderWithInferredContext<(ins "Type":$elementType,
"unsigned":$numElements)>
"uint64_t":$numElements)>
];

let extraClassDeclaration = [{
Expand Down
6 changes: 3 additions & 3 deletions mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,22 +154,22 @@ bool LLVMArrayType::isValidElementType(Type type) {
type);
}

LLVMArrayType LLVMArrayType::get(Type elementType, unsigned numElements) {
LLVMArrayType LLVMArrayType::get(Type elementType, uint64_t numElements) {
assert(elementType && "expected non-null subtype");
return Base::get(elementType.getContext(), elementType, numElements);
}

LLVMArrayType
LLVMArrayType::getChecked(function_ref<InFlightDiagnostic()> emitError,
Type elementType, unsigned numElements) {
Type elementType, uint64_t numElements) {
assert(elementType && "expected non-null subtype");
return Base::getChecked(emitError, elementType.getContext(), elementType,
numElements);
}

LogicalResult
LLVMArrayType::verify(function_ref<InFlightDiagnostic()> emitError,
Type elementType, unsigned numElements) {
Type elementType, uint64_t numElements) {
if (!isValidElementType(elementType))
return emitError() << "invalid array element type: " << elementType;
return success();
Expand Down
39 changes: 37 additions & 2 deletions mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -632,8 +632,43 @@ llvm::Constant *mlir::LLVM::detail::getLLVMConstant(
llvm::ElementCount::get(numElements, /*Scalable=*/isScalable), child);
if (llvmType->isArrayTy()) {
auto *arrayType = llvm::ArrayType::get(elementType, numElements);
SmallVector<llvm::Constant *, 8> constants(numElements, child);
return llvm::ConstantArray::get(arrayType, constants);
if (child->isZeroValue()) {
return llvm::ConstantAggregateZero::get(arrayType);
} else {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can you drop the else since there is a return in the if body? That should save one indention for the else body.

if (llvm::ConstantDataSequential::isElementTypeCompatible(
elementType)) {
// TODO: Handle all compatible types. This code only handles integer.
if (llvm::IntegerType *iTy =
dyn_cast<llvm::IntegerType>(elementType)) {
if (llvm::ConstantInt *ci = dyn_cast<llvm::ConstantInt>(child)) {
if (ci->getBitWidth() == 8) {
SmallVector<int8_t> constants(numElements, ci->getZExtValue());
return llvm::ConstantDataArray::get(elementType->getContext(),
constants);
}
if (ci->getBitWidth() == 16) {
SmallVector<int16_t> constants(numElements, ci->getZExtValue());
return llvm::ConstantDataArray::get(elementType->getContext(),
constants);
}
if (ci->getBitWidth() == 32) {
SmallVector<int32_t> constants(numElements, ci->getZExtValue());
return llvm::ConstantDataArray::get(elementType->getContext(),
constants);
}
if (ci->getBitWidth() == 64) {
SmallVector<int64_t> constants(numElements, ci->getZExtValue());
return llvm::ConstantDataArray::get(elementType->getContext(),
constants);
}
}
}
}
// std::vector is used here to accomodate large number of elements that
// exceed SmallVector capacity.
std::vector<llvm::Constant *> constants(numElements, child);
return llvm::ConstantArray::get(arrayType, constants);
}
}
}

Expand Down
5 changes: 5 additions & 0 deletions mlir/test/Target/LLVMIR/llvmir.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2396,3 +2396,8 @@ llvm.func @zeroinit_complex_local_aggregate() {
llvm.linker_options ["/DEFAULTLIB:", "libcmt"]
//CHECK: ![[MD1]] = !{!"/DEFAULTLIB:", !"libcmtd"}
llvm.linker_options ["/DEFAULTLIB:", "libcmtd"]

// -----

// CHECK: @big_ = common global [4294967296 x i8] zeroinitializer
llvm.mlir.global common @big_(dense<0> : vector<4294967296xi8>) {addr_space = 0 : i32} : !llvm.array<4294967296 x i8>
Loading