Skip to content

[AMDGPU] Fold llvm.amdgcn.cvt.pkrtz when either operand is fpext #108237

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 29 additions & 18 deletions llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -643,27 +643,38 @@ GCNTTIImpl::instCombineIntrinsic(InstCombiner &IC, IntrinsicInst &II) const {
break;
}
case Intrinsic::amdgcn_cvt_pkrtz: {
Value *Src0 = II.getArgOperand(0);
Value *Src1 = II.getArgOperand(1);
if (const ConstantFP *C0 = dyn_cast<ConstantFP>(Src0)) {
if (const ConstantFP *C1 = dyn_cast<ConstantFP>(Src1)) {
const fltSemantics &HalfSem =
II.getType()->getScalarType()->getFltSemantics();
auto foldFPTruncToF16RTZ = [](Value *Arg) -> Value * {
Type *HalfTy = Type::getHalfTy(Arg->getContext());

if (isa<PoisonValue>(Arg))
return PoisonValue::get(HalfTy);
if (isa<UndefValue>(Arg))
return UndefValue::get(HalfTy);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

usually we do undef -> qnan for FP folds (although I think this is overly conservative, and I assume is only to exclude snan bit patterns which isn't guaranteed to quiet anyway)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was following ConstantFoldCastInstruction which will fold (fptrunc undef) to undef.


ConstantFP *CFP = nullptr;
if (match(Arg, m_ConstantFP(CFP))) {
bool LosesInfo;
APFloat Val0 = C0->getValueAPF();
APFloat Val1 = C1->getValueAPF();
Val0.convert(HalfSem, APFloat::rmTowardZero, &LosesInfo);
Val1.convert(HalfSem, APFloat::rmTowardZero, &LosesInfo);

Constant *Folded =
ConstantVector::get({ConstantFP::get(II.getContext(), Val0),
ConstantFP::get(II.getContext(), Val1)});
return IC.replaceInstUsesWith(II, Folded);
APFloat Val(CFP->getValueAPF());
Val.convert(APFloat::IEEEhalf(), APFloat::rmTowardZero, &LosesInfo);
return ConstantFP::get(HalfTy, Val);
}
}

if (isa<UndefValue>(Src0) && isa<UndefValue>(Src1)) {
return IC.replaceInstUsesWith(II, UndefValue::get(II.getType()));
Value *Src = nullptr;
if (match(Arg, m_FPExt(m_Value(Src)))) {
if (Src->getType()->isHalfTy())
return Src;
}

return nullptr;
};

if (Value *Src0 = foldFPTruncToF16RTZ(II.getArgOperand(0))) {
if (Value *Src1 = foldFPTruncToF16RTZ(II.getArgOperand(1))) {
Value *V = PoisonValue::get(II.getType());
V = IC.Builder.CreateInsertElement(V, Src0, (uint64_t)0);
V = IC.Builder.CreateInsertElement(V, Src1, (uint64_t)1);
return IC.replaceInstUsesWith(II, V);
}
}

break;
Expand Down
79 changes: 79 additions & 0 deletions llvm/test/Transforms/InstCombine/AMDGPU/amdgcn-intrinsics.ll
Original file line number Diff line number Diff line change
Expand Up @@ -1161,6 +1161,85 @@ define <2 x half> @constant_rtz_pkrtz() {
ret <2 x half> %cvt
}

define <2 x half> @fpext_const_cvt_pkrtz(half %x) {
; CHECK-LABEL: @fpext_const_cvt_pkrtz(
; CHECK-NEXT: [[CVT:%.*]] = insertelement <2 x half> <half poison, half 0xH4200>, half [[X:%.*]], i64 0
; CHECK-NEXT: ret <2 x half> [[CVT]]
;
%ext = fpext half %x to float
%cvt = call <2 x half> @llvm.amdgcn.cvt.pkrtz(float %ext, float 3.0)
ret <2 x half> %cvt
}

define <2 x half> @const_fpext_cvt_pkrtz(half %y) {
; CHECK-LABEL: @const_fpext_cvt_pkrtz(
; CHECK-NEXT: [[CVT:%.*]] = insertelement <2 x half> <half 0xH4500, half poison>, half [[Y:%.*]], i64 1
; CHECK-NEXT: ret <2 x half> [[CVT]]
;
%ext = fpext half %y to float
%cvt = call <2 x half> @llvm.amdgcn.cvt.pkrtz(float 5.0, float %ext)
ret <2 x half> %cvt
}

define <2 x half> @const_fpext_multi_cvt_pkrtz(half %y) {
; CHECK-LABEL: @const_fpext_multi_cvt_pkrtz(
; CHECK-NEXT: [[CVT1:%.*]] = insertelement <2 x half> <half 0xH4500, half poison>, half [[Y:%.*]], i64 1
; CHECK-NEXT: [[CVT2:%.*]] = insertelement <2 x half> <half 0xH4200, half poison>, half [[Y]], i64 1
; CHECK-NEXT: [[ADD:%.*]] = fadd <2 x half> [[CVT1]], [[CVT2]]
; CHECK-NEXT: ret <2 x half> [[ADD]]
;
%ext = fpext half %y to float
%cvt1 = call <2 x half> @llvm.amdgcn.cvt.pkrtz(float 5.0, float %ext)
%cvt2 = call <2 x half> @llvm.amdgcn.cvt.pkrtz(float 3.0, float %ext)
%add = fadd <2 x half> %cvt1, %cvt2
ret <2 x half> %add
}

define <2 x half> @fpext_fpext_cvt_pkrtz(half %x, half %y) {
; CHECK-LABEL: @fpext_fpext_cvt_pkrtz(
; CHECK-NEXT: [[TMP1:%.*]] = insertelement <2 x half> poison, half [[X:%.*]], i64 0
; CHECK-NEXT: [[CVT:%.*]] = insertelement <2 x half> [[TMP1]], half [[Y:%.*]], i64 1
; CHECK-NEXT: ret <2 x half> [[CVT]]
;
%extx = fpext half %x to float
%exty = fpext half %y to float
%cvt = call <2 x half> @llvm.amdgcn.cvt.pkrtz(float %extx, float %exty)
ret <2 x half> %cvt
}

define <2 x half> @fpext_fpext_bf16_cvt_pkrtz(bfloat %x, bfloat %y) {
; CHECK-LABEL: @fpext_fpext_bf16_cvt_pkrtz(
; CHECK-NEXT: [[EXTX:%.*]] = fpext bfloat [[X:%.*]] to float
; CHECK-NEXT: [[EXTY:%.*]] = fpext bfloat [[Y:%.*]] to float
; CHECK-NEXT: [[CVT:%.*]] = call <2 x half> @llvm.amdgcn.cvt.pkrtz(float [[EXTX]], float [[EXTY]])
; CHECK-NEXT: ret <2 x half> [[CVT]]
;
%extx = fpext bfloat %x to float
%exty = fpext bfloat %y to float
%cvt = call <2 x half> @llvm.amdgcn.cvt.pkrtz(float %extx, float %exty)
ret <2 x half> %cvt
}

define <2 x half> @poison_fpext_cvt_pkrtz(half %y) {
; CHECK-LABEL: @poison_fpext_cvt_pkrtz(
; CHECK-NEXT: [[CVT:%.*]] = insertelement <2 x half> poison, half [[Y:%.*]], i64 1
; CHECK-NEXT: ret <2 x half> [[CVT]]
;
%ext = fpext half %y to float
%cvt = call <2 x half> @llvm.amdgcn.cvt.pkrtz(float poison, float %ext)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also test poison on RHS

ret <2 x half> %cvt
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add some tests with bfloat sources. Also negative multi use test

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bfloat tests would be negative since the intrinsic only supports half.

As for multi use tests, I'm not sure if they should be negative. I know I used m_OneUse in the implementation, but now I think the optimization might be beneficial even with multiple uses.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, negative

define <2 x half> @fpext_poison_cvt_pkrtz(half %x) {
; CHECK-LABEL: @fpext_poison_cvt_pkrtz(
; CHECK-NEXT: [[TMP1:%.*]] = insertelement <2 x half> poison, half [[X:%.*]], i64 0
; CHECK-NEXT: ret <2 x half> [[TMP1]]
;
%ext = fpext half %x to float
%cvt = call <2 x half> @llvm.amdgcn.cvt.pkrtz(float %ext, float poison)
ret <2 x half> %cvt
}

; --------------------------------------------------------------------
; llvm.amdgcn.cvt.pknorm.i16
; --------------------------------------------------------------------
Expand Down
Loading