@@ -116,27 +116,35 @@ Value *SPIRVExpander::visitCallInst(CallInst &CI) {
116
116
auto *Ty = CI.getType ();
117
117
auto CalleeName = Callee->getName ();
118
118
119
+ // Check if it's a SPIR-V function
120
+ const StringRef Prefix (" __spirv_" );
121
+ const auto PrefixPos = CalleeName.find (Prefix);
122
+ if (PrefixPos == StringRef::npos)
123
+ return nullptr ;
124
+
125
+ CalleeName = CalleeName.drop_front (PrefixPos + Prefix.size ());
126
+
119
127
// Addrspace-related builtins.
120
- if (CalleeName.contains ( " __spirv_GenericCastToPtrExplicit " ))
128
+ if (CalleeName.startswith ( " GenericCastToPtrExplicit " ))
121
129
return emitIntrinsic (Builder, vc::InternalIntrinsic::cast_to_ptr_explicit,
122
130
Ty, {CI.getArgOperand (0 )});
123
131
124
132
// SPV_INTEL_bfloat16_conversion extension.
125
- if (CalleeName.contains ( " __spirv_ConvertFToBF16INTEL " )) {
133
+ if (CalleeName.startswith ( " ConvertFToBF16INTEL " )) {
126
134
auto *Arg = CI.getArgOperand (0 );
127
135
auto *ArgTy = Arg->getType ();
128
136
return emitIntrinsic (Builder, vc::InternalIntrinsic::cast_to_bf16,
129
137
{Ty, ArgTy}, {Arg});
130
138
}
131
- if (CalleeName.contains ( " __spirv_ConvertBF16ToFINTEL " )) {
139
+ if (CalleeName.startswith ( " ConvertBF16ToFINTEL " )) {
132
140
auto *Arg = CI.getArgOperand (0 );
133
141
auto *ArgTy = Arg->getType ();
134
142
return emitIntrinsic (Builder, vc::InternalIntrinsic::cast_from_bf16,
135
143
{Ty, ArgTy}, {Arg});
136
144
}
137
145
// SPV_INTEL_tensor_float32_rounding extension.
138
- if (CalleeName.contains ( " __spirv_RoundFToTF32INTEL " ) ||
139
- CalleeName.contains ( " __spirv_ConvertFToTF32INTEL " )) {
146
+ if (CalleeName.startswith ( " RoundFToTF32INTEL " ) ||
147
+ CalleeName.startswith ( " ConvertFToTF32INTEL " )) {
140
148
auto *Arg = CI.getArgOperand (0 );
141
149
auto *ArgTy = Arg->getType ();
142
150
Type *ResTy = Builder.getInt32Ty ();
@@ -148,68 +156,72 @@ Value *SPIRVExpander::visitCallInst(CallInst &CI) {
148
156
return Builder.CreateBitCast (Intr, Ty);
149
157
}
150
158
151
- // Math builtins.
152
- if (!CalleeName.contains (" __spirv_ocl_native_" ) &&
153
- !CalleeName.contains (" __spirv_ocl_half_" ))
159
+ // OpenCL extended instruction set
160
+ if (!CalleeName.consume_front (" ocl_" ))
161
+ return nullptr ;
162
+
163
+ // Native subset
164
+ if (!CalleeName.consume_front (" native_" ) &&
165
+ !CalleeName.consume_front (" half_" ))
154
166
return nullptr ;
155
167
156
- if (CalleeName.contains (" cos" ))
168
+ if (CalleeName.startswith (" cos" ))
157
169
return emitMathIntrinsic (Builder, Intrinsic::cos, Ty, {CI.getArgOperand (0 )},
158
170
true );
159
- if (CalleeName.contains (" divide" ))
171
+ if (CalleeName.startswith (" divide" ))
160
172
return emitFDiv (Builder, CI.getArgOperand (0 ), CI.getArgOperand (1 ), true );
161
- if (CalleeName.contains (" exp2" ))
173
+ if (CalleeName.startswith (" exp2" ))
162
174
return emitMathIntrinsic (Builder, Intrinsic::exp2, Ty,
163
175
{CI.getArgOperand (0 )}, true );
164
- if (CalleeName.contains (" exp10" )) {
176
+ if (CalleeName.startswith (" exp10" )) {
165
177
// exp10(x) == exp2(x * log2(10))
166
178
auto *C = ConstantFP::get (Ty, Log2_10);
167
179
auto *ArgV = Builder.CreateFMul (CI.getArgOperand (0 ), C);
168
180
return emitMathIntrinsic (Builder, Intrinsic::exp2, Ty, {ArgV}, true );
169
181
}
170
- if (CalleeName.contains (" exp" )) {
182
+ if (CalleeName.startswith (" exp" )) {
171
183
// exp(x) == exp2(x * log2(e))
172
184
auto *C = ConstantFP::get (Ty, Log2E);
173
185
auto *ArgV = Builder.CreateFMul (CI.getArgOperand (0 ), C);
174
186
return emitMathIntrinsic (Builder, Intrinsic::exp2, Ty, {ArgV}, true );
175
187
}
176
- if (CalleeName.contains (" log2" ))
188
+ if (CalleeName.startswith (" log2" ))
177
189
return emitMathIntrinsic (Builder, Intrinsic::log2, Ty,
178
190
{CI.getArgOperand (0 )}, true );
179
- if (CalleeName.contains (" log10" )) {
191
+ if (CalleeName.startswith (" log10" )) {
180
192
// log10(x) == log2(x) * log10(2)
181
193
auto *LogV = emitMathIntrinsic (Builder, Intrinsic::log2, Ty,
182
194
{CI.getArgOperand (0 )}, true );
183
195
auto *C = ConstantFP::get (Ty, Log10_2);
184
196
return Builder.CreateFMul (LogV, C);
185
197
}
186
- if (CalleeName.contains (" log" )) {
198
+ if (CalleeName.startswith (" log" )) {
187
199
// ln(x) == log2(x) * ln(2)
188
200
auto *LogV = emitMathIntrinsic (Builder, Intrinsic::log2, Ty,
189
201
{CI.getArgOperand (0 )}, true );
190
202
auto *C = ConstantFP::get (Ty, Ln2);
191
203
return Builder.CreateFMul (LogV, C);
192
204
}
193
- if (CalleeName.contains (" powr" ))
205
+ if (CalleeName.startswith (" powr" ))
194
206
return emitMathIntrinsic (Builder, Intrinsic::pow, Ty,
195
207
{CI.getArgOperand (0 ), CI.getArgOperand (1 )}, true );
196
- if (CalleeName.contains (" recip" )) {
208
+ if (CalleeName.startswith (" recip" )) {
197
209
auto *OneC = ConstantFP::get (Ty, 1.0 );
198
210
return emitFDiv (Builder, OneC, CI.getArgOperand (0 ), true );
199
211
}
200
- if (CalleeName.contains (" rsqrt" )) {
212
+ if (CalleeName.startswith (" rsqrt" )) {
201
213
auto *OneC = ConstantFP::get (Ty, 1.0 );
202
214
auto *SqrtV = emitMathIntrinsic (Builder, Intrinsic::sqrt, Ty,
203
215
{CI.getArgOperand (0 )}, true );
204
216
return emitFDiv (Builder, OneC, SqrtV, true );
205
217
}
206
- if (CalleeName.contains (" sin" ))
218
+ if (CalleeName.startswith (" sin" ))
207
219
return emitMathIntrinsic (Builder, Intrinsic::sin, Ty, {CI.getArgOperand (0 )},
208
220
true );
209
- if (CalleeName.contains (" sqrt" ))
221
+ if (CalleeName.startswith (" sqrt" ))
210
222
return emitMathIntrinsic (Builder, Intrinsic::sqrt, Ty,
211
223
{CI.getArgOperand (0 )}, true );
212
- if (CalleeName.contains (" tan" )) {
224
+ if (CalleeName.startswith (" tan" )) {
213
225
// tan(x) == sin(x) / cos(x)
214
226
auto *ArgV = CI.getArgOperand (0 );
215
227
auto *SinV = emitMathIntrinsic (Builder, Intrinsic::sin, Ty, {ArgV}, true );
0 commit comments