Skip to content

Commit 609a906

Browse files
anmyachevvlad-penkin
authored andcommitted
[TEST] Use device fixture for assert_helper.py and print_helper.py (#4643)
Signed-off-by: Anatoly Myachev <[email protected]>
1 parent fbb63e2 commit 609a906

File tree

2 files changed

+4
-9
lines changed

2 files changed

+4
-9
lines changed

python/test/unit/language/assert_helper.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,8 @@ def test_assert(func: str, device: str):
5151
N = 128 # This value should match with test_print in test_subprocess.py.
5252
num_warps = N // get_current_target_warp_size()
5353

54-
x = torch.arange(0, N, dtype=torch.int32, device='xpu')
55-
y = torch.zeros((N, ), dtype=x.dtype, device="xpu")
54+
x = torch.arange(0, N, dtype=torch.int32, device=device)
55+
y = torch.zeros((N, ), dtype=x.dtype, device=device)
5656
if func == "device_assert":
5757
kernel_device_assert[(1, )](x, y, num_warps=num_warps, BLOCK=N)
5858
if func == "device_assert_passes":
@@ -80,7 +80,7 @@ def test_assert(func: str, device: str):
8080
kernel_assert_passes[(1, )](x, y, num_warps=num_warps, BLOCK=N)
8181
assert_close(y, x)
8282
# GPU/host synchronization before exiting the test.
83-
torch.xpu.synchronize()
83+
getattr(torch, device).synchronize()
8484

8585

8686
@triton.jit

python/test/unit/language/print_helper.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -131,19 +131,14 @@ def test_print(func: str, data_type: str, device: str):
131131
else:
132132
assert f"Unknown kernel: {func}"
133133

134-
if device == "xpu":
135-
# FIXME: remove trigger to get output from kernel
136-
repr(x)
137-
repr(y)
138-
139134
if func != "print_no_arg" and func != "no_arg_print" and func != "device_print_large" and \
140135
func != "print_multiple_args" and func != "device_print_multiple_args" and \
141136
func != "device_print_pointer" and func != "device_print_scalar":
142137
assert_close(y, x)
143138

144139
# Wait until driver complete all the jobs for the device_print, especially test_subprocess
145140
# require this which captures stdout when child exits.
146-
torch.xpu.synchronize()
141+
getattr(torch, device).synchronize()
147142

148143

149144
if __name__ == "__main__":

0 commit comments

Comments
 (0)