Skip to content

[msan] Generalize handlePairwiseShadowOrIntrinsic, and handle x86 pairwise add/sub #127567

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Feb 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
215 changes: 112 additions & 103 deletions llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2607,8 +2607,6 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
///
/// e.g., <2 x i32> @llvm.aarch64.neon.saddlp.v2i32.v4i16(<4 x i16>)
/// <16 x i8> @llvm.aarch64.neon.addp.v16i8(<16 x i8>, <16 x i8>)
///
/// TODO: adapt this function to handle horizontal add/sub?
void handlePairwiseShadowOrIntrinsic(IntrinsicInst &I) {
assert(I.arg_size() == 1 || I.arg_size() == 2);

Expand All @@ -2617,8 +2615,8 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {

FixedVectorType *ParamType =
cast<FixedVectorType>(I.getArgOperand(0)->getType());
if (I.arg_size() == 2)
assert(ParamType == cast<FixedVectorType>(I.getArgOperand(1)->getType()));
assert((I.arg_size() != 2) ||
(ParamType == cast<FixedVectorType>(I.getArgOperand(1)->getType())));
[[maybe_unused]] FixedVectorType *ReturnType =
cast<FixedVectorType>(I.getType());
assert(ParamType->getNumElements() * I.arg_size() ==
Expand Down Expand Up @@ -2656,6 +2654,82 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
setOriginForNaryOp(I);
}

/// Propagate shadow for 1- or 2-vector intrinsics that combine adjacent
/// fields, with the parameters reinterpreted to have elements of a specified
/// width. For example:
/// @llvm.x86.ssse3.phadd.w(<1 x i64> [[VAR1]], <1 x i64> [[VAR2]])
/// conceptually operates on
/// (<4 x i16> [[VAR1]], <4 x i16> [[VAR2]])
/// and can be handled with ReinterpretElemWidth == 16.
void handlePairwiseShadowOrIntrinsic(IntrinsicInst &I,
int ReinterpretElemWidth) {
assert(I.arg_size() == 1 || I.arg_size() == 2);

assert(I.getType()->isVectorTy());
assert(I.getArgOperand(0)->getType()->isVectorTy());

FixedVectorType *ParamType =
cast<FixedVectorType>(I.getArgOperand(0)->getType());
assert((I.arg_size() != 2) ||
(ParamType == cast<FixedVectorType>(I.getArgOperand(1)->getType())));

[[maybe_unused]] FixedVectorType *ReturnType =
cast<FixedVectorType>(I.getType());
assert(ParamType->getNumElements() * I.arg_size() ==
2 * ReturnType->getNumElements());

IRBuilder<> IRB(&I);

unsigned TotalNumElems = ParamType->getNumElements() * I.arg_size();
FixedVectorType *ReinterpretShadowTy = nullptr;
assert(isAligned(Align(ReinterpretElemWidth),
ParamType->getPrimitiveSizeInBits()));
ReinterpretShadowTy = FixedVectorType::get(
IRB.getIntNTy(ReinterpretElemWidth),
ParamType->getPrimitiveSizeInBits() / ReinterpretElemWidth);
TotalNumElems = ReinterpretShadowTy->getNumElements() * I.arg_size();

// Horizontal OR of shadow
SmallVector<int, 8> EvenMask;
SmallVector<int, 8> OddMask;
for (unsigned X = 0; X < TotalNumElems - 1; X += 2) {
EvenMask.push_back(X);
OddMask.push_back(X + 1);
}

Value *FirstArgShadow = getShadow(&I, 0);
FirstArgShadow = IRB.CreateBitCast(FirstArgShadow, ReinterpretShadowTy);

// If we had two parameters each with an odd number of elements, the total
// number of elements is even, but we have never seen this in extant
// instruction sets, so we enforce that each parameter must have an even
// number of elements.
assert(isAligned(
Align(2),
cast<FixedVectorType>(FirstArgShadow->getType())->getNumElements()));

Value *EvenShadow;
Value *OddShadow;
if (I.arg_size() == 2) {
Value *SecondArgShadow = getShadow(&I, 1);
SecondArgShadow = IRB.CreateBitCast(SecondArgShadow, ReinterpretShadowTy);

EvenShadow =
IRB.CreateShuffleVector(FirstArgShadow, SecondArgShadow, EvenMask);
OddShadow =
IRB.CreateShuffleVector(FirstArgShadow, SecondArgShadow, OddMask);
} else {
EvenShadow = IRB.CreateShuffleVector(FirstArgShadow, EvenMask);
OddShadow = IRB.CreateShuffleVector(FirstArgShadow, OddMask);
}

Value *OrShadow = IRB.CreateOr(EvenShadow, OddShadow);
OrShadow = CreateShadowCast(IRB, OrShadow, getShadowTy(&I));

setShadow(&I, OrShadow);
setOriginForNaryOp(I);
}

void visitFNeg(UnaryOperator &I) { handleShadowOr(I); }

// Handle multiplication by constant.
Expand Down Expand Up @@ -4156,87 +4230,6 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
setOriginForNaryOp(I);
}

void handleAVXHorizontalAddSubIntrinsic(IntrinsicInst &I) {
// Approximation only:
// output = horizontal_add/sub(A, B)
// => shadow[output] = horizontal_add(shadow[A], shadow[B])
//
// We always use horizontal add instead of subtract, because subtracting
// a fully uninitialized shadow would result in a fully initialized shadow.
//
// - If we add two adjacent zero (initialized) shadow values, the
// result always be zero i.e., no false positives.
// - If we add two shadows, one of which is uninitialized, the
// result will always be non-zero i.e., no false negatives.
// - However, we can have false negatives if we do an addition that wraps
// to zero; we consider this an acceptable tradeoff for performance.
//
// To make shadow propagation precise, we want the equivalent of
// "horizontal OR", but this is not available for SSE3/SSSE3/AVX/AVX2.

Intrinsic::ID shadowIntrinsicID = I.getIntrinsicID();

switch (I.getIntrinsicID()) {
case Intrinsic::x86_sse3_hsub_ps:
shadowIntrinsicID = Intrinsic::x86_sse3_hadd_ps;
break;

case Intrinsic::x86_sse3_hsub_pd:
shadowIntrinsicID = Intrinsic::x86_sse3_hadd_pd;
break;

case Intrinsic::x86_ssse3_phsub_d:
shadowIntrinsicID = Intrinsic::x86_ssse3_phadd_d;
break;

case Intrinsic::x86_ssse3_phsub_d_128:
shadowIntrinsicID = Intrinsic::x86_ssse3_phadd_d_128;
break;

case Intrinsic::x86_ssse3_phsub_w:
shadowIntrinsicID = Intrinsic::x86_ssse3_phadd_w;
break;

case Intrinsic::x86_ssse3_phsub_w_128:
shadowIntrinsicID = Intrinsic::x86_ssse3_phadd_w_128;
break;

case Intrinsic::x86_ssse3_phsub_sw:
shadowIntrinsicID = Intrinsic::x86_ssse3_phadd_sw;
break;

case Intrinsic::x86_ssse3_phsub_sw_128:
shadowIntrinsicID = Intrinsic::x86_ssse3_phadd_sw_128;
break;

case Intrinsic::x86_avx_hsub_pd_256:
shadowIntrinsicID = Intrinsic::x86_avx_hadd_pd_256;
break;

case Intrinsic::x86_avx_hsub_ps_256:
shadowIntrinsicID = Intrinsic::x86_avx_hadd_ps_256;
break;

case Intrinsic::x86_avx2_phsub_d:
shadowIntrinsicID = Intrinsic::x86_avx2_phadd_d;
break;

case Intrinsic::x86_avx2_phsub_w:
shadowIntrinsicID = Intrinsic::x86_avx2_phadd_w;
break;

case Intrinsic::x86_avx2_phsub_sw:
shadowIntrinsicID = Intrinsic::x86_avx2_phadd_sw;
break;

default:
break;
}

return handleIntrinsicByApplyingToShadow(I, shadowIntrinsicID,
/*trailingVerbatimArgs*/ 0);
}

/// Handle Arm NEON vector store intrinsics (vst{2,3,4}, vst1x_{2,3,4},
/// and vst{2,3,4}lane).
///
Expand Down Expand Up @@ -4783,33 +4776,49 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
handleVtestIntrinsic(I);
break;

case Intrinsic::x86_sse3_hadd_ps:
case Intrinsic::x86_sse3_hadd_pd:
case Intrinsic::x86_ssse3_phadd_d:
case Intrinsic::x86_ssse3_phadd_d_128:
// Packed Horizontal Add/Subtract
case Intrinsic::x86_ssse3_phadd_w:
case Intrinsic::x86_ssse3_phadd_w_128:
case Intrinsic::x86_avx2_phadd_w:
case Intrinsic::x86_ssse3_phsub_w:
case Intrinsic::x86_ssse3_phsub_w_128:
case Intrinsic::x86_avx2_phsub_w: {
handlePairwiseShadowOrIntrinsic(I, /*ReinterpretElemWidth=*/16);
break;
}

// Packed Horizontal Add/Subtract
case Intrinsic::x86_ssse3_phadd_d:
case Intrinsic::x86_ssse3_phadd_d_128:
case Intrinsic::x86_avx2_phadd_d:
case Intrinsic::x86_ssse3_phsub_d:
case Intrinsic::x86_ssse3_phsub_d_128:
case Intrinsic::x86_avx2_phsub_d: {
handlePairwiseShadowOrIntrinsic(I, /*ReinterpretElemWidth=*/32);
break;
}

// Packed Horizontal Add/Subtract and Saturate
case Intrinsic::x86_ssse3_phadd_sw:
case Intrinsic::x86_ssse3_phadd_sw_128:
case Intrinsic::x86_avx2_phadd_sw:
case Intrinsic::x86_ssse3_phsub_sw:
case Intrinsic::x86_ssse3_phsub_sw_128:
case Intrinsic::x86_avx2_phsub_sw: {
handlePairwiseShadowOrIntrinsic(I, /*ReinterpretElemWidth=*/16);
break;
}

// Packed Single/Double Precision Floating-Point Horizontal Add
case Intrinsic::x86_sse3_hadd_ps:
case Intrinsic::x86_sse3_hadd_pd:
case Intrinsic::x86_avx_hadd_pd_256:
case Intrinsic::x86_avx_hadd_ps_256:
case Intrinsic::x86_avx2_phadd_d:
case Intrinsic::x86_avx2_phadd_w:
case Intrinsic::x86_avx2_phadd_sw:
case Intrinsic::x86_sse3_hsub_ps:
case Intrinsic::x86_sse3_hsub_pd:
case Intrinsic::x86_ssse3_phsub_d:
case Intrinsic::x86_ssse3_phsub_d_128:
case Intrinsic::x86_ssse3_phsub_w:
case Intrinsic::x86_ssse3_phsub_w_128:
case Intrinsic::x86_ssse3_phsub_sw:
case Intrinsic::x86_ssse3_phsub_sw_128:
case Intrinsic::x86_avx_hsub_pd_256:
case Intrinsic::x86_avx_hsub_ps_256:
case Intrinsic::x86_avx2_phsub_d:
case Intrinsic::x86_avx2_phsub_w:
case Intrinsic::x86_avx2_phsub_sw: {
handleAVXHorizontalAddSubIntrinsic(I);
case Intrinsic::x86_avx_hsub_ps_256: {
handlePairwiseShadowOrIntrinsic(I);
break;
}

Expand Down
28 changes: 12 additions & 16 deletions llvm/test/Instrumentation/MemorySanitizer/X86/avx-intrinsics-x86.ll
Original file line number Diff line number Diff line change
Expand Up @@ -435,10 +435,9 @@ define <4 x double> @test_x86_avx_hadd_pd_256(<4 x double> %a0, <4 x double> %a1
; CHECK-NEXT: [[TMP1:%.*]] = load <4 x i64>, ptr @__msan_param_tls, align 8
; CHECK-NEXT: [[TMP2:%.*]] = load <4 x i64>, ptr inttoptr (i64 add (i64 ptrtoint (ptr @__msan_param_tls to i64), i64 32) to ptr), align 8
; CHECK-NEXT: call void @llvm.donothing()
; CHECK-NEXT: [[A0:%.*]] = bitcast <4 x i64> [[TMP1]] to <4 x double>
; CHECK-NEXT: [[A1:%.*]] = bitcast <4 x i64> [[TMP2]] to <4 x double>
; CHECK-NEXT: [[RES:%.*]] = call <4 x double> @llvm.x86.avx.hadd.pd.256(<4 x double> [[A0]], <4 x double> [[A1]])
; CHECK-NEXT: [[_MSPROP:%.*]] = bitcast <4 x double> [[RES]] to <4 x i64>
; CHECK-NEXT: [[TMP3:%.*]] = shufflevector <4 x i64> [[TMP1]], <4 x i64> [[TMP2]], <4 x i32> <i32 0, i32 2, i32 4, i32 6>
; CHECK-NEXT: [[TMP4:%.*]] = shufflevector <4 x i64> [[TMP1]], <4 x i64> [[TMP2]], <4 x i32> <i32 1, i32 3, i32 5, i32 7>
; CHECK-NEXT: [[_MSPROP:%.*]] = or <4 x i64> [[TMP3]], [[TMP4]]
; CHECK-NEXT: [[RES1:%.*]] = call <4 x double> @llvm.x86.avx.hadd.pd.256(<4 x double> [[A2:%.*]], <4 x double> [[A3:%.*]])
; CHECK-NEXT: store <4 x i64> [[_MSPROP]], ptr @__msan_retval_tls, align 8
; CHECK-NEXT: ret <4 x double> [[RES1]]
Expand All @@ -454,10 +453,9 @@ define <8 x float> @test_x86_avx_hadd_ps_256(<8 x float> %a0, <8 x float> %a1) #
; CHECK-NEXT: [[TMP1:%.*]] = load <8 x i32>, ptr @__msan_param_tls, align 8
; CHECK-NEXT: [[TMP2:%.*]] = load <8 x i32>, ptr inttoptr (i64 add (i64 ptrtoint (ptr @__msan_param_tls to i64), i64 32) to ptr), align 8
; CHECK-NEXT: call void @llvm.donothing()
; CHECK-NEXT: [[A0:%.*]] = bitcast <8 x i32> [[TMP1]] to <8 x float>
; CHECK-NEXT: [[A1:%.*]] = bitcast <8 x i32> [[TMP2]] to <8 x float>
; CHECK-NEXT: [[RES:%.*]] = call <8 x float> @llvm.x86.avx.hadd.ps.256(<8 x float> [[A0]], <8 x float> [[A1]])
; CHECK-NEXT: [[_MSPROP:%.*]] = bitcast <8 x float> [[RES]] to <8 x i32>
; CHECK-NEXT: [[TMP3:%.*]] = shufflevector <8 x i32> [[TMP1]], <8 x i32> [[TMP2]], <8 x i32> <i32 0, i32 2, i32 4, i32 6, i32 8, i32 10, i32 12, i32 14>
; CHECK-NEXT: [[TMP4:%.*]] = shufflevector <8 x i32> [[TMP1]], <8 x i32> [[TMP2]], <8 x i32> <i32 1, i32 3, i32 5, i32 7, i32 9, i32 11, i32 13, i32 15>
; CHECK-NEXT: [[_MSPROP:%.*]] = or <8 x i32> [[TMP3]], [[TMP4]]
; CHECK-NEXT: [[RES1:%.*]] = call <8 x float> @llvm.x86.avx.hadd.ps.256(<8 x float> [[A2:%.*]], <8 x float> [[A3:%.*]])
; CHECK-NEXT: store <8 x i32> [[_MSPROP]], ptr @__msan_retval_tls, align 8
; CHECK-NEXT: ret <8 x float> [[RES1]]
Expand All @@ -473,10 +471,9 @@ define <4 x double> @test_x86_avx_hsub_pd_256(<4 x double> %a0, <4 x double> %a1
; CHECK-NEXT: [[TMP1:%.*]] = load <4 x i64>, ptr @__msan_param_tls, align 8
; CHECK-NEXT: [[TMP2:%.*]] = load <4 x i64>, ptr inttoptr (i64 add (i64 ptrtoint (ptr @__msan_param_tls to i64), i64 32) to ptr), align 8
; CHECK-NEXT: call void @llvm.donothing()
; CHECK-NEXT: [[A0:%.*]] = bitcast <4 x i64> [[TMP1]] to <4 x double>
; CHECK-NEXT: [[A1:%.*]] = bitcast <4 x i64> [[TMP2]] to <4 x double>
; CHECK-NEXT: [[RES:%.*]] = call <4 x double> @llvm.x86.avx.hadd.pd.256(<4 x double> [[A0]], <4 x double> [[A1]])
; CHECK-NEXT: [[_MSPROP:%.*]] = bitcast <4 x double> [[RES]] to <4 x i64>
; CHECK-NEXT: [[TMP3:%.*]] = shufflevector <4 x i64> [[TMP1]], <4 x i64> [[TMP2]], <4 x i32> <i32 0, i32 2, i32 4, i32 6>
; CHECK-NEXT: [[TMP4:%.*]] = shufflevector <4 x i64> [[TMP1]], <4 x i64> [[TMP2]], <4 x i32> <i32 1, i32 3, i32 5, i32 7>
; CHECK-NEXT: [[_MSPROP:%.*]] = or <4 x i64> [[TMP3]], [[TMP4]]
; CHECK-NEXT: [[RES1:%.*]] = call <4 x double> @llvm.x86.avx.hsub.pd.256(<4 x double> [[A2:%.*]], <4 x double> [[A3:%.*]])
; CHECK-NEXT: store <4 x i64> [[_MSPROP]], ptr @__msan_retval_tls, align 8
; CHECK-NEXT: ret <4 x double> [[RES1]]
Expand All @@ -492,10 +489,9 @@ define <8 x float> @test_x86_avx_hsub_ps_256(<8 x float> %a0, <8 x float> %a1) #
; CHECK-NEXT: [[TMP1:%.*]] = load <8 x i32>, ptr @__msan_param_tls, align 8
; CHECK-NEXT: [[TMP2:%.*]] = load <8 x i32>, ptr inttoptr (i64 add (i64 ptrtoint (ptr @__msan_param_tls to i64), i64 32) to ptr), align 8
; CHECK-NEXT: call void @llvm.donothing()
; CHECK-NEXT: [[A0:%.*]] = bitcast <8 x i32> [[TMP1]] to <8 x float>
; CHECK-NEXT: [[A1:%.*]] = bitcast <8 x i32> [[TMP2]] to <8 x float>
; CHECK-NEXT: [[RES:%.*]] = call <8 x float> @llvm.x86.avx.hadd.ps.256(<8 x float> [[A0]], <8 x float> [[A1]])
; CHECK-NEXT: [[_MSPROP:%.*]] = bitcast <8 x float> [[RES]] to <8 x i32>
; CHECK-NEXT: [[TMP3:%.*]] = shufflevector <8 x i32> [[TMP1]], <8 x i32> [[TMP2]], <8 x i32> <i32 0, i32 2, i32 4, i32 6, i32 8, i32 10, i32 12, i32 14>
; CHECK-NEXT: [[TMP4:%.*]] = shufflevector <8 x i32> [[TMP1]], <8 x i32> [[TMP2]], <8 x i32> <i32 1, i32 3, i32 5, i32 7, i32 9, i32 11, i32 13, i32 15>
; CHECK-NEXT: [[_MSPROP:%.*]] = or <8 x i32> [[TMP3]], [[TMP4]]
; CHECK-NEXT: [[RES1:%.*]] = call <8 x float> @llvm.x86.avx.hsub.ps.256(<8 x float> [[A2:%.*]], <8 x float> [[A3:%.*]])
; CHECK-NEXT: store <8 x i32> [[_MSPROP]], ptr @__msan_retval_tls, align 8
; CHECK-NEXT: ret <8 x float> [[RES1]]
Expand Down
Loading