Skip to content

[OpenMP][CodeExtractor]Add align metadata to load instructions #131131

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 51 additions & 2 deletions llvm/lib/Transforms/Utils/CodeExtractor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,21 @@ buildExtractionBlockSet(ArrayRef<BasicBlock *> BBs, DominatorTree *DT,
return Result;
}

/// isAlignmentPreservedForAddrCast - Return true if the cast operation
/// for specified target preserves original alignment
static bool isAlignmentPreservedForAddrCast(const Triple &TargetTriple) {
switch (TargetTriple.getArch()) {
case Triple::ArchType::amdgcn:
case Triple::ArchType::r600:
return true;
// TODO: Add other architectures for which we are certain that alignment
// is preserved during address space cast operations.
default:
return false;
}
return false;
}

CodeExtractor::CodeExtractor(ArrayRef<BasicBlock *> BBs, DominatorTree *DT,
bool AggregateArgs, BlockFrequencyInfo *BFI,
BranchProbabilityInfo *BPI, AssumptionCache *AC,
Expand Down Expand Up @@ -1612,8 +1627,42 @@ void CodeExtractor::emitFunctionBody(
Idx[1] = ConstantInt::get(Type::getInt32Ty(header->getContext()), aggIdx);
GetElementPtrInst *GEP = GetElementPtrInst::Create(
StructArgTy, AggArg, Idx, "gep_" + inputs[i]->getName(), newFuncRoot);
RewriteVal = new LoadInst(StructArgTy->getElementType(aggIdx), GEP,
"loadgep_" + inputs[i]->getName(), newFuncRoot);
LoadInst *LoadGEP =
new LoadInst(StructArgTy->getElementType(aggIdx), GEP,
"loadgep_" + inputs[i]->getName(), newFuncRoot);
// If we load pointer, we can add optional !align metadata
// The existence of the !align metadata on the instruction tells
// the optimizer that the value loaded is known to be aligned to
// a boundary specified by the integer value in the metadata node.
// Example:
// %res = load ptr, ptr %input, align 8, !align !align_md_node
// ^ ^
// | |
// alignment of %input address |
// |
// alignment of %res object
if (StructArgTy->getElementType(aggIdx)->isPointerTy()) {
unsigned AlignmentValue;
const Triple &TargetTriple =
newFunction->getParent()->getTargetTriple();
const DataLayout &DL = header->getDataLayout();
// Pointers without casting can provide more information about
// alignment. Use pointers without casts if given target preserves
// alignment information for cast the operation.
if (isAlignmentPreservedForAddrCast(TargetTriple))
AlignmentValue =
inputs[i]->stripPointerCasts()->getPointerAlignment(DL).value();
Comment on lines +1653 to +1654
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to make getPointerAlignment to strip irrelevant casts itsself?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The getPointerAlignment function is from the Value class. Do you think it's worth adding another version of this function to this basic class? If yes, I will create a separate patch that will contain a version of the getPointerAlignment function with two arguments. One of them will be a flag to get the value alignment without pointer casts.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think improving getPointerAlignment itself would be generally useful. I was using it myself in

// TODO: Would be great if this could determine alignment through a GEP
EffectiveAlign = AtomicPtr->getPointerAlignment(EmitOptions.DL);
and was disappointed how quickly it gives up.

But also maybe does not belong into this PR.

else
AlignmentValue = inputs[i]->getPointerAlignment(DL).value();
MDBuilder MDB(header->getContext());
LoadGEP->setMetadata(
LLVMContext::MD_align,
MDNode::get(
header->getContext(),
MDB.createConstant(ConstantInt::get(
Type::getInt64Ty(header->getContext()), AlignmentValue))));
}
RewriteVal = LoadGEP;
++aggIdx;
} else
RewriteVal = &*ScalarAI++;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include "llvm/IR/DebugInfoMetadata.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/MDBuilder.h"
#include "llvm/IR/ReplaceConstant.h"
#include "llvm/Support/FileSystem.h"
#include "llvm/TargetParser/Triple.h"
Expand Down Expand Up @@ -4407,13 +4408,17 @@ createDeviceArgumentAccessor(MapInfoData &mapData, llvm::Argument &arg,
builder.restoreIP(allocaIP);

omp::VariableCaptureKind capture = omp::VariableCaptureKind::ByRef;

LLVM::TypeToLLVMIRTranslator typeToLLVMIRTranslator(
ompBuilder.M.getContext());
unsigned alignmentValue = 0;
// Find the associated MapInfoData entry for the current input
for (size_t i = 0; i < mapData.MapClause.size(); ++i)
if (mapData.OriginalValue[i] == input) {
auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[i]);
capture = mapOp.getMapCaptureType();

// Get information of alignment of mapped object
alignmentValue = typeToLLVMIRTranslator.getPreferredAlignment(
mapOp.getVarType(), ompBuilder.M.getDataLayout());
break;
}

Expand All @@ -4437,9 +4442,34 @@ createDeviceArgumentAccessor(MapInfoData &mapData, llvm::Argument &arg,
break;
}
case omp::VariableCaptureKind::ByRef: {
retVal = builder.CreateAlignedLoad(
llvm::LoadInst *loadInst = builder.CreateAlignedLoad(
v->getType(), v,
ompBuilder.M.getDataLayout().getPrefTypeAlign(v->getType()));
// CreateAlignedLoad function creates similar LLVM IR:
// %res = load ptr, ptr %input, align 8
// This LLVM IR does not contain information about alignment
// of the loaded value. We need to add !align metadata to unblock
// optimizer. The existence of the !align metadata on the instruction
// tells the optimizer that the value loaded is known to be aligned to
// a boundary specified by the integer value in the metadata node.
// Example:
// %res = load ptr, ptr %input, align 8, !align !align_md_node
// ^ ^
// | |
// alignment of %input address |
// |
// alignment of %res object
if (v->getType()->isPointerTy() && alignmentValue) {
llvm::MDBuilder MDB(builder.getContext());
loadInst->setMetadata(
llvm::LLVMContext::MD_align,
llvm::MDNode::get(builder.getContext(),
MDB.createConstant(llvm::ConstantInt::get(
llvm::Type::getInt64Ty(builder.getContext()),
alignmentValue))));
}
retVal = loadInst;

break;
}
case omp::VariableCaptureKind::This:
Expand Down
92 changes: 92 additions & 0 deletions mlir/test/Target/LLVMIR/omptarget-memcpy-align-metadata.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s

// The aim of this test is to verfiy that information of
// alignment of loaded objects is passed to outlined
// functions.

module attributes {llvm.data_layout = "e-p:64:64-p1:64:64-p2:32:32-p3:32:32-p4:64:64-p5:32:32-p6:32:32-p7:160:256:256:32-p8:128:128-p9:192:256:256:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-v2048:2048-n32:64-S32-A5-G1-ni:7:8:9", llvm.target_triple = "amdgcn-amd-amdhsa", omp.is_gpu = true, omp.is_target_device = true} {
omp.private {type = private} @_QFEk_private_i32 : i32
llvm.func @_QQmain() {
%0 = llvm.mlir.constant(1 : i32) : i32
%7 = llvm.alloca %0 x !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)> {alignment = 8 : i64} : (i32) -> !llvm.ptr<5>
%8 = llvm.addrspacecast %7 : !llvm.ptr<5> to !llvm.ptr
%12 = llvm.mlir.constant(1 : i64) : i64
%13 = llvm.alloca %12 x i32 {bindc_name = "k"} : (i64) -> !llvm.ptr<5>
%14 = llvm.addrspacecast %13 : !llvm.ptr<5> to !llvm.ptr
%15 = llvm.mlir.constant(1 : i64) : i64
%16 = llvm.alloca %15 x i32 {bindc_name = "b"} : (i64) -> !llvm.ptr<5>
%17 = llvm.addrspacecast %16 : !llvm.ptr<5> to !llvm.ptr
%19 = llvm.mlir.constant(1 : index) : i64
%20 = llvm.mlir.constant(0 : index) : i64
%22 = llvm.mlir.addressof @_QFEa : !llvm.ptr
%25 = llvm.mlir.addressof @_QFECnz : !llvm.ptr
%60 = llvm.getelementptr %8[0, 7, %20, 0] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>
%61 = llvm.load %60 : !llvm.ptr -> i64
%62 = llvm.getelementptr %8[0, 7, %20, 1] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>
%63 = llvm.load %62 : !llvm.ptr -> i64
%64 = llvm.getelementptr %8[0, 7, %20, 2] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>
%65 = llvm.load %64 : !llvm.ptr -> i64
%66 = llvm.sub %63, %19 : i64
%67 = omp.map.bounds lower_bound(%20 : i64) upper_bound(%66 : i64) extent(%63 : i64) stride(%65 : i64) start_idx(%61 : i64) {stride_in_bytes = true}
%68 = llvm.getelementptr %22[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>
%69 = omp.map.info var_ptr(%22 : !llvm.ptr, i32) map_clauses(tofrom) capture(ByRef) var_ptr_ptr(%68 : !llvm.ptr) bounds(%67) -> !llvm.ptr {name = ""}
%70 = omp.map.info var_ptr(%22 : !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>) map_clauses(to) capture(ByRef) members(%69 : [0] : !llvm.ptr) -> !llvm.ptr {name = "a"}
%71 = omp.map.info var_ptr(%17 : !llvm.ptr, i32) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = "b"}
%72 = omp.map.info var_ptr(%14 : !llvm.ptr, i32) map_clauses(implicit, exit_release_or_enter_alloc) capture(ByCopy) -> !llvm.ptr {name = "k"}
%73 = omp.map.info var_ptr(%25 : !llvm.ptr, i32) map_clauses(implicit, exit_release_or_enter_alloc) capture(ByCopy) -> !llvm.ptr {name = "nz"}
omp.target map_entries(%70 -> %arg0, %71 -> %arg1, %72 -> %arg2, %73 -> %arg3, %69 -> %arg4 : !llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr) {
%106 = llvm.mlir.constant(0 : index) : i64
%107 = llvm.mlir.constant(13 : i32) : i32
%108 = llvm.mlir.constant(1000 : i32) : i32
%109 = llvm.mlir.constant(1 : i32) : i32
omp.teams {
omp.parallel private(@_QFEk_private_i32 %arg2 -> %arg5 : !llvm.ptr) {
%110 = llvm.mlir.constant(1 : i32) : i32
%111 = llvm.alloca %110 x !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)> {alignment = 8 : i64} : (i32) -> !llvm.ptr<5>
%112 = llvm.addrspacecast %111 : !llvm.ptr<5> to !llvm.ptr
omp.distribute {
omp.wsloop {
omp.loop_nest (%arg6) : i32 = (%109) to (%108) inclusive step (%109) {
llvm.store %arg6, %arg5 : i32, !llvm.ptr
%115 = llvm.mlir.constant(48 : i32) : i32
"llvm.intr.memcpy"(%112, %arg0, %115) <{isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> ()
omp.yield
}
} {omp.composite}
} {omp.composite}
omp.terminator
} {omp.composite}
omp.terminator
}
omp.terminator
}
llvm.return
}
llvm.mlir.global internal @_QFEa() {addr_space = 0 : i32} : !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)> {
%6 = llvm.mlir.undef : !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>
llvm.return %6 : !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>
}
llvm.mlir.global internal constant @_QFECnz() {addr_space = 0 : i32} : i32 {
%0 = llvm.mlir.constant(1000 : i32) : i32
llvm.return %0 : i32
}
}

