@@ -22,6 +22,7 @@ SPDX-License-Identifier: MIT
22
22
#include " vc/Support/GenXDiagnostic.h"
23
23
#include " vc/Utils/GenX/IntrinsicsWrapper.h"
24
24
#include " vc/Utils/General/BiF.h"
25
+ #include " vc/Utils/General/Types.h"
25
26
26
27
#include " Probe/Assertion.h"
27
28
@@ -382,6 +383,62 @@ static bool isSPIRVBuiltinDecl(const Function &F) {
382
383
return Name.contains (" __spirv" );
383
384
}
384
385
386
+ static void emitError (Type *ArgTy, Type *NewArgTy, unsigned Index,
387
+ LLVMContext &Ctx, CallInst *CI) {
388
+ SmallString<128 > Message;
389
+ raw_svector_ostream Out (Message);
390
+ Out << " Unexpected function argument #" << Index << " type: " << *ArgTy
391
+ << " , expected: " << *NewArgTy << " \n " ;
392
+ vc::diagnose (Ctx, " GenXTranslateSPIRVBuiltins" , Message, CI);
393
+ }
394
+
395
+ static inline void checkTypesFixPtrs (Function *Func, Function *NewFunc) {
396
+ SmallVector<CallInst *, 1 > CallInstList;
397
+
398
+ // If types not matched - we try to modify it by cast's
399
+ if (Func->getFunctionType () != NewFunc->getFunctionType ()) {
400
+ for (auto *U : Func->users ()) {
401
+ auto *CI = dyn_cast<CallInst>(U);
402
+ if (!CI)
403
+ continue ;
404
+
405
+ CallInstList.push_back (CI);
406
+ IRBuilder<> Builder (CI);
407
+
408
+ for (auto &U : CI->args ()) {
409
+ auto Index = U.getOperandNo ();
410
+ auto *ArgTy = U->getType ();
411
+ auto *NewArgTy = NewFunc->getArg (Index)->getType ();
412
+
413
+ if (isa<PointerType>(ArgTy) && isa<PointerType>(NewArgTy)) {
414
+ auto AS = cast<PointerType>(ArgTy)->getAddressSpace ();
415
+ auto NewAS = cast<PointerType>(NewArgTy)->getAddressSpace ();
416
+ if (AS != NewAS && NewAS != vc::AddrSpace::Generic)
417
+ emitError (ArgTy, NewArgTy, Index, CI->getContext (), CI);
418
+
419
+ U.set (Builder.CreatePointerBitCastOrAddrSpaceCast (U.get (), NewArgTy));
420
+ } else if (ArgTy != NewArgTy)
421
+ emitError (ArgTy, NewArgTy, Index, CI->getContext (), CI);
422
+ }
423
+ }
424
+ }
425
+ Func->deleteBody ();
426
+
427
+ if (!CallInstList.empty ()) {
428
+ Func->stealArgumentListFrom (*NewFunc);
429
+ // A new function is needed to replase FunctionType in all calls
430
+ auto *CastFunc =
431
+ Function::Create (NewFunc->getFunctionType (), Func->getLinkage (),
432
+ NewFunc->getName (), Func->getParent ());
433
+ CastFunc->copyAttributesFrom (Func);
434
+ for (auto *CI : CallInstList)
435
+ CI->setCalledFunction (CastFunc);
436
+
437
+ Func->eraseFromParent ();
438
+ CastFunc->setName (NewFunc->getName ());
439
+ }
440
+ }
441
+
385
442
bool GenXTranslateSPIRVBuiltins::runOnModule (Module &M) {
386
443
bool Changed = false ;
387
444
Expander = SPIRVExpander (&M);
@@ -404,10 +461,8 @@ bool GenXTranslateSPIRVBuiltins::runOnModule(Module &M) {
404
461
for (auto &FuncName : SPIRVBuiltins) {
405
462
auto *Func = M.getFunction (FuncName);
406
463
auto *NewFunc = SPIRVBuiltinsModule->getFunction (FuncName);
407
- if (Func && !Func->isDeclaration () && NewFunc &&
408
- !NewFunc->isDeclaration () &&
409
- Func->getFunctionType () == NewFunc->getFunctionType ())
410
- Func->deleteBody ();
464
+ if (Func && !Func->isDeclaration () && NewFunc && !NewFunc->isDeclaration ())
465
+ checkTypesFixPtrs (Func, NewFunc);
411
466
}
412
467
413
468
if (Linker::linkModules (M, std::move (SPIRVBuiltinsModule),
0 commit comments