Skip to content

Commit f1161d6

Browse files
committed
Changes to visitor model to run handlers conditionally.
1 parent c09dae4 commit f1161d6

File tree

1 file changed

+55
-52
lines changed

1 file changed

+55
-52
lines changed

clang/lib/Sema/SemaSYCL.cpp

Lines changed: 55 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,11 @@
2727
#include "llvm/Support/raw_ostream.h"
2828

2929
#include <array>
30+
#include <functional>
31+
#include <initializer_list>
3032

3133
using namespace clang;
34+
using namespace std::placeholders;
3235

3336
using KernelParamKind = SYCLIntegrationHeader::kernel_param_kind_t;
3437

@@ -681,30 +684,60 @@ QualType getItemType(const FieldDecl *FD) { return FD->getType(); }
681684
QualType getItemType(const CXXBaseSpecifier &BS) { return BS.getType(); }
682685

683686
// These enable handler execution only when previous handlers succeed.
684-
template <typename T>
685-
static bool handleField(FieldDecl *FD, QualType FDTy, T &t) {
686-
return (t.first->*t.second)(FD, FDTy);
687+
template <typename... Tn>
688+
static bool handleField(FieldDecl *FD, QualType FDTy, Tn &&... tn) {
689+
bool result = true;
690+
std::initializer_list<int>{(result = result && tn(FD, FDTy), 0)...};
691+
return result;
687692
}
688-
template <typename T, typename... Tn>
689-
static bool handleField(FieldDecl *FD, QualType FDTy, T &t, Tn &... tn) {
690-
return (t.first->*t.second)(FD, FDTy) && handleField(FD, FDTy, tn...);
693+
template <typename... Tn>
694+
static bool handleField(const CXXBaseSpecifier &BD, QualType BDTy,
695+
Tn &&... tn) {
696+
bool result = true;
697+
std::initializer_list<int>{(result = result && tn(BD, BDTy), 0)...};
698+
return result;
691699
}
692700

701+
template <typename T> struct bind_param { using type = T; };
702+
703+
template <> struct bind_param<CXXBaseSpecifier &> {
704+
using type = const CXXBaseSpecifier &;
705+
};
706+
707+
template <> struct bind_param<FieldDecl *&> { using type = FieldDecl *; };
708+
709+
template <> struct bind_param<FieldDecl *const &> { using type = FieldDecl *; };
710+
711+
template <typename T> using bind_param_t = typename bind_param<T>::type;
712+
713+
// This definition using std::bind is necessary because of a gcc 7.x bug.
714+
#define KF_FOR_EACH(FUNC, Item, Qt) \
715+
handleField( \
716+
Item, Qt, \
717+
std::bind(static_cast<bool (std::decay_t<decltype(handlers)>::*)( \
718+
bind_param_t<decltype(Item)>, QualType)>( \
719+
&std::decay_t<decltype(handlers)>::FUNC), \
720+
std::ref(handlers), _1, _2)...)
721+
722+
// The following simpler definition works with gcc 8.x and later.
723+
//#define KF_FOR_EACH(FUNC) \
724+
// handleField(Field, FieldTy, ([&](FieldDecl *FD, QualType FDTy) { \
725+
// return handlers.f(FD, FDTy); \
726+
// })...)
727+
693728
// Implements the 'for-each-visitor' pattern.
694729
template <typename ParentTy, typename... Handlers>
695730
static void VisitAccessorWrapper(CXXRecordDecl *Owner, ParentTy &Parent,
696731
CXXRecordDecl *Wrapper,
697732
Handlers &... handlers);
698733

699734
template <typename RangeTy, typename... Handlers>
700-
static void VisitField(CXXRecordDecl *Owner, RangeTy Item, QualType ItemTy,
735+
static void VisitField(CXXRecordDecl *Owner, RangeTy &&Item, QualType ItemTy,
701736
Handlers &... handlers) {
702737
if (Util::isSyclAccessorType(ItemTy))
703-
(void)std::initializer_list<int>{
704-
(handlers.handleSyclAccessorType(Item, ItemTy), 0)...};
738+
KF_FOR_EACH(handleSyclAccessorType, Item, ItemTy);
705739
if (Util::isSyclStreamType(ItemTy))
706-
(void)std::initializer_list<int>{
707-
(handlers.handleSyclStreamType(Item, ItemTy), 0)...};
740+
KF_FOR_EACH(handleSyclStreamType, Item, ItemTy);
708741
if (ItemTy->isStructureOrClassType())
709742
VisitAccessorWrapper(Owner, Item, ItemTy->getAsCXXRecordDecl(),
710743
handlers...);
@@ -757,40 +790,38 @@ template <typename... Handlers>
757790
static void VisitRecordFields(RecordDecl::field_range Fields,
758791
Handlers &... handlers) {
759792

760-
#define KF_FOR_EACH(FUNC) handleField(Field, FieldTy, handlers.FUNC()...)
761-
762-
for (const auto &Field : Fields) {
793+
for (const auto Field : Fields) {
763794
(void)std::initializer_list<int>{
764795
(handlers.enterField(nullptr, Field), 0)...};
765796
QualType FieldTy = Field->getType();
766797

767798
if (Util::isSyclAccessorType(FieldTy))
768-
KF_FOR_EACH(processSyclAccessorType);
799+
KF_FOR_EACH(handleSyclAccessorType, Field, FieldTy);
769800
else if (Util::isSyclSamplerType(FieldTy))
770-
KF_FOR_EACH(processSyclSamplerType);
801+
KF_FOR_EACH(handleSyclSamplerType, Field, FieldTy);
771802
else if (Util::isSyclSpecConstantType(FieldTy))
772-
KF_FOR_EACH(processSyclSpecConstantType);
803+
KF_FOR_EACH(handleSyclSpecConstantType, Field, FieldTy);
773804
else if (Util::isSyclStreamType(FieldTy)) {
774805
// Stream actually wraps accessors, so do recursion
775806
CXXRecordDecl *RD = FieldTy->getAsCXXRecordDecl();
776807
VisitAccessorWrapper(nullptr, Field, RD, handlers...);
777-
KF_FOR_EACH(processSyclStreamType);
808+
KF_FOR_EACH(handleSyclStreamType, Field, FieldTy);
778809
} else if (FieldTy->isStructureOrClassType()) {
779-
if (KF_FOR_EACH(processStructType)) {
810+
if (KF_FOR_EACH(handleStructType, Field, FieldTy)) {
780811
CXXRecordDecl *RD = FieldTy->getAsCXXRecordDecl();
781812
VisitAccessorWrapper(nullptr, Field, RD, handlers...);
782813
}
783814
} else if (FieldTy->isReferenceType())
784-
KF_FOR_EACH(processReferenceType);
815+
KF_FOR_EACH(handleReferenceType, Field, FieldTy);
785816
else if (FieldTy->isPointerType())
786-
KF_FOR_EACH(processPointerType);
817+
KF_FOR_EACH(handlePointerType, Field, FieldTy);
787818
else if (FieldTy->isArrayType()) {
788-
if (KF_FOR_EACH(processArrayType))
819+
if (KF_FOR_EACH(handleArrayType, Field, FieldTy))
789820
VisitArrayElements(Field, FieldTy, handlers...);
790821
} else if (FieldTy->isScalarType())
791-
KF_FOR_EACH(processScalarType);
822+
KF_FOR_EACH(handleScalarType, Field, FieldTy);
792823
else
793-
KF_FOR_EACH(processOtherType);
824+
KF_FOR_EACH(handleOtherType, Field, FieldTy);
794825
(void)std::initializer_list<int>{
795826
(handlers.leaveField(nullptr, Field), 0)...};
796827
}
@@ -804,21 +835,6 @@ template <typename Derived> class SyclKernelFieldHandler {
804835
Sema &SemaRef;
805836
SyclKernelFieldHandler(Sema &S) : SemaRef(S) {}
806837

807-
// The following capture a handler::function pair.
808-
809-
typedef bool (SyclKernelFieldHandler::*SMemFn)(FieldDecl *, QualType);
810-
using tuple = std::pair<SyclKernelFieldHandler *, SMemFn>;
811-
tuple pmfAccessor{this, &SyclKernelFieldHandler::handleSyclAccessorType};
812-
tuple pmfSampler{this, &SyclKernelFieldHandler::handleSyclSamplerType};
813-
tuple pmfConstant{this, &SyclKernelFieldHandler::handleSyclSpecConstantType};
814-
tuple pmfStream{this, &SyclKernelFieldHandler::handleSyclStreamType};
815-
tuple pmfStruct{this, &SyclKernelFieldHandler::handleStructType};
816-
tuple pmfReference{this, &SyclKernelFieldHandler::handleReferenceType};
817-
tuple pmfPointer{this, &SyclKernelFieldHandler::handlePointerType};
818-
tuple pmfScalar{this, &SyclKernelFieldHandler::handleScalarType};
819-
tuple pmfArray{this, &SyclKernelFieldHandler::handleArrayType};
820-
tuple pmfOther{this, &SyclKernelFieldHandler::handleOtherType};
821-
822838
public:
823839
// Mark these virtual so that we can use override in the implementer classes,
824840
// despite virtual dispatch never being used.
@@ -863,19 +879,6 @@ template <typename Derived> class SyclKernelFieldHandler {
863879
virtual void enterArray() {}
864880
virtual void nextElement(QualType) {}
865881
virtual void leaveArray(QualType, int64_t) {}
866-
867-
// The following return prebuilt handler::function pairs.
868-
869-
virtual tuple &processSyclAccessorType() { return pmfAccessor; }
870-
virtual tuple &processSyclSamplerType() { return pmfSampler; }
871-
virtual tuple &processSyclSpecConstantType() { return pmfConstant; }
872-
virtual tuple &processSyclStreamType() { return pmfStream; }
873-
virtual tuple &processStructType() { return pmfStruct; }
874-
virtual tuple &processReferenceType() { return pmfReference; }
875-
virtual tuple &processPointerType() { return pmfPointer; }
876-
virtual tuple &processScalarType() { return pmfScalar; }
877-
virtual tuple &processArrayType() { return pmfArray; }
878-
virtual tuple &processOtherType() { return pmfOther; }
879882
};
880883

881884
// A type to check the validity of all of the argument types.

0 commit comments

Comments
 (0)