Skip to content

Commit b6193a2

Browse files
authored
libclc: clspv: update gen_convert.cl for clspv (#66902)
Add a clspv switch in gen_convert.cl This is needed as Vulkan SPIR-V does not respect the assumptions needed to have the generic convert.cl compliant on many platforms. It is needed because of the conversion of TYPE_MAX and TYPE_MIN. Depending on the platform the behaviour can vary, but most of them just do not convert correctly those 2 values. Because of that, we also need to avoid having explicit function for simple conversions because it allows llvm to optimise the code, thus removing some of the added checks that are in fact needed.
1 parent 2595931 commit b6193a2

File tree

2 files changed

+116
-28
lines changed

2 files changed

+116
-28
lines changed

libclc/CMakeLists.txt

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,12 @@ add_custom_command(
174174
DEPENDS ${script_loc} )
175175
add_custom_target( "generate_convert.cl" DEPENDS convert.cl )
176176

177+
add_custom_command(
178+
OUTPUT clspv-convert.cl
179+
COMMAND ${Python3_EXECUTABLE} ${script_loc} --clspv > clspv-convert.cl
180+
DEPENDS ${script_loc} )
181+
add_custom_target( "clspv-generate_convert.cl" DEPENDS clspv-convert.cl )
182+
177183
enable_testing()
178184

179185
foreach( t ${LIBCLC_TARGETS_TO_BUILD} )
@@ -218,11 +224,14 @@ foreach( t ${LIBCLC_TARGETS_TO_BUILD} )
218224
# Add the generated convert.cl here to prevent adding
219225
# the one listed in SOURCES
220226
if( NOT ${ARCH} STREQUAL "spirv" AND NOT ${ARCH} STREQUAL "spirv64" )
221-
set( rel_files convert.cl )
222-
set( objects convert.cl )
223227
if( NOT ENABLE_RUNTIME_SUBNORMAL AND NOT ${ARCH} STREQUAL "clspv" AND
224228
NOT ${ARCH} STREQUAL "clspv64" )
229+
set( rel_files convert.cl )
230+
set( objects convert.cl )
225231
list( APPEND rel_files generic/lib/subnormal_use_default.ll )
232+
elseif(${ARCH} STREQUAL "clspv" OR ${ARCH} STREQUAL "clspv64")
233+
set( rel_files clspv-convert.cl )
234+
set( objects clspv-convert.cl )
226235
endif()
227236
else()
228237
set( rel_files )
@@ -286,6 +295,8 @@ foreach( t ${LIBCLC_TARGETS_TO_BUILD} )
286295
# multiple invocations
287296
add_dependencies( builtins.link.${arch_suffix}
288297
generate_convert.cl )
298+
add_dependencies( builtins.link.${arch_suffix}
299+
clspv-generate_convert.cl )
289300
# CMake will turn this include into absolute path
290301
target_include_directories( builtins.link.${arch_suffix} PRIVATE
291302
"generic/include" )

libclc/generic/lib/gen_convert.py

Lines changed: 103 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#
33
# Copyright (c) 2013 Victor Oliveira <[email protected]>
44
# Copyright (c) 2013 Jesse Towner <[email protected]>
5+
# Copyright (c) 2024 Romaric Jodin <[email protected]>
56
#
67
# Permission is hereby granted, free of charge, to any person obtaining a copy
78
# of this software and associated documentation files (the "Software"), to deal
@@ -26,6 +27,16 @@
2627
#
2728
# convert_<destTypen><_sat><_roundingMode>(<sourceTypen>)
2829

30+
import argparse
31+
32+
parser = argparse.ArgumentParser()
33+
parser.add_argument(
34+
"--clspv", action="store_true", help="Generate the clspv variant of the code"
35+
)
36+
args = parser.parse_args()
37+
38+
clspv = args.clspv
39+
2940
types = [
3041
"char",
3142
"uchar",
@@ -251,13 +262,19 @@ def generate_default_conversion(src, dst, mode):
251262
print("#endif")
252263

253264

254-
for src in types:
255-
for dst in types:
256-
generate_default_conversion(src, dst, "")
265+
# Do not generate default conversion for clspv as they are handled natively
266+
if not clspv:
267+
for src in types:
268+
for dst in types:
269+
generate_default_conversion(src, dst, "")
257270

258271
for src in int_types:
259272
for dst in int_types:
260273
for mode in rounding_modes:
274+
# Do not generate "_rte" conversion for clspv as they are handled
275+
# natively
276+
if clspv and mode == "_rte":
277+
continue
261278
generate_default_conversion(src, dst, mode)
262279

263280
#
@@ -304,21 +321,38 @@ def generate_saturated_conversion(src, dst, size):
304321

305322
elif src in float_types:
306323

307-
# Conversion from float to int
308-
print(
309-
""" {DST}{N} y = convert_{DST}{N}(x);
310-
y = select(y, ({DST}{N}){DST_MIN}, {BP}(x < ({SRC}{N}){DST_MIN}){BS});
311-
y = select(y, ({DST}{N}){DST_MAX}, {BP}(x > ({SRC}{N}){DST_MAX}){BS});
312-
return y;""".format(
313-
SRC=src,
314-
DST=dst,
315-
N=size,
316-
DST_MIN=limit_min[dst],
317-
DST_MAX=limit_max[dst],
318-
BP=bool_prefix,
319-
BS=bool_suffix,
324+
if clspv:
325+
# Conversion from float to int
326+
print(
327+
""" {DST}{N} y = convert_{DST}{N}(x);
328+
y = select(y, ({DST}{N}){DST_MIN}, {BP}(x <= ({SRC}{N}){DST_MIN}){BS});
329+
y = select(y, ({DST}{N}){DST_MAX}, {BP}(x >= ({SRC}{N}){DST_MAX}){BS});
330+
return y;""".format(
331+
SRC=src,
332+
DST=dst,
333+
N=size,
334+
DST_MIN=limit_min[dst],
335+
DST_MAX=limit_max[dst],
336+
BP=bool_prefix,
337+
BS=bool_suffix,
338+
)
339+
)
340+
else:
341+
# Conversion from float to int
342+
print(
343+
""" {DST}{N} y = convert_{DST}{N}(x);
344+
y = select(y, ({DST}{N}){DST_MIN}, {BP}(x < ({SRC}{N}){DST_MIN}){BS});
345+
y = select(y, ({DST}{N}){DST_MAX}, {BP}(x > ({SRC}{N}){DST_MAX}){BS});
346+
return y;""".format(
347+
SRC=src,
348+
DST=dst,
349+
N=size,
350+
DST_MIN=limit_min[dst],
351+
DST_MAX=limit_max[dst],
352+
BP=bool_prefix,
353+
BS=bool_suffix,
354+
)
320355
)
321-
)
322356

323357
else:
324358

@@ -432,7 +466,10 @@ def generate_float_conversion(src, dst, size, mode, sat):
432466
print(" return convert_{DST}{N}(x);".format(DST=dst, N=size))
433467
else:
434468
print(" {DST}{N} r = convert_{DST}{N}(x);".format(DST=dst, N=size))
435-
print(" {SRC}{N} y = convert_{SRC}{N}(r);".format(SRC=src, N=size))
469+
if clspv:
470+
print(" {SRC}{N} y = convert_{SRC}{N}_sat(r);".format(SRC=src, N=size))
471+
else:
472+
print(" {SRC}{N} y = convert_{SRC}{N}(r);".format(SRC=src, N=size))
436473
if mode == "_rtz":
437474
if src in int_types:
438475
print(
@@ -448,23 +485,59 @@ def generate_float_conversion(src, dst, size, mode, sat):
448485
else:
449486
print(" {SRC}{N} abs_x = fabs(x);".format(SRC=src, N=size))
450487
print(" {SRC}{N} abs_y = fabs(y);".format(SRC=src, N=size))
451-
print(
452-
" return select(r, nextafter(r, sign(r) * ({DST}{N})-INFINITY), convert_{BOOL}{N}(abs_y > abs_x));".format(
453-
DST=dst, N=size, BOOL=bool_type[dst]
488+
if clspv:
489+
print(
490+
" {BOOL}{N} c = convert_{BOOL}{N}(abs_y > abs_x);".format(
491+
BOOL=bool_type[dst], N=size
492+
)
493+
)
494+
if sizeof_type[src] >= 4 and src in int_types:
495+
print(
496+
" c = c || convert_{BOOL}{N}(({SRC}{N}){SRC_MAX} == x);".format(
497+
BOOL=bool_type[dst], N=size, SRC=src, SRC_MAX=limit_max[src]
498+
)
499+
)
500+
print(
501+
" return select(r, nextafter(r, sign(r) * ({DST}{N})-INFINITY), c);".format(
502+
DST=dst, N=size, BOOL=bool_type[dst], SRC=src
503+
)
504+
)
505+
else:
506+
print(
507+
" return select(r, nextafter(r, sign(r) * ({DST}{N})-INFINITY), convert_{BOOL}{N}(abs_y > abs_x));".format(
508+
DST=dst, N=size, BOOL=bool_type[dst]
509+
)
454510
)
455-
)
456511
if mode == "_rtp":
457512
print(
458513
" return select(r, nextafter(r, ({DST}{N})INFINITY), convert_{BOOL}{N}(y < x));".format(
459514
DST=dst, N=size, BOOL=bool_type[dst]
460515
)
461516
)
462517
if mode == "_rtn":
463-
print(
464-
" return select(r, nextafter(r, ({DST}{N})-INFINITY), convert_{BOOL}{N}(y > x));".format(
465-
DST=dst, N=size, BOOL=bool_type[dst]
518+
if clspv:
519+
print(
520+
" {BOOL}{N} c = convert_{BOOL}{N}(y > x);".format(
521+
BOOL=bool_type[dst], N=size
522+
)
523+
)
524+
if sizeof_type[src] >= 4 and src in int_types:
525+
print(
526+
" c = c || convert_{BOOL}{N}(({SRC}{N}){SRC_MAX} == x);".format(
527+
BOOL=bool_type[dst], N=size, SRC=src, SRC_MAX=limit_max[src]
528+
)
529+
)
530+
print(
531+
" return select(r, nextafter(r, ({DST}{N})-INFINITY), c);".format(
532+
DST=dst, N=size, BOOL=bool_type[dst], SRC=src
533+
)
534+
)
535+
else:
536+
print(
537+
" return select(r, nextafter(r, ({DST}{N})-INFINITY), convert_{BOOL}{N}(y > x));".format(
538+
DST=dst, N=size, BOOL=bool_type[dst]
539+
)
466540
)
467-
)
468541

469542
# Footer
470543
print("}")
@@ -484,4 +557,8 @@ def generate_float_conversion(src, dst, size, mode, sat):
484557
for dst in float_types:
485558
for size in vector_sizes:
486559
for mode in rounding_modes:
560+
# Do not generate "_rte" conversion for clspv as they are
561+
# handled natively
562+
if clspv and mode == "_rte":
563+
continue
487564
generate_float_conversion(src, dst, size, mode, "")

0 commit comments

Comments
 (0)