@@ -535,48 +535,39 @@ static bool needsInt32RetTy(Op OC) {
535
535
536
536
void SPIRVToOCLBase::visitCallSPIRVGroupBuiltin (CallInst *CI, Op OC) {
537
537
auto FuncName = groupOCToOCLBuiltinName (CI, OC);
538
- auto ModifyArguments = [=](CallInst *, std::vector<Value *> &Args,
539
- llvm::Type *&RetTy) {
540
- Type *Int32Ty = Type::getInt32Ty (*Ctx);
541
- bool HasArg0ExtendedToi32 =
542
- OC == OpGroupAny || OC == OpGroupAll || OC == OpGroupNonUniformAny ||
543
- OC == OpGroupNonUniformAll || OC == OpGroupNonUniformBallot ||
544
- isGroupLogicalOpCode (OC);
545
- // / Remove Group Operation argument,
546
- // / as in OpenCL representation this is included in the function name
547
- Args.erase (Args.begin (), Args.begin () + (hasGroupOperation (OC) ? 2 : 1 ));
548
-
549
- // Handle function arguments
550
- if (OC == OpGroupBroadcast)
551
- expandVector (CI, Args, 1 );
552
- else if (HasArg0ExtendedToi32)
553
- Args[0 ] = CastInst::CreateZExtOrBitCast (Args[0 ], Int32Ty, " " , CI);
554
-
555
- // Handle function return type
556
- if (needsInt32RetTy (OC))
557
- RetTy = Int32Ty;
558
-
559
- return FuncName;
560
- };
561
- auto ModifyRetTy = [=](CallInst *CI) -> Instruction * {
562
- if (needsInt32RetTy (OC)) {
538
+ auto Mutator = mutateCallInst (CI, FuncName);
539
+ // / Remove Group Operation argument,
540
+ // / as in OpenCL representation this is included in the function name
541
+ Mutator.removeArgs (0 , (hasGroupOperation (OC) ? 2 : 1 ));
542
+
543
+ Type *Int32Ty = Type::getInt32Ty (*Ctx);
544
+ bool HasArg0ExtendedToi32 =
545
+ OC == OpGroupAny || OC == OpGroupAll || OC == OpGroupNonUniformAny ||
546
+ OC == OpGroupNonUniformAll || OC == OpGroupNonUniformBallot ||
547
+ isGroupLogicalOpCode (OC);
548
+
549
+ // Handle function arguments
550
+ if (OC == OpGroupBroadcast) {
551
+ Value *VecArg = Mutator.getArg (1 );
552
+ if (auto *VT = dyn_cast<FixedVectorType>(VecArg->getType ())) {
553
+ unsigned NumElements = VT->getNumElements ();
554
+ for (unsigned I = 0 ; I < NumElements; I++)
555
+ Mutator.insertArg (1 + I, Mutator.Builder .CreateExtractElement (
556
+ VecArg, Mutator.Builder .getInt32 (I)));
557
+ Mutator.removeArg (1 + NumElements);
558
+ }
559
+ } else if (HasArg0ExtendedToi32)
560
+ Mutator.mapArg (0 , [](IRBuilder<> &Builder, Value *V) {
561
+ return Builder.CreateZExt (V, Builder.getInt32Ty ());
562
+ });
563
+
564
+ // Handle function return type
565
+ if (needsInt32RetTy (OC))
566
+ Mutator.changeReturnType (Int32Ty, [](IRBuilder<> &Builder, CallInst *CI) {
563
567
// The OpenCL builtin returns a non-zero integer value. Convert to a
564
568
// boolean value.
565
- Constant *Zero = ConstantInt::get (CI->getType (), 0 );
566
- return new ICmpInst (CI->getNextNode (), CmpInst::ICMP_NE, CI, Zero);
567
- } else
568
- return CI;
569
- };
570
-
571
- assert (CI->getCalledFunction () && " Unexpected indirect call" );
572
- AttributeList Attrs = CI->getCalledFunction ()->getAttributes ();
573
- SmallVector<AttributeSet, 2 > ArgAttrs;
574
- for (int I = (hasGroupOperation (OC) ? 2 : 1 );
575
- I < (int )Attrs.getNumAttrSets () - 2 ; I++)
576
- ArgAttrs.push_back (Attrs.getParamAttrs (I));
577
- Attrs = AttributeList::get (*Ctx, Attrs.getFnAttrs (), Attrs.getRetAttrs (),
578
- ArgAttrs);
579
- mutateCallInstOCL (M, CI, ModifyArguments, ModifyRetTy, &Attrs);
569
+ return Builder.CreateICmpNE (CI, Builder.getInt32 (0 ));
570
+ });
580
571
}
581
572
582
573
void SPIRVToOCLBase::visitCallSPIRVPipeBuiltin (CallInst *CI, Op OC) {
@@ -586,64 +577,52 @@ void SPIRVToOCLBase::visitCallSPIRVPipeBuiltin(CallInst *CI, Op OC) {
586
577
DemangledName = getGroupBuiltinPrefix (CI) + DemangledName;
587
578
588
579
assert (CI->getCalledFunction () && " Unexpected indirect call" );
589
- AttributeList Attrs = CI->getCalledFunction ()->getAttributes ();
590
- mutateCallInstOCL (
591
- M, CI,
592
- [=](CallInst *, std::vector<Value *> &Args) {
593
- if (HasScope)
594
- Args.erase (Args.begin (), Args.begin () + 1 );
595
-
596
- if (!(OC == OpReadPipe || OC == OpWritePipe ||
597
- OC == OpReservedReadPipe || OC == OpReservedWritePipe ||
598
- OC == OpReadPipeBlockingINTEL || OC == OpWritePipeBlockingINTEL))
599
- return DemangledName;
600
-
601
- auto &P = Args[Args.size () - 3 ];
602
- auto T = P->getType ();
603
- assert (isa<PointerType>(T));
604
- auto *NewTy = PointerType::getInt8PtrTy (*Ctx, SPIRAS_Generic);
605
- if (T != NewTy) {
606
- P = CastInst::CreatePointerBitCastOrAddrSpaceCast (P, NewTy, " " , CI);
607
- }
608
- return DemangledName;
609
- },
610
- &Attrs);
580
+ auto Mutator = mutateCallInst (CI, DemangledName);
581
+ if (HasScope)
582
+ Mutator.removeArg (0 );
583
+ if (OC == OpReadPipe || OC == OpWritePipe || OC == OpReservedReadPipe ||
584
+ OC == OpReservedWritePipe || OC == OpReadPipeBlockingINTEL ||
585
+ OC == OpWritePipeBlockingINTEL) {
586
+ Mutator.mapArg (Mutator.arg_size () - 3 , [](IRBuilder<> &Builder, Value *P) {
587
+ Type *T = P->getType ();
588
+ assert (isa<PointerType>(T));
589
+ auto *NewTy = Builder.getInt8PtrTy (SPIRAS_Generic);
590
+ if (T != NewTy) {
591
+ P = Builder.CreatePointerBitCastOrAddrSpaceCast (P, NewTy);
592
+ }
593
+ return std::pair<Value *, Type *>(P, Builder.getInt8Ty ());
594
+ });
595
+ }
611
596
}
612
597
613
598
void SPIRVToOCLBase::visitCallSPIRVImageMediaBlockBuiltin (CallInst *CI, Op OC) {
614
- AttributeList Attrs = CI->getCalledFunction ()->getAttributes ();
615
- mutateCallInstOCL (
616
- M, CI,
617
- [=](CallInst *, std::vector<Value *> &Args) {
618
- // Moving the first argument to the end.
619
- std::rotate (Args.rbegin (), Args.rend () - 1 , Args.rend ());
620
- Type *RetType = CI->getType ();
621
- if (OC == OpSubgroupImageMediaBlockWriteINTEL) {
622
- assert (Args.size () >= 4 && " Wrong media block write signature" );
623
- RetType = Args.at (3 )->getType (); // texel type
624
- }
625
- unsigned int BitWidth = RetType->getScalarSizeInBits ();
626
- std::string FuncPostfix;
627
- if (BitWidth == 8 )
628
- FuncPostfix = " _uc" ;
629
- else if (BitWidth == 16 )
630
- FuncPostfix = " _us" ;
631
- else if (BitWidth == 32 )
632
- FuncPostfix = " _ui" ;
633
- else
634
- assert (0 && " Unsupported texel type!" );
599
+ Type *RetType = CI->getType ();
600
+ if (OC == OpSubgroupImageMediaBlockWriteINTEL) {
601
+ assert (CI->arg_size () >= 5 && " Wrong media block write signature" );
602
+ RetType = CI->getArgOperand (4 )->getType (); // texel type
603
+ }
604
+ unsigned int BitWidth = RetType->getScalarSizeInBits ();
605
+ std::string FuncPostfix;
606
+ if (BitWidth == 8 )
607
+ FuncPostfix = " _uc" ;
608
+ else if (BitWidth == 16 )
609
+ FuncPostfix = " _us" ;
610
+ else if (BitWidth == 32 )
611
+ FuncPostfix = " _ui" ;
612
+ else
613
+ assert (0 && " Unsupported texel type!" );
635
614
636
- if (auto *VecTy = dyn_cast<FixedVectorType>(RetType)) {
637
- unsigned int NumEl = VecTy->getNumElements ();
638
- assert ((NumEl == 2 || NumEl == 4 || NumEl == 8 || NumEl == 16 ) &&
639
- " Wrong function type!" );
640
- FuncPostfix += std::to_string (NumEl);
641
- }
615
+ if (auto *VecTy = dyn_cast<FixedVectorType>(RetType)) {
616
+ unsigned int NumEl = VecTy->getNumElements ();
617
+ assert ((NumEl == 2 || NumEl == 4 || NumEl == 8 || NumEl == 16 ) &&
618
+ " Wrong function type!" );
619
+ FuncPostfix += std::to_string (NumEl);
620
+ }
642
621
643
- return OCLSPIRVBuiltinMap::rmap (OC) + FuncPostfix;
644
- },
645
- &Attrs);
622
+ mutateCallInst (CI, OCLSPIRVBuiltinMap::rmap (OC) + FuncPostfix)
623
+ .moveArg (0 , CI->arg_size () - 1 );
646
624
}
625
+
647
626
void SPIRVToOCLBase::visitCallBuildNDRangeBuiltIn (CallInst *CI, Op OC,
648
627
StringRef DemangledName) {
649
628
assert (CI->getCalledFunction () && " Unexpected indirect call" );
@@ -837,40 +816,34 @@ void SPIRVToOCLBase::visitCallSPIRVImageQueryBuiltIn(CallInst *CI, Op OC) {
837
816
}
838
817
839
818
void SPIRVToOCLBase::visitCallSPIRVSubgroupINTELBuiltIn (CallInst *CI, Op OC) {
840
- assert (CI->getCalledFunction () && " Unexpected indirect call" );
841
- AttributeList Attrs = CI->getCalledFunction ()->getAttributes ();
842
- mutateCallInstOCL (
843
- M, CI,
844
- [=](CallInst *, std::vector<Value *> &Args) {
845
- std::stringstream Name;
846
- Type *DataTy = nullptr ;
847
- switch (OC) {
848
- case OpSubgroupBlockReadINTEL:
849
- case OpSubgroupImageBlockReadINTEL:
850
- Name << " intel_sub_group_block_read" ;
851
- DataTy = CI->getType ();
852
- break ;
853
- case OpSubgroupBlockWriteINTEL:
854
- Name << " intel_sub_group_block_write" ;
855
- DataTy = CI->getOperand (1 )->getType ();
856
- break ;
857
- case OpSubgroupImageBlockWriteINTEL:
858
- Name << " intel_sub_group_block_write" ;
859
- DataTy = CI->getOperand (2 )->getType ();
860
- break ;
861
- default :
862
- return OCLSPIRVBuiltinMap::rmap (OC);
863
- }
864
- assert (DataTy && " Intel subgroup block builtins should have data type" );
865
- unsigned VectorNumElements = 1 ;
866
- if (FixedVectorType *VT = dyn_cast<FixedVectorType>(DataTy))
867
- VectorNumElements = VT->getNumElements ();
868
- unsigned ElementBitSize = DataTy->getScalarSizeInBits ();
869
- Name << getIntelSubgroupBlockDataPostfix (ElementBitSize,
870
- VectorNumElements);
871
- return Name.str ();
872
- },
873
- &Attrs);
819
+ std::stringstream Name;
820
+ Type *DataTy = nullptr ;
821
+ switch (OC) {
822
+ case OpSubgroupBlockReadINTEL:
823
+ case OpSubgroupImageBlockReadINTEL:
824
+ Name << " intel_sub_group_block_read" ;
825
+ DataTy = CI->getType ();
826
+ break ;
827
+ case OpSubgroupBlockWriteINTEL:
828
+ Name << " intel_sub_group_block_write" ;
829
+ DataTy = CI->getOperand (1 )->getType ();
830
+ break ;
831
+ case OpSubgroupImageBlockWriteINTEL:
832
+ Name << " intel_sub_group_block_write" ;
833
+ DataTy = CI->getOperand (2 )->getType ();
834
+ break ;
835
+ default :
836
+ Name << OCLSPIRVBuiltinMap::rmap (OC);
837
+ break ;
838
+ }
839
+ if (DataTy) {
840
+ unsigned VectorNumElements = 1 ;
841
+ if (FixedVectorType *VT = dyn_cast<FixedVectorType>(DataTy))
842
+ VectorNumElements = VT->getNumElements ();
843
+ unsigned ElementBitSize = DataTy->getScalarSizeInBits ();
844
+ Name << getIntelSubgroupBlockDataPostfix (ElementBitSize, VectorNumElements);
845
+ }
846
+ mutateCallInst (CI, Name.str ());
874
847
}
875
848
876
849
void SPIRVToOCLBase::visitCallSPIRVAvcINTELEvaluateBuiltIn (CallInst *CI,
0 commit comments