Skip to content

Commit 9df071b

Browse files
committed
Add bypass for bf16 and bf16xN
1 parent bd6b84a commit 9df071b

File tree

4 files changed

+20
-2
lines changed

4 files changed

+20
-2
lines changed

compiler/rustc_codegen_llvm/src/abi.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,17 @@ impl<'ll, CX: Borrow<SCx<'ll>>> GenericCx<'ll, CX> {
361361
}
362362

363363
match self.type_kind(llvm_ty) {
364+
TypeKind::BFloat => rust_ty == self.type_i16(),
365+
TypeKind::Vector => {
366+
let llvm_element_count = self.vector_length(llvm_ty) as u64;
367+
let llvm_element_ty = self.element_type(llvm_ty);
368+
369+
if llvm_element_ty == self.type_bf16() {
370+
rust_ty == self.type_vector(self.type_i16(), llvm_element_count)
371+
} else {
372+
false
373+
}
374+
}
364375
TypeKind::Struct if self.type_kind(rust_ty) == TypeKind::Struct => {
365376
let rust_element_tys = self.struct_element_types(rust_ty);
366377
let llvm_element_tys = self.struct_element_types(llvm_ty);

compiler/rustc_codegen_llvm/src/builder.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1704,7 +1704,7 @@ impl<'a, 'll, 'tcx> Builder<'a, 'll, 'tcx> {
17041704
}
17051705
ret
17061706
}
1707-
_ => unreachable!(),
1707+
_ => self.bitcast(val, dest_ty), // for `bf16(xN)` <-> `u16(xN)`
17081708
}
17091709
}
17101710

compiler/rustc_codegen_llvm/src/llvm/ffi.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1051,6 +1051,9 @@ unsafe extern "C" {
10511051
pub(crate) fn LLVMDoubleTypeInContext(C: &Context) -> &Type;
10521052
pub(crate) fn LLVMFP128TypeInContext(C: &Context) -> &Type;
10531053

1054+
// Operations on non-IEEE real types
1055+
pub(crate) fn LLVMBFloatTypeInContext(C: &Context) -> &Type;
1056+
10541057
// Operations on function types
10551058
pub(crate) fn LLVMFunctionType<'a>(
10561059
ReturnType: &'a Type,

compiler/rustc_codegen_llvm/src/type_.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,10 @@ impl<'ll, CX: Borrow<SCx<'ll>>> GenericCx<'ll, CX> {
168168
)
169169
}
170170
}
171+
172+
pub(crate) fn type_bf16(&self) -> &'ll Type {
173+
unsafe { llvm::LLVMBFloatTypeInContext(self.llcx()) }
174+
}
171175
}
172176

173177
impl<'ll, CX: Borrow<SCx<'ll>>> BaseTypeCodegenMethods for GenericCx<'ll, CX> {
@@ -241,7 +245,7 @@ impl<'ll, CX: Borrow<SCx<'ll>>> BaseTypeCodegenMethods for GenericCx<'ll, CX> {
241245

242246
fn float_width(&self, ty: &'ll Type) -> usize {
243247
match self.type_kind(ty) {
244-
TypeKind::Half => 16,
248+
TypeKind::Half | TypeKind::BFloat => 16,
245249
TypeKind::Float => 32,
246250
TypeKind::Double => 64,
247251
TypeKind::X86_FP80 => 80,

0 commit comments

Comments
 (0)