Skip to content

Commit 50f3a8c

Browse files
committed
fix cuda tests
1 parent 3bfe588 commit 50f3a8c

File tree

2 files changed

+17
-13
lines changed

2 files changed

+17
-13
lines changed

clang/include/clang/ASTMatchers/ASTMatchers.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10292,7 +10292,7 @@ AST_MATCHER_P(CaseStmt, hasCaseConstant, internal::Matcher<Expr>,
1029210292
/// \code
1029310293
/// __attribute__((device)) void f() {}
1029410294
/// \endcode
10295-
/// \compile_args{--cuda-gpu-arch=sm_70}
10295+
/// \compile_args{--cuda-gpu-arch=sm_70;-std=c++}
1029610296
/// The matcher \matcher{decl(hasAttr(clang::attr::CUDADevice))}
1029710297
/// matches \match{type=name$f}.
1029810298
/// If the matcher is used from clang-query, attr::Kind
@@ -10331,12 +10331,12 @@ AST_MATCHER_P(ReturnStmt, hasReturnValue, internal::Matcher<Expr>,
1033110331
/// \code
1033210332
/// __global__ void kernel() {}
1033310333
/// void f() {
10334-
/// kernel<<<32,32>>>();
10334+
/// kernel<<<32, 32>>>();
1033510335
/// }
1033610336
/// \endcode
10337-
/// \compile_args{--cuda-gpu-arch=sm_70}
10337+
/// \compile_args{--cuda-gpu-arch=sm_70;-std=c++}
1033810338
/// The matcher \matcher{cudaKernelCallExpr()}
10339-
/// matches \match{kernel<<<i, k>>>()}
10339+
/// matches \match{kernel<<<32, 32>>>()}
1034010340
extern const internal::VariadicDynCastAllOfMatcher<Stmt, CUDAKernelCallExpr>
1034110341
cudaKernelCallExpr;
1034210342

clang/utils/generate_ast_matcher_doc_tests.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,13 @@ def __init__(self, args: list[str]) -> None:
341341
self.lang_spec = TestLanguage(get_lang_spec_and_remove_from_list(args))
342342
self.args = args
343343

344+
if any(("cuda" in arg for arg in self.args)) and not any(
345+
"-x" in arg for arg in self.args
346+
):
347+
self.args.append("-xcuda")
348+
self.args.append("-nocudainc")
349+
self.args.append("-nocudalib")
350+
344351
def is_cuda(self) -> bool:
345352
return any("cuda" in cmd for cmd in self.args)
346353

@@ -397,7 +404,7 @@ def get_with_lang_spec(args: CompileArgs) -> list[str]:
397404

398405

399406
cuda_header: str = """
400-
typedef unsigned int size_t;
407+
typedef unsigned long long size_t;
401408
#define __constant__ __attribute__((constant))
402409
#define __device__ __attribute__((device))
403410
#define __global__ __attribute__((global))
@@ -595,18 +602,15 @@ def get_formated_headers(self) -> str:
595602
def build_test_case(self):
596603
self.code = self.code.strip("\n")
597604
has_cuda = self.compile_args and self.compile_args.is_cuda()
605+
if has_cuda:
606+
self.headers.append(("cuda.h", cuda_header))
607+
598608
res = ""
599609
if has_cuda:
600610
res += "#if LLVM_HAS_NVPTX_TARGET\n"
601611

602612
res += f"""TEST_P(ASTMatchersDocTest, docs_{self.line + 1}) {{
603-
const StringRef Code = R"cpp(\n{self.code})cpp";\n"""
604-
605-
if has_cuda:
606-
res += f"""
607-
const StringRef CudaHeader = R"cuda({cuda_header}
608-
)cuda";
609-
"""
613+
const StringRef Code = R"cpp(\n{"\t#include \"cuda.h\"\n" if has_cuda else ""}{self.code})cpp";\n"""
610614

611615
if self.has_headers():
612616
res += f"\tconst FileContentMappings VirtualMappedFiles = {{{self.get_formated_headers()}}};"
@@ -687,7 +691,7 @@ def build_test_case(self):
687691
{code_adding_matches}
688692
689693
EXPECT_TRUE({match_function}(
690-
{"CudaHeader + " if has_cuda else ""}Code,
694+
Code,
691695
{matcher.matcher}.bind("match"),
692696
std::make_unique<Verifier>("match", Matches)"""
693697

0 commit comments

Comments
 (0)