Skip to content

Commit 618c807

Browse files
author
Hugh Delaney
committed
Add device code check for conversion builtin
1 parent be60cdd commit 618c807

File tree

2 files changed

+10
-6
lines changed

2 files changed

+10
-6
lines changed

sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -609,17 +609,19 @@ float float_to_tf32(float a) {
609609
int32_t tmp_int = __nvvm_f2tf32_rna(a);
610610
return __nvvm_bitcast_i2f(tmp_int);
611611
#else
612-
throw runtime_error("When using SYCL_EXT_ONEAPI_MATRIX=3 float_to_tf32 is "
613-
"only supported by CUDA devices",
614-
PI_INVALID_DEVICE);
612+
uint32_t tmp_uint = reinterpret_cast<uint32_t &>(a);
613+
tmp_uint += 0x1000u;
614+
float ret = reinterpret_cast<float &>(tmp_uint);
615+
return ret;
615616
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
616617
}
617618

618619
// This function just zeros out the bottom 13 bits of the tf32 type
619620
float tf32_to_float(float a) {
620621
uint32_t tmp_uint = reinterpret_cast<uint32_t &>(a);
621622
tmp_uint &= 0xFFFFE000u;
622-
return reinterpret_cast<float &>(tmp_uint);
623+
float ret = reinterpret_cast<float &>(tmp_uint);
624+
return ret;
623625
}
624626

625627
} // namespace experimental::matrix

sycl/test/check_device_code/matrix/matrix-nvptx-tf32-test.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ int main() {
7777
//CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.load.c.row.stride.f32.p1f32(float addrspace(1)* %_arg_accC, i32 16) #{{.*}}
7878
joint_matrix_load(sg, sub_c, accC.get_pointer(), N);
7979

80+
// CHECK: tail call i32 @llvm.nvvm.f2tf32.rna(float {{.*}}
8081
// Round a, b to tf32
8182
for (auto i = 0; i < 4; ++i)
8283
sub_a.data[i] = float_to_tf32(sub_a.data[i]);
@@ -120,14 +121,15 @@ int main() {
120121
joint_matrix_load(sg, sub_b, accB.get_pointer(), N);
121122
//CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.load.c.col.stride.f32.p1f32(float addrspace(1)* {{.*}}, i32 {{.*}}) #{{.*}}
122123
joint_matrix_load(sg, sub_c, accC.get_pointer(), N);
123-
124+
125+
// CHECK: tail call i32 @llvm.nvvm.f2tf32.rna(float {{.*}}
124126
// Round a, b to tf32
125127
for (auto i = 0; i < 4; ++i)
126128
sub_a.data[i] = float_to_tf32(sub_a.data[i]);
127129

128130
for (auto i = 0; i < 4; ++i)
129131
sub_b.data[i] = float_to_tf32(sub_b.data[i]);
130-
132+
131133
//CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k8.mma.col.col.tf32(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) #{{.*}}
132134
sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c);
133135
//CHECK: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.f32.p1f32(float addrspace(1)* {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) #{{.*}}

0 commit comments

Comments
 (0)