// CHECK: call void @__kmpc_distribute_for_static_loop_4u(
// CHECK-SAME: ptr addrspacecast (ptr addrspace(1) @[[GLOB:[0-9]+]] to ptr),
// CHECK-SAME: ptr @[[LOOP_BODY_FUNC:.*]], ptr %[[LOOP_BODY_FUNC_ARG:.*]],
// CHEKC-SAME i32 1000, i32 %1, i32 0, i32 0)


// CHECK: define internal void @[[LOOP_BODY_FUNC]](i32 %[[CNT:.*]], ptr %[[LOOP_BODY_ARG_PTR:.*]]) #[[ATTRS:[0-9]+]] {
// CHECK: %[[GEP_PTR_0:.*]] = getelementptr { ptr, ptr, ptr }, ptr %[[LOOP_BODY_ARG_PTR]], i32 0, i32 0
// CHECK: %[[INT_PTR:.*]] = load ptr, ptr %[[GEP_PTR_0]], align 8, !align ![[ALIGN_INT:[0-9]+]]
// CHECK: %[[GEP_PTR_1:.*]] = getelementptr { ptr, ptr, ptr }, ptr %[[LOOP_BODY_ARG_PTR]], i32 0, i32 1
// CHECK: %[[STRUCT_PTR_0:.*]] = load ptr, ptr %[[GEP_PTR_1]], align 8, !align ![[ALIGN_STRUCT:[0-9]+]]
// CHECK: %[[GEP_PTR_2:.*]] = getelementptr { ptr, ptr, ptr }, ptr %[[LOOP_BODY_ARG_PTR]], i32 0, i32 2
// CHECK: %[[STRUCT_PTR_1:.*]] = load ptr, ptr %[[GEP_PTR_2]], align 8, !align ![[ALIGN_STRUCT:[0-9]+]]
// CHECK: store i32 %[[DATA_INT:.*]], ptr %[[INT_PTR]], align 4
// CHECK: call void @llvm.memcpy.p0.p0.i32(ptr %[[STRUCT_PTR_0]], ptr %[[STRUCT_PTR_1]], i32 48, i1 false)

// CHECK: ![[ALIGN_STRUCT]] = !{i64 8}
// CHECK: ![[ALIGN_INT]] = !{i64 4}