Skip to content

Commit 5fcd30f

Browse files
jayfoadtmsri
authored andcommitted
[AMDGPU] Fold llvm.amdgcn.cvt.pkrtz when either operand is fpext (llvm#108237)
This also generalizes the Undef handling and adds Poison handling.
1 parent aae38c1 commit 5fcd30f

File tree

2 files changed

+108
-18
lines changed

2 files changed

+108
-18
lines changed

llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp

Lines changed: 29 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -640,27 +640,38 @@ GCNTTIImpl::instCombineIntrinsic(InstCombiner &IC, IntrinsicInst &II) const {
640640
break;
641641
}
642642
case Intrinsic::amdgcn_cvt_pkrtz: {
643-
Value *Src0 = II.getArgOperand(0);
644-
Value *Src1 = II.getArgOperand(1);
645-
if (const ConstantFP *C0 = dyn_cast<ConstantFP>(Src0)) {
646-
if (const ConstantFP *C1 = dyn_cast<ConstantFP>(Src1)) {
647-
const fltSemantics &HalfSem =
648-
II.getType()->getScalarType()->getFltSemantics();
643+
auto foldFPTruncToF16RTZ = [](Value *Arg) -> Value * {
644+
Type *HalfTy = Type::getHalfTy(Arg->getContext());
645+
646+
if (isa<PoisonValue>(Arg))
647+
return PoisonValue::get(HalfTy);
648+
if (isa<UndefValue>(Arg))
649+
return UndefValue::get(HalfTy);
650+
651+
ConstantFP *CFP = nullptr;
652+
if (match(Arg, m_ConstantFP(CFP))) {
649653
bool LosesInfo;
650-
APFloat Val0 = C0->getValueAPF();
651-
APFloat Val1 = C1->getValueAPF();
652-
Val0.convert(HalfSem, APFloat::rmTowardZero, &LosesInfo);
653-
Val1.convert(HalfSem, APFloat::rmTowardZero, &LosesInfo);
654-
655-
Constant *Folded =
656-
ConstantVector::get({ConstantFP::get(II.getContext(), Val0),
657-
ConstantFP::get(II.getContext(), Val1)});
658-
return IC.replaceInstUsesWith(II, Folded);
654+
APFloat Val(CFP->getValueAPF());
655+
Val.convert(APFloat::IEEEhalf(), APFloat::rmTowardZero, &LosesInfo);
656+
return ConstantFP::get(HalfTy, Val);
659657
}
660-
}
661658

662-
if (isa<UndefValue>(Src0) && isa<UndefValue>(Src1)) {
663-
return IC.replaceInstUsesWith(II, UndefValue::get(II.getType()));
659+
Value *Src = nullptr;
660+
if (match(Arg, m_FPExt(m_Value(Src)))) {
661+
if (Src->getType()->isHalfTy())
662+
return Src;
663+
}
664+
665+
return nullptr;
666+
};
667+
668+
if (Value *Src0 = foldFPTruncToF16RTZ(II.getArgOperand(0))) {
669+
if (Value *Src1 = foldFPTruncToF16RTZ(II.getArgOperand(1))) {
670+
Value *V = PoisonValue::get(II.getType());
671+
V = IC.Builder.CreateInsertElement(V, Src0, (uint64_t)0);
672+
V = IC.Builder.CreateInsertElement(V, Src1, (uint64_t)1);
673+
return IC.replaceInstUsesWith(II, V);
674+
}
664675
}
665676

666677
break;

llvm/test/Transforms/InstCombine/AMDGPU/amdgcn-intrinsics.ll

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1161,6 +1161,85 @@ define <2 x half> @constant_rtz_pkrtz() {
11611161
ret <2 x half> %cvt
11621162
}
11631163

1164+
define <2 x half> @fpext_const_cvt_pkrtz(half %x) {
1165+
; CHECK-LABEL: @fpext_const_cvt_pkrtz(
1166+
; CHECK-NEXT: [[CVT:%.*]] = insertelement <2 x half> <half poison, half 0xH4200>, half [[X:%.*]], i64 0
1167+
; CHECK-NEXT: ret <2 x half> [[CVT]]
1168+
;
1169+
%ext = fpext half %x to float
1170+
%cvt = call <2 x half> @llvm.amdgcn.cvt.pkrtz(float %ext, float 3.0)
1171+
ret <2 x half> %cvt
1172+
}
1173+
1174+
define <2 x half> @const_fpext_cvt_pkrtz(half %y) {
1175+
; CHECK-LABEL: @const_fpext_cvt_pkrtz(
1176+
; CHECK-NEXT: [[CVT:%.*]] = insertelement <2 x half> <half 0xH4500, half poison>, half [[Y:%.*]], i64 1
1177+
; CHECK-NEXT: ret <2 x half> [[CVT]]
1178+
;
1179+
%ext = fpext half %y to float
1180+
%cvt = call <2 x half> @llvm.amdgcn.cvt.pkrtz(float 5.0, float %ext)
1181+
ret <2 x half> %cvt
1182+
}
1183+
1184+
define <2 x half> @const_fpext_multi_cvt_pkrtz(half %y) {
1185+
; CHECK-LABEL: @const_fpext_multi_cvt_pkrtz(
1186+
; CHECK-NEXT: [[CVT1:%.*]] = insertelement <2 x half> <half 0xH4500, half poison>, half [[Y:%.*]], i64 1
1187+
; CHECK-NEXT: [[CVT2:%.*]] = insertelement <2 x half> <half 0xH4200, half poison>, half [[Y]], i64 1
1188+
; CHECK-NEXT: [[ADD:%.*]] = fadd <2 x half> [[CVT1]], [[CVT2]]
1189+
; CHECK-NEXT: ret <2 x half> [[ADD]]
1190+
;
1191+
%ext = fpext half %y to float
1192+
%cvt1 = call <2 x half> @llvm.amdgcn.cvt.pkrtz(float 5.0, float %ext)
1193+
%cvt2 = call <2 x half> @llvm.amdgcn.cvt.pkrtz(float 3.0, float %ext)
1194+
%add = fadd <2 x half> %cvt1, %cvt2
1195+
ret <2 x half> %add
1196+
}
1197+
1198+
define <2 x half> @fpext_fpext_cvt_pkrtz(half %x, half %y) {
1199+
; CHECK-LABEL: @fpext_fpext_cvt_pkrtz(
1200+
; CHECK-NEXT: [[TMP1:%.*]] = insertelement <2 x half> poison, half [[X:%.*]], i64 0
1201+
; CHECK-NEXT: [[CVT:%.*]] = insertelement <2 x half> [[TMP1]], half [[Y:%.*]], i64 1
1202+
; CHECK-NEXT: ret <2 x half> [[CVT]]
1203+
;
1204+
%extx = fpext half %x to float
1205+
%exty = fpext half %y to float
1206+
%cvt = call <2 x half> @llvm.amdgcn.cvt.pkrtz(float %extx, float %exty)
1207+
ret <2 x half> %cvt
1208+
}
1209+
1210+
define <2 x half> @fpext_fpext_bf16_cvt_pkrtz(bfloat %x, bfloat %y) {
1211+
; CHECK-LABEL: @fpext_fpext_bf16_cvt_pkrtz(
1212+
; CHECK-NEXT: [[EXTX:%.*]] = fpext bfloat [[X:%.*]] to float
1213+
; CHECK-NEXT: [[EXTY:%.*]] = fpext bfloat [[Y:%.*]] to float
1214+
; CHECK-NEXT: [[CVT:%.*]] = call <2 x half> @llvm.amdgcn.cvt.pkrtz(float [[EXTX]], float [[EXTY]])
1215+
; CHECK-NEXT: ret <2 x half> [[CVT]]
1216+
;
1217+
%extx = fpext bfloat %x to float
1218+
%exty = fpext bfloat %y to float
1219+
%cvt = call <2 x half> @llvm.amdgcn.cvt.pkrtz(float %extx, float %exty)
1220+
ret <2 x half> %cvt
1221+
}
1222+
1223+
define <2 x half> @poison_fpext_cvt_pkrtz(half %y) {
1224+
; CHECK-LABEL: @poison_fpext_cvt_pkrtz(
1225+
; CHECK-NEXT: [[CVT:%.*]] = insertelement <2 x half> poison, half [[Y:%.*]], i64 1
1226+
; CHECK-NEXT: ret <2 x half> [[CVT]]
1227+
;
1228+
%ext = fpext half %y to float
1229+
%cvt = call <2 x half> @llvm.amdgcn.cvt.pkrtz(float poison, float %ext)
1230+
ret <2 x half> %cvt
1231+
}
1232+
1233+
define <2 x half> @fpext_poison_cvt_pkrtz(half %x) {
1234+
; CHECK-LABEL: @fpext_poison_cvt_pkrtz(
1235+
; CHECK-NEXT: [[TMP1:%.*]] = insertelement <2 x half> poison, half [[X:%.*]], i64 0
1236+
; CHECK-NEXT: ret <2 x half> [[TMP1]]
1237+
;
1238+
%ext = fpext half %x to float
1239+
%cvt = call <2 x half> @llvm.amdgcn.cvt.pkrtz(float %ext, float poison)
1240+
ret <2 x half> %cvt
1241+
}
1242+
11641243
; --------------------------------------------------------------------
11651244
; llvm.amdgcn.cvt.pknorm.i16
11661245
; --------------------------------------------------------------------

0 commit comments

Comments
 (0)