@@ -22,7 +22,7 @@ def __init__(self, related_arg_type, vec_type):
22
22
self .vec_type = vec_type
23
23
24
24
def __str__ (self ):
25
- return f'detail::vec_return_t <{ self .related_arg_type } >'
25
+ return f'detail::get_vec_t <{ self .related_arg_type } >'
26
26
27
27
class MultiPtr :
28
28
def __init__ (self , element_type ):
@@ -66,20 +66,19 @@ def __init__(self, signed_type, parent_idx):
66
66
def __str__ (self ):
67
67
return f'detail::make_unsigned_t<{ self .signed_type } >'
68
68
69
- class SameSizeIntType :
70
- def __init__ (self , signed , parent_idx ):
71
- self .signed = signed
69
+ class ConversionTraitType :
70
+ def __init__ (self , trait , parent_idx ):
71
+ self .trait = trait
72
72
self .parent_idx = parent_idx
73
73
74
- class InstantiatedSameSizeIntType :
75
- def __init__ (self , parent_type , signed , parent_idx ):
74
+ class InstantiatedConversionTraitType :
75
+ def __init__ (self , parent_type , trait , parent_idx ):
76
76
self .parent_type = parent_type
77
- self .signed = signed
77
+ self .trait = trait
78
78
self .parent_idx = parent_idx
79
79
80
80
def __str__ (self ):
81
- signedness = 'signed' if self .signed else 'unsigned'
82
- return f'detail::same_size_{ signedness } _int_t<{ self .parent_type } >'
81
+ return f'detail::{ self .trait } <{ self .parent_type } >'
83
82
84
83
### GENTYPE DEFINITIONS
85
84
# NOTE: Marray is currently explicitly defined.
@@ -190,8 +189,10 @@ def __str__(self):
190
189
# argument.
191
190
elementtype0 = [ElementType (0 )]
192
191
unsignedtype0 = [UnsignedType (0 )]
193
- samesizesignedint0 = [SameSizeIntType (True , 0 )]
194
- samesizeunsignedint0 = [SameSizeIntType (False , 0 )]
192
+ samesizesignedint0 = [ConversionTraitType ("same_size_signed_int_t" , 0 )]
193
+ samesizeunsignedint0 = [ConversionTraitType ("same_size_unsigned_int_t" , 0 )]
194
+ samesizefloat0 = [ConversionTraitType ("same_size_float_t" , 0 )]
195
+ upsampledint0 = [ConversionTraitType ("upsampled_int_t" , 0 )]
195
196
196
197
builtin_types = {
197
198
"floatn" : floatn ,
@@ -276,6 +277,8 @@ def __str__(self):
276
277
"unsignedtype0" : unsignedtype0 ,
277
278
"samesizesignedint0" : samesizesignedint0 ,
278
279
"samesizeunsignedint0" : samesizeunsignedint0 ,
280
+ "samesizefloat0" : samesizefloat0 ,
281
+ "upsampledint0" : upsampledint0 ,
279
282
"char" : ["char" ],
280
283
"signed char" : ["signed char" ],
281
284
"short" : ["short" ],
@@ -309,6 +312,11 @@ def find_first_vec_arg(arg_types):
309
312
return arg_type
310
313
return None
311
314
315
+ def convert_vec_arg_name (arg_type , arg_name ):
316
+ if isinstance (arg_type , InstantiatedVecArg ):
317
+ return f'typename detail::get_vec_t<{ arg_type } >({ arg_name } )'
318
+ return arg_name
319
+
312
320
class Def :
313
321
def __init__ (self , return_type , arg_types , invoke_name = None ,
314
322
invoke_prefix = "" , custom_invoke = None , fast_math_invoke_name = None ,
@@ -328,7 +336,7 @@ def __init__(self, return_type, arg_types, invoke_name=None,
328
336
self .vec_size_alias = vec_size_alias
329
337
330
338
def get_invoke_args (self , arg_types , arg_names ):
331
- result = arg_names
339
+ result = list ( map ( convert_vec_arg_name , arg_types , arg_names ))
332
340
for (arg_idx , type_conv ) in self .convert_args :
333
341
# type_conv is either an index or a conversion function/type.
334
342
conv = type_conv if isinstance (type_conv , str ) else arg_types [type_conv ]
@@ -424,8 +432,8 @@ def custom_bool_select_invoke(return_type, _, arg_names):
424
432
"fract" : [Def ("vgenfloat" , ["vgenfloat" , "vgenfloatptr" ]),
425
433
Def ("float" , ["float" , "floatptr" ]),
426
434
Def ("double" , ["double" , "doubleptr" ]),
427
- Def ("half" , ["half" , "halfptr" ]),
428
- Def ("vgenfloat" , ["vgenfloat" , "vint32nptr" ]),
435
+ Def ("half" , ["half" , "halfptr" ])] ,
436
+ "frexp" : [ Def ("vgenfloat" , ["vgenfloat" , "vint32nptr" ]),
429
437
Def ("float" , ["float" , "intptr" ]),
430
438
Def ("double" , ["double" , "intptr" ]),
431
439
Def ("half" , ["half" , "intptr" ])],
@@ -456,13 +464,13 @@ def custom_bool_select_invoke(return_type, _, arg_names):
456
464
Def ("float" , ["float" , "floatptr" ]),
457
465
Def ("double" , ["double" , "doubleptr" ]),
458
466
Def ("half" , ["half" , "halfptr" ])],
459
- "nan" : [Def ("vfloatn " , ["vuint32n" ]),
460
- Def ("float " , ["unsigned int" ]),
461
- Def ("vdoublen " , ["vuint64n_ext" ]),
462
- Def ("double " , ["unsigned long" ]),
463
- Def ("double " , ["unsigned long long" ]),
464
- Def ("vhalfn " , ["vuint16n" ]),
465
- Def ("half " , ["unsigned short" ])],
467
+ "nan" : [Def ("samesizefloat0 " , ["vuint32n" ]),
468
+ Def ("samesizefloat0 " , ["unsigned int" ]),
469
+ Def ("samesizefloat0 " , ["vuint64n_ext" ]),
470
+ Def ("samesizefloat0 " , ["unsigned long" ]),
471
+ Def ("samesizefloat0 " , ["unsigned long long" ]),
472
+ Def ("samesizefloat0 " , ["vuint16n" ]),
473
+ Def ("samesizefloat0 " , ["unsigned short" ])],
466
474
"nextafter" : [Def ("genfloat" , ["genfloat" , "genfloat" ])],
467
475
"pow" : [Def ("genfloat" , ["genfloat" , "genfloat" ])],
468
476
"pown" : [Def ("vgenfloat" , ["vgenfloat" , "vint32n" ]),
@@ -498,9 +506,6 @@ def custom_bool_select_invoke(return_type, _, arg_names):
498
506
"tgamma" : [Def ("genfloat" , ["genfloat" ])],
499
507
"trunc" : [Def ("genfloat" , ["genfloat" ])],
500
508
# Integer functions
501
- "abs" : [Def ("sigeninteger" , ["sigeninteger" ], custom_invoke = custom_signed_abs_scalar_invoke ),
502
- Def ("vigeninteger" , ["vigeninteger" ], custom_invoke = custom_signed_abs_vector_invoke ),
503
- Def ("ugeninteger" , ["ugeninteger" ], invoke_prefix = "u_" )],
504
509
"abs_diff" : [Def ("unsignedtype0" , ["igeninteger" , "igeninteger" ], invoke_prefix = "s_" ),
505
510
Def ("ugeninteger" , ["ugeninteger" , "ugeninteger" ], invoke_prefix = "u_" )],
506
511
"add_sat" : [Def ("igeninteger" , ["igeninteger" , "igeninteger" ], invoke_prefix = "s_" ),
@@ -520,18 +525,19 @@ def custom_bool_select_invoke(return_type, _, arg_names):
520
525
"rotate" : [Def ("geninteger" , ["geninteger" , "geninteger" ])],
521
526
"sub_sat" : [Def ("igeninteger" , ["igeninteger" , "igeninteger" ], invoke_prefix = "s_" ),
522
527
Def ("ugeninteger" , ["ugeninteger" , "ugeninteger" ], invoke_prefix = "u_" )],
523
- "upsample" : [Def ("int16_t" , ["int8_t" , "uint8_t" ], invoke_prefix = "s_" ),
524
- Def ("vint16n" , ["vint8n" , "vuint8n" ], invoke_prefix = "s_" ),
525
- Def ("uint16_t" , ["uint8_t" , "uint8_t" ], invoke_prefix = "u_" ),
526
- Def ("vuint16n" , ["vuint8n" , "vuint8n" ], invoke_prefix = "u_" ),
527
- Def ("int32_t" , ["int16_t" , "uint16_t" ], invoke_prefix = "s_" ),
528
- Def ("vint32n" , ["vint16n" , "vuint16n" ], invoke_prefix = "s_" ),
529
- Def ("uint32_t" , ["uint16_t" , "uint16_t" ], invoke_prefix = "u_" ),
530
- Def ("vuint32n" , ["vuint16n" , "vuint16n" ], invoke_prefix = "u_" ),
531
- Def ("int64_t" , ["int32_t" , "uint32_t" ], invoke_prefix = "s_" ),
532
- Def ("vint64n" , ["vint32n" , "vuint32n" ], invoke_prefix = "s_" ),
533
- Def ("uint64_t" , ["uint32_t" , "uint32_t" ], invoke_prefix = "u_" ),
534
- Def ("vuint64n" , ["vuint32n" , "vuint32n" ], invoke_prefix = "u_" )],
528
+ "upsample" : [Def ("upsampledint0" , ["int8_t" , "uint8_t" ], invoke_prefix = "s_" ),
529
+ Def ("upsampledint0" , ["char" , "uint8_t" ], invoke_prefix = "s_" ), # TODO: Non-standard. Deprecate.
530
+ Def ("upsampledint0" , ["vint8n" , "vuint8n" ], invoke_prefix = "s_" ),
531
+ Def ("upsampledint0" , ["uint8_t" , "uint8_t" ], invoke_prefix = "u_" ),
532
+ Def ("upsampledint0" , ["vuint8n" , "vuint8n" ], invoke_prefix = "u_" ),
533
+ Def ("upsampledint0" , ["int16_t" , "uint16_t" ], invoke_prefix = "s_" ),
534
+ Def ("upsampledint0" , ["vint16n" , "vuint16n" ], invoke_prefix = "s_" ),
535
+ Def ("upsampledint0" , ["uint16_t" , "uint16_t" ], invoke_prefix = "u_" ),
536
+ Def ("upsampledint0" , ["vuint16n" , "vuint16n" ], invoke_prefix = "u_" ),
537
+ Def ("upsampledint0" , ["int32_t" , "uint32_t" ], invoke_prefix = "s_" ),
538
+ Def ("upsampledint0" , ["vint32n" , "vuint32n" ], invoke_prefix = "s_" ),
539
+ Def ("upsampledint0" , ["uint32_t" , "uint32_t" ], invoke_prefix = "u_" ),
540
+ Def ("upsampledint0" , ["vuint32n" , "vuint32n" ], invoke_prefix = "u_" )],
535
541
"popcount" : [Def ("geninteger" , ["geninteger" ])],
536
542
"mad24" : [Def ("igenint32" , ["igenint32" , "igenint32" , "igenint32" ], invoke_prefix = "s_" ),
537
543
Def ("ugenint32" , ["ugenint32" , "ugenint32" , "ugenint32" ], invoke_prefix = "u_" )],
@@ -571,7 +577,10 @@ def custom_bool_select_invoke(return_type, _, arg_names):
571
577
Def ("vfloatn" , ["float" , "float" , "vfloatn" ], convert_args = [(0 ,2 ),(1 ,2 )]),
572
578
Def ("vdoublen" , ["double" , "double" , "vdoublen" ], convert_args = [(0 ,2 ),(1 ,2 )])],
573
579
"sign" : [Def ("genfloat" , ["genfloat" ])],
574
- "abs" : [Def ("genfloat" , ["genfloat" ], invoke_prefix = "f" )], # TODO: Non-standard. Deprecate.
580
+ "abs" : [Def ("genfloat" , ["genfloat" ], invoke_prefix = "f" ), # TODO: Non-standard. Deprecate.
581
+ Def ("sigeninteger" , ["sigeninteger" ], custom_invoke = custom_signed_abs_scalar_invoke ),
582
+ Def ("vigeninteger" , ["vigeninteger" ], custom_invoke = custom_signed_abs_vector_invoke ),
583
+ Def ("ugeninteger" , ["ugeninteger" ], invoke_prefix = "u_" )],
575
584
# Geometric functions
576
585
"cross" : [Def ("vfloat3or4" , ["vfloat3or4" , "vfloat3or4" ]),
577
586
Def ("vdouble3or4" , ["vdouble3or4" , "vdouble3or4" ]),
@@ -698,9 +707,9 @@ def select_from_mapping(mappings, arg_types, arg_type):
698
707
if isinstance (mapping , UnsignedType ):
699
708
parent_mapping = mappings [arg_types [mapping .parent_idx ]]
700
709
return InstantiatedUnsignedType (parent_mapping , mapping .parent_idx )
701
- if isinstance (mapping , SameSizeIntType ):
710
+ if isinstance (mapping , ConversionTraitType ):
702
711
parent_mapping = mappings [arg_types [mapping .parent_idx ]]
703
- return InstantiatedSameSizeIntType (parent_mapping , mapping .signed , mapping .parent_idx )
712
+ return InstantiatedConversionTraitType (parent_mapping , mapping .trait , mapping .parent_idx )
704
713
return mapping
705
714
706
715
def instantiate_arg (idx , arg ):
@@ -714,8 +723,8 @@ def instantiate_arg(idx, arg):
714
723
return InstantiatedElementType (instantiate_arg (arg .parent_idx , arg .referenced_type ), arg .parent_idx )
715
724
if isinstance (arg , InstantiatedUnsignedType ):
716
725
return InstantiatedUnsignedType (instantiate_arg (arg .parent_idx , arg .signed_type ), arg .parent_idx )
717
- if isinstance (arg , InstantiatedSameSizeIntType ):
718
- return InstantiatedSameSizeIntType (instantiate_arg (arg .parent_idx , arg .parent_type ), arg .signed , arg .parent_idx )
726
+ if isinstance (arg , InstantiatedConversionTraitType ):
727
+ return InstantiatedConversionTraitType (instantiate_arg (arg .parent_idx , arg .parent_type ), arg .trait , arg .parent_idx )
719
728
return arg
720
729
721
730
def instantiate_return_type (return_type , instantiated_args ):
@@ -730,8 +739,8 @@ def instantiate_return_type(return_type, instantiated_args):
730
739
return InstantiatedElementType (instantiate_return_type (return_type .referenced_type , instantiated_args ), return_type .parent_idx )
731
740
if isinstance (return_type , InstantiatedUnsignedType ):
732
741
return InstantiatedUnsignedType (instantiate_return_type (return_type .signed_type , instantiated_args ), return_type .parent_idx )
733
- if isinstance (return_type , InstantiatedSameSizeIntType ):
734
- return InstantiatedSameSizeIntType (instantiate_return_type (return_type .parent_type , instantiated_args ), return_type .signed , return_type .parent_idx )
742
+ if isinstance (return_type , InstantiatedConversionTraitType ):
743
+ return InstantiatedConversionTraitType (instantiate_return_type (return_type .parent_type , instantiated_args ), return_type .trait , return_type .parent_idx )
735
744
return return_type
736
745
737
746
def type_combinations (return_type , arg_types ):
@@ -772,9 +781,9 @@ def get_all_vec_args(arg_types):
772
781
def get_vec_arg_requirement (vec_arg ):
773
782
valid_type_str = ', ' .join (vec_arg .vec_type .valid_types )
774
783
valid_sizes_str = ', ' .join (map (str , vec_arg .vec_type .valid_sizes ))
775
- checks = [f'detail::is_vec_v <{ vec_arg } >' ,
776
- f'detail::is_valid_vec_type_v <{ vec_arg .template_name } , { valid_type_str } >' ,
777
- f'detail::is_valid_vec_size_v <{ vec_arg .template_name } , { valid_sizes_str } >' ]
784
+ checks = [f'detail::is_vec_or_swizzle_v <{ vec_arg } >' ,
785
+ f'detail::is_valid_elem_type_v <{ vec_arg .template_name } , { valid_type_str } >' ,
786
+ f'detail::is_valid_size_v <{ vec_arg .template_name } , { valid_sizes_str } >' ]
778
787
return '(' + (' && ' .join (checks )) + ')'
779
788
780
789
def get_func_return (return_type , arg_types ):
0 commit comments