@@ -935,17 +935,17 @@ def custom_nan_invoke(return_type, arg_types, arg_names):
935
935
Def ("vfloatn" , ["vfloatn" , "float" ], invoke_name = "fmax_common" , convert_args = [(1 ,0 )]),
936
936
Def ("vdoublen" , ["vdoublen" , "double" ], invoke_name = "fmax_common" , convert_args = [(1 ,0 )]),
937
937
Def ("vhalfn" , ["vhalfn" , "half" ], invoke_name = "fmax_common" , convert_args = [(1 ,0 )]), # Non-standard. Deprecated.
938
- Def ("igeninteger" , ["igeninteger" , "igeninteger" ], invoke_name = "s_max" , marray_use_loop = True ),
939
- Def ("ugeninteger" , ["ugeninteger" , "ugeninteger" ], invoke_name = "u_max" , marray_use_loop = True ),
938
+ Def ("igeninteger" , ["igeninteger" , "igeninteger" ], invoke_name = "s_max" , marray_use_loop = True , template_scalar_args = True ),
939
+ Def ("ugeninteger" , ["ugeninteger" , "ugeninteger" ], invoke_name = "u_max" , marray_use_loop = True , template_scalar_args = True ),
940
940
Def ("vigeninteger" , ["vigeninteger" , "elementtype0" ], invoke_name = "s_max" ),
941
941
Def ("vugeninteger" , ["vugeninteger" , "elementtype0" ], invoke_name = "u_max" ),
942
942
Def ("mgentype" , ["mgentype" , "elementtype0" ], marray_use_loop = True )],
943
943
"(min)" : [Def ("genfloat" , ["genfloat" , "genfloat" ], invoke_name = "fmin_common" , template_scalar_args = True ),
944
944
Def ("vfloatn" , ["vfloatn" , "float" ], invoke_name = "fmin_common" , convert_args = [(1 ,0 )]),
945
945
Def ("vdoublen" , ["vdoublen" , "double" ], invoke_name = "fmin_common" , convert_args = [(1 ,0 )]),
946
946
Def ("vhalfn" , ["vhalfn" , "half" ], invoke_name = "fmin_common" , convert_args = [(1 ,0 )]), # Non-standard. Deprecated.
947
- Def ("igeninteger" , ["igeninteger" , "igeninteger" ], invoke_name = "s_min" , marray_use_loop = True ),
948
- Def ("ugeninteger" , ["ugeninteger" , "ugeninteger" ], invoke_name = "u_min" , marray_use_loop = True ),
947
+ Def ("igeninteger" , ["igeninteger" , "igeninteger" ], invoke_name = "s_min" , marray_use_loop = True , template_scalar_args = True ),
948
+ Def ("ugeninteger" , ["ugeninteger" , "ugeninteger" ], invoke_name = "u_min" , marray_use_loop = True , template_scalar_args = True ),
949
949
Def ("vigeninteger" , ["vigeninteger" , "elementtype0" ], invoke_name = "s_min" ),
950
950
Def ("vugeninteger" , ["vugeninteger" , "elementtype0" ], invoke_name = "u_min" ),
951
951
Def ("mgentype" , ["mgentype" , "elementtype0" ], marray_use_loop = True )],
@@ -957,7 +957,7 @@ def custom_nan_invoke(return_type, arg_types, arg_names):
957
957
Def ("mdoublen" , ["mdoublen" , "mdoublen" , "double" ]),
958
958
Def ("mhalfn" , ["mhalfn" , "mhalfn" , "half" ])], # Non-standard. Deprecated.
959
959
"radians" : [Def ("genfloat" , ["genfloat" ], template_scalar_args = True )],
960
- "step" : [Def ("genfloat" , ["genfloat" , "genfloat" ]),
960
+ "step" : [Def ("genfloat" , ["genfloat" , "genfloat" ], template_scalar_args = True ),
961
961
Def ("vfloatn" , ["float" , "vfloatn" ], convert_args = [(0 ,1 )]),
962
962
Def ("vdoublen" , ["double" , "vdoublen" ], convert_args = [(0 ,1 )]),
963
963
Def ("vhalfn" , ["half" , "vhalfn" ], convert_args = [(0 ,1 )]), # Non-standard. Deprecated.
@@ -989,25 +989,25 @@ def custom_nan_invoke(return_type, arg_types, arg_names):
989
989
Def ("float" , ["mgeofloat" , "mgeofloat" ], invoke_name = "Dot" ),
990
990
Def ("double" , ["mgeodouble" , "mgeodouble" ], invoke_name = "Dot" ),
991
991
Def ("half" , ["mgeohalf" , "mgeohalf" ], invoke_name = "Dot" ),
992
- Def ("sgenfloat" , ["sgenfloat" , "sgenfloat" ], custom_invoke = (lambda return_types , arg_types , arg_names : ' return ' + ' * ' .join (arg_names ) + ';' ))],
993
- "distance" : [Def ("float" , ["gengeofloat" , "gengeofloat" ]),
994
- Def ("double" , ["gengeodouble" , "gengeodouble" ]),
995
- Def ("half" , ["gengeohalf" , "gengeohalf" ])],
996
- "length" : [Def ("float" , ["gengeofloat" ]),
997
- Def ("double" , ["gengeodouble" ]),
998
- Def ("half" , ["gengeohalf" ])],
999
- "normalize" : [Def ("gengeofloat" , ["gengeofloat" ]),
1000
- Def ("gengeodouble" , ["gengeodouble" ]),
1001
- Def ("gengeohalf" , ["gengeohalf" ])],
1002
- "fast_distance" : [Def ("float" , ["gengeofloat" , "gengeofloat" ]),
1003
- Def ("double" , ["gengeodouble" , "gengeodouble" ]),
1004
- Def ("half" , ["gengeohalf" , "gengeohalf" ])],
1005
- "fast_length" : [Def ("float" , ["gengeofloat" ]),
1006
- Def ("double" , ["gengeodouble" ]),
1007
- Def ("half" , ["gengeohalf" ])],
1008
- "fast_normalize" : [Def ("gengeofloat" , ["gengeofloat" ]),
1009
- Def ("gengeodouble" , ["gengeodouble" ]),
1010
- Def ("gengeohalf" , ["gengeohalf" ])],
992
+ Def ("sgenfloat" , ["sgenfloat" , "sgenfloat" ], template_scalar_args = True , custom_invoke = (lambda return_types , arg_types , arg_names : ' return ' + ' * ' .join (arg_names ) + ';' ))],
993
+ "distance" : [Def ("float" , ["gengeofloat" , "gengeofloat" ], template_scalar_args = True ),
994
+ Def ("double" , ["gengeodouble" , "gengeodouble" ], template_scalar_args = True ),
995
+ Def ("half" , ["gengeohalf" , "gengeohalf" ], template_scalar_args = True )],
996
+ "length" : [Def ("float" , ["gengeofloat" ], template_scalar_args = True ),
997
+ Def ("double" , ["gengeodouble" ], template_scalar_args = True ),
998
+ Def ("half" , ["gengeohalf" ], template_scalar_args = True )],
999
+ "normalize" : [Def ("gengeofloat" , ["gengeofloat" ], template_scalar_args = True ),
1000
+ Def ("gengeodouble" , ["gengeodouble" ], template_scalar_args = True ),
1001
+ Def ("gengeohalf" , ["gengeohalf" ], template_scalar_args = True )],
1002
+ "fast_distance" : [Def ("float" , ["gengeofloat" , "gengeofloat" ], template_scalar_args = True ),
1003
+ Def ("double" , ["gengeodouble" , "gengeodouble" ], template_scalar_args = True ),
1004
+ Def ("half" , ["gengeohalf" , "gengeohalf" ], template_scalar_args = True )],
1005
+ "fast_length" : [Def ("float" , ["gengeofloat" ], template_scalar_args = True ),
1006
+ Def ("double" , ["gengeodouble" ], template_scalar_args = True ),
1007
+ Def ("half" , ["gengeohalf" ], template_scalar_args = True )],
1008
+ "fast_normalize" : [Def ("gengeofloat" , ["gengeofloat" ], template_scalar_args = True ),
1009
+ Def ("gengeodouble" , ["gengeodouble" ], template_scalar_args = True ),
1010
+ Def ("gengeohalf" , ["gengeohalf" ], template_scalar_args = True )],
1011
1011
# Relational functions
1012
1012
"isequal" : [RelDef ("samesizesignedint0" , ["vgenfloat" , "vgenfloat" ], invoke_name = "FOrdEqual" ),
1013
1013
RelDef ("bool" , ["sgenfloat" , "sgenfloat" ], invoke_name = "FOrdEqual" ),
@@ -1052,13 +1052,13 @@ def custom_nan_invoke(return_type, arg_types, arg_names):
1052
1052
RelDef ("bool" , ["sgenfloat" ], invoke_name = "SignBitSet" ),
1053
1053
RelDef ("boolelements0" , ["mgenfloat" ])],
1054
1054
"any" : [Def ("int" , ["vigeninteger" ], custom_invoke = get_custom_any_all_vec_invoke ("Any" )),
1055
- Def ("bool" , ["sigeninteger" ], custom_invoke = (lambda return_type , arg_types , arg_names : f' return bool(int(detail::msbIsSet({ arg_names [0 ]} )));' )),
1055
+ Def ("bool" , ["sigeninteger" ], template_scalar_args = True , custom_invoke = (lambda return_type , arg_types , arg_names : f' return bool(int(detail::msbIsSet({ arg_names [0 ]} )));' )),
1056
1056
Def ("bool" , ["migeninteger" ], custom_invoke = get_custom_any_all_marray_invoke ("any" ))],
1057
1057
"all" : [Def ("int" , ["vigeninteger" ], custom_invoke = get_custom_any_all_vec_invoke ("All" )),
1058
- Def ("bool" , ["sigeninteger" ], custom_invoke = (lambda return_type , arg_types , arg_names : f' return bool(int(detail::msbIsSet({ arg_names [0 ]} )));' )),
1058
+ Def ("bool" , ["sigeninteger" ], template_scalar_args = True , custom_invoke = (lambda return_type , arg_types , arg_names : f' return bool(int(detail::msbIsSet({ arg_names [0 ]} )));' )),
1059
1059
Def ("bool" , ["migeninteger" ], custom_invoke = get_custom_any_all_marray_invoke ("all" ))],
1060
1060
"bitselect" : [Def ("vgentype" , ["vgentype" , "vgentype" , "vgentype" ]),
1061
- Def ("sgentype" , ["sgentype" , "sgentype" , "sgentype" ]),
1061
+ Def ("sgentype" , ["sgentype" , "sgentype" , "sgentype" ], template_scalar_args = True ),
1062
1062
Def ("mgentype" , ["mgentype" , "mgentype" , "mgentype" ], marray_use_loop = True )],
1063
1063
"select" : [Def ("vint8n" , ["vint8n" , "vint8n" , "vint8n" ]),
1064
1064
Def ("vint16n" , ["vint16n" , "vint16n" , "vint16n" ]),
@@ -1082,7 +1082,7 @@ def custom_nan_invoke(return_type, arg_types, arg_names):
1082
1082
Def ("vfloatn" , ["vfloatn" , "vfloatn" , "vuint32n" ]),
1083
1083
Def ("vdoublen" , ["vdoublen" , "vdoublen" , "vuint64n" ]),
1084
1084
Def ("vhalfn" , ["vhalfn" , "vhalfn" , "vuint16n" ]),
1085
- Def ("sgentype" , ["sgentype" , "sgentype" , "bool" ], custom_invoke = custom_bool_select_invoke ),
1085
+ Def ("sgentype" , ["sgentype" , "sgentype" , "bool" ], template_scalar_args = True , custom_invoke = custom_bool_select_invoke ),
1086
1086
Def ("mgentype" , ["mgentype" , "mgentype" , "mbooln" ], marray_use_loop = True )]}
1087
1087
# List of all builtins definitions in the sycl::native namespace.
1088
1088
native_builtins = {"cos" : [Def ("genfloatf" , ["genfloatf" ], invoke_prefix = "native_" )],
@@ -1210,10 +1210,15 @@ def type_combinations(return_type, arg_types, template_scalars):
1210
1210
Generates all return and argument type combinations for a given builtin
1211
1211
definition.
1212
1212
"""
1213
- unique_types = list (dict .fromkeys (arg_types + [ return_type ] ))
1213
+ unique_types = list (dict .fromkeys (arg_types ))
1214
1214
unique_type_lists = [builtin_types [unique_type ] for unique_type in unique_types ]
1215
1215
if template_scalars :
1216
1216
unique_type_lists = [convert_scalars_to_templated (type_list ) for type_list in unique_type_lists ]
1217
+ if return_type not in unique_types :
1218
+ # Add return type after scalars have been turned to template arguments if
1219
+ # it is unique, to avoid undeducible return types.
1220
+ unique_types .append (return_type )
1221
+ unique_type_lists .append (builtin_types [return_type ])
1217
1222
combinations = list (itertools .product (* unique_type_lists ))
1218
1223
result = []
1219
1224
for combination in combinations :
0 commit comments