@@ -8601,7 +8601,7 @@ void EmitPass::EmitGenIntrinsicMessage(llvm::GenIntrinsicInst* inst)
8601
8601
break;
8602
8602
case GenISAIntrinsic::GenISA_srnd_ftohf:
8603
8603
case GenISAIntrinsic::GenISA_srnd_hftobf8:
8604
- emitsrnd (inst);
8604
+ emitSrnd (inst);
8605
8605
break;
8606
8606
case GenISAIntrinsic::GenISA_uavSerializeAll:
8607
8607
case GenISAIntrinsic::GenISA_uavSerializeOnResID:
@@ -21647,7 +21647,7 @@ void EmitPass::emitfcvt(llvm::GenIntrinsicInst* GII)
21647
21647
}
21648
21648
}
21649
21649
21650
- void EmitPass::emitsrnd (llvm::GenIntrinsicInst* GII)
21650
+ void EmitPass::emitSrnd (llvm::GenIntrinsicInst* GII)
21651
21651
{
21652
21652
CVariable* dst = m_destination;
21653
21653
CVariable* src0 = GetSymbol(GII->getOperand(0));
@@ -21656,18 +21656,30 @@ void EmitPass::emitsrnd(llvm::GenIntrinsicInst* GII)
21656
21656
bool isSat = CI->getValue().getBoolValue();
21657
21657
GenISAIntrinsic::ID GID = GII->getIntrinsicID();
21658
21658
21659
- switch (GID)
21659
+ // set dst types
21660
+ if (GID == GenISAIntrinsic::GenISA_srnd_ftohf)
21660
21661
{
21661
- case GenISAIntrinsic::GenISA_srnd_hftobf8:
21662
+ if (dst->GetType() != ISA_TYPE_HF)
21663
+ dst = m_currShader->GetNewAlias(dst, ISA_TYPE_HF, 0, 0);
21664
+ }
21665
+ if (GID == GenISAIntrinsic::GenISA_srnd_hftobf8
21666
+ )
21662
21667
{
21663
- if (dst->GetType() != ISA_TYPE_UB)
21664
- { // Use UB for bf8
21668
+ if (dst->GetType() != ISA_TYPE_UB) // Use UB for bf8
21665
21669
dst = m_currShader->GetNewAlias(dst, ISA_TYPE_UB, 0, 0);
21666
- }
21667
- break;
21668
21670
}
21669
- default:
21670
- break;
21671
+
21672
+ // set src types
21673
+ if (GID == GenISAIntrinsic::GenISA_srnd_ftohf)
21674
+ {
21675
+ if (src0->GetType() != ISA_TYPE_F)
21676
+ src0 = m_currShader->GetNewAlias(src0, ISA_TYPE_F, 0, 0);
21677
+ }
21678
+ if (GID == GenISAIntrinsic::GenISA_srnd_hftobf8
21679
+ )
21680
+ {
21681
+ if (src0->GetType() != ISA_TYPE_HF)
21682
+ src0 = m_currShader->GetNewAlias(src0, ISA_TYPE_HF, 0, 0);
21671
21683
}
21672
21684
21673
21685
uint16_t nsimdsize = numLanes(m_currShader->m_SIMDSize);
0 commit comments