@@ -341,6 +341,13 @@ def __init__(self, args: list[str]) -> None:
341
341
self .lang_spec = TestLanguage (get_lang_spec_and_remove_from_list (args ))
342
342
self .args = args
343
343
344
+ if any (("cuda" in arg for arg in self .args )) and not any (
345
+ "-x" in arg for arg in self .args
346
+ ):
347
+ self .args .append ("-xcuda" )
348
+ self .args .append ("-nocudainc" )
349
+ self .args .append ("-nocudalib" )
350
+
344
351
def is_cuda (self ) -> bool :
345
352
return any ("cuda" in cmd for cmd in self .args )
346
353
@@ -397,7 +404,7 @@ def get_with_lang_spec(args: CompileArgs) -> list[str]:
397
404
398
405
399
406
cuda_header : str = """
400
- typedef unsigned int size_t;
407
+ typedef unsigned long long size_t;
401
408
#define __constant__ __attribute__((constant))
402
409
#define __device__ __attribute__((device))
403
410
#define __global__ __attribute__((global))
@@ -595,18 +602,15 @@ def get_formated_headers(self) -> str:
595
602
def build_test_case (self ):
596
603
self .code = self .code .strip ("\n " )
597
604
has_cuda = self .compile_args and self .compile_args .is_cuda ()
605
+ if has_cuda :
606
+ self .headers .append (("cuda.h" , cuda_header ))
607
+
598
608
res = ""
599
609
if has_cuda :
600
610
res += "#if LLVM_HAS_NVPTX_TARGET\n "
601
611
602
612
res += f"""TEST_P(ASTMatchersDocTest, docs_{ self .line + 1 } ) {{
603
- const StringRef Code = R"cpp(\n { self .code } )cpp";\n """
604
-
605
- if has_cuda :
606
- res += f"""
607
- const StringRef CudaHeader = R"cuda({ cuda_header }
608
- )cuda";
609
- """
613
+ const StringRef Code = R"cpp(\n { "\t #include \" cuda.h\" \n " if has_cuda else "" } { self .code } )cpp";\n """
610
614
611
615
if self .has_headers ():
612
616
res += f"\t const FileContentMappings VirtualMappedFiles = {{{ self .get_formated_headers ()} }};"
@@ -687,7 +691,7 @@ def build_test_case(self):
687
691
{ code_adding_matches }
688
692
689
693
EXPECT_TRUE({ match_function } (
690
- { "CudaHeader + " if has_cuda else "" } Code,
694
+ Code,
691
695
{ matcher .matcher } .bind("match"),
692
696
std::make_unique<Verifier>("match", Matches)"""
693
697
0 commit comments