Skip to content

Commit c7f8d77

Browse files
committed
Add bypass for bf16 and bf16xN
1 parent 45283e4 commit c7f8d77

File tree

4 files changed

+20
-2
lines changed

4 files changed

+20
-2
lines changed

compiler/rustc_codegen_llvm/src/abi.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,17 @@ impl<'ll, CX: Borrow<SCx<'ll>>> GenericCx<'ll, CX> {
361361
}
362362

363363
match self.type_kind(llvm_ty) {
364+
TypeKind::BFloat => rust_ty == self.type_i16(),
365+
TypeKind::Vector => {
366+
let llvm_element_count = self.vector_length(llvm_ty) as u64;
367+
let llvm_element_ty = self.element_type(llvm_ty);
368+
369+
if llvm_element_ty == self.type_bf16() {
370+
rust_ty == self.type_vector(self.type_i16(), llvm_element_count)
371+
} else {
372+
false
373+
}
374+
}
364375
TypeKind::Struct if self.type_kind(rust_ty) == TypeKind::Struct => {
365376
let rust_element_tys = self.struct_element_types(rust_ty);
366377
let llvm_element_tys = self.struct_element_types(llvm_ty);

compiler/rustc_codegen_llvm/src/builder.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1624,7 +1624,7 @@ impl<'a, 'll, 'tcx> Builder<'a, 'll, 'tcx> {
16241624
}
16251625
ret
16261626
}
1627-
_ => unreachable!(),
1627+
_ => self.bitcast(val, dest_ty), // for `bf16(xN)` <-> `u16(xN)`
16281628
}
16291629
}
16301630

compiler/rustc_codegen_llvm/src/llvm/ffi.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1051,6 +1051,9 @@ unsafe extern "C" {
10511051
pub(crate) fn LLVMDoubleTypeInContext(C: &Context) -> &Type;
10521052
pub(crate) fn LLVMFP128TypeInContext(C: &Context) -> &Type;
10531053

1054+
// Operations on non-IEEE real types
1055+
pub(crate) fn LLVMBFloatTypeInContext(C: &Context) -> &Type;
1056+
10541057
// Operations on function types
10551058
pub(crate) fn LLVMFunctionType<'a>(
10561059
ReturnType: &'a Type,

compiler/rustc_codegen_llvm/src/type_.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,10 @@ impl<'ll, CX: Borrow<SCx<'ll>>> GenericCx<'ll, CX> {
170170
)
171171
}
172172
}
173+
174+
pub(crate) fn type_bf16(&self) -> &'ll Type {
175+
unsafe { llvm::LLVMBFloatTypeInContext(self.llcx()) }
176+
}
173177
}
174178

175179
impl<'ll, CX: Borrow<SCx<'ll>>> BaseTypeCodegenMethods for GenericCx<'ll, CX> {
@@ -243,7 +247,7 @@ impl<'ll, CX: Borrow<SCx<'ll>>> BaseTypeCodegenMethods for GenericCx<'ll, CX> {
243247

244248
fn float_width(&self, ty: &'ll Type) -> usize {
245249
match self.type_kind(ty) {
246-
TypeKind::Half => 16,
250+
TypeKind::Half | TypeKind::BFloat => 16,
247251
TypeKind::Float => 32,
248252
TypeKind::Double => 64,
249253
TypeKind::X86_FP80 => 80,

0 commit comments

Comments
 (0)