Skip to content

Commit 9e574a3

Browse files
committed
DAG: Fix expansion of bf16 sourced extloads
Also fix assorted vector extload failures for AMDGPU.
1 parent 701f647 commit 9e574a3

File tree

3 files changed

+3299
-2
lines changed

3 files changed

+3299
-2
lines changed

llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -913,14 +913,17 @@ void SelectionDAGLegalize::LegalizeLoadOps(SDNode *Node) {
913913
// normal undefined upper bits behavior to allow using an in-reg extend
914914
// with the illegal FP type, so load as an integer and do the
915915
// from-integer conversion.
916-
if (SrcVT.getScalarType() == MVT::f16) {
916+
EVT SVT = SrcVT.getScalarType();
917+
if (SVT == MVT::f16 || SVT == MVT::bf16) {
917918
EVT ISrcVT = SrcVT.changeTypeToInteger();
918919
EVT IDestVT = DestVT.changeTypeToInteger();
919920
EVT ILoadVT = TLI.getRegisterType(IDestVT.getSimpleVT());
920921

921922
SDValue Result = DAG.getExtLoad(ISD::ZEXTLOAD, dl, ILoadVT, Chain,
922923
Ptr, ISrcVT, LD->getMemOperand());
923-
Value = DAG.getNode(ISD::FP16_TO_FP, dl, DestVT, Result);
924+
Value =
925+
DAG.getNode(SVT == MVT::f16 ? ISD::FP16_TO_FP : ISD::BF16_TO_FP,
926+
dl, DestVT, Result);
924927
Chain = Result.getValue(1);
925928
break;
926929
}

llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,11 +169,17 @@ AMDGPUTargetLowering::AMDGPUTargetLowering(const TargetMachine &TM,
169169
setLoadExtAction(ISD::EXTLOAD, MVT::f32, MVT::f16, Expand);
170170
setLoadExtAction(ISD::EXTLOAD, MVT::f32, MVT::bf16, Expand);
171171
setLoadExtAction(ISD::EXTLOAD, MVT::v2f32, MVT::v2f16, Expand);
172+
setLoadExtAction(ISD::EXTLOAD, MVT::v2f32, MVT::v2bf16, Expand);
172173
setLoadExtAction(ISD::EXTLOAD, MVT::v3f32, MVT::v3f16, Expand);
174+
setLoadExtAction(ISD::EXTLOAD, MVT::v3f32, MVT::v3bf16, Expand);
173175
setLoadExtAction(ISD::EXTLOAD, MVT::v4f32, MVT::v4f16, Expand);
176+
setLoadExtAction(ISD::EXTLOAD, MVT::v4f32, MVT::v4bf16, Expand);
174177
setLoadExtAction(ISD::EXTLOAD, MVT::v8f32, MVT::v8f16, Expand);
178+
setLoadExtAction(ISD::EXTLOAD, MVT::v8f32, MVT::v8bf16, Expand);
175179
setLoadExtAction(ISD::EXTLOAD, MVT::v16f32, MVT::v16f16, Expand);
180+
setLoadExtAction(ISD::EXTLOAD, MVT::v16f32, MVT::v16bf16, Expand);
176181
setLoadExtAction(ISD::EXTLOAD, MVT::v32f32, MVT::v32f16, Expand);
182+
setLoadExtAction(ISD::EXTLOAD, MVT::v32f32, MVT::v32bf16, Expand);
177183

178184
setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::f32, Expand);
179185
setLoadExtAction(ISD::EXTLOAD, MVT::v2f64, MVT::v2f32, Expand);
@@ -185,10 +191,15 @@ AMDGPUTargetLowering::AMDGPUTargetLowering(const TargetMachine &TM,
185191
setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::f16, Expand);
186192
setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::bf16, Expand);
187193
setLoadExtAction(ISD::EXTLOAD, MVT::v2f64, MVT::v2f16, Expand);
194+
setLoadExtAction(ISD::EXTLOAD, MVT::v2f64, MVT::v2bf16, Expand);
188195
setLoadExtAction(ISD::EXTLOAD, MVT::v3f64, MVT::v3f16, Expand);
196+
setLoadExtAction(ISD::EXTLOAD, MVT::v3f64, MVT::v3bf16, Expand);
189197
setLoadExtAction(ISD::EXTLOAD, MVT::v4f64, MVT::v4f16, Expand);
198+
setLoadExtAction(ISD::EXTLOAD, MVT::v4f64, MVT::v4bf16, Expand);
190199
setLoadExtAction(ISD::EXTLOAD, MVT::v8f64, MVT::v8f16, Expand);
200+
setLoadExtAction(ISD::EXTLOAD, MVT::v8f64, MVT::v8bf16, Expand);
191201
setLoadExtAction(ISD::EXTLOAD, MVT::v16f64, MVT::v16f16, Expand);
202+
setLoadExtAction(ISD::EXTLOAD, MVT::v16f64, MVT::v16bf16, Expand);
192203

193204
setOperationAction(ISD::STORE, MVT::f32, Promote);
194205
AddPromotedToType(ISD::STORE, MVT::f32, MVT::i32);

0 commit comments

Comments
 (0)