Skip to content

Commit 0c86c5b

Browse files
jithunnair-amdpruthvistony
authored andcommitted
[SWDEV-466849] Enhancements for PyTorch UT helper scripts (#1491)
* Check that >1 GPUs are visible when running TEST_CONFIG=distributed * Add EXECUTION_TIME to file-level and aggregate statistics
1 parent d933327 commit 0c86c5b

File tree

1 file changed

+20
-1
lines changed

1 file changed

+20
-1
lines changed

.automation_scripts/run_pytorch_unit_tests.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,11 @@ def get_test_message(test_case, status=None):
116116
else:
117117
return ""
118118

119+
def get_test_file_running_time(test_suite):
120+
if test_suite.__contains__('time'):
121+
return test_suite["time"]
122+
return 0
123+
119124
def get_test_running_time(test_case):
120125
if test_case.__contains__('time'):
121126
return test_case["time"]
@@ -129,9 +134,11 @@ def summarize_xml_files(path, workflow_name):
129134
TOTAL_XFAIL_NUM = 0
130135
TOTAL_FAILED_NUM = 0
131136
TOTAL_ERROR_NUM = 0
137+
TOTAL_EXECUTION_TIME = 0
132138

133139
#parse the xml files
134140
test_cases = parse_xml_reports_as_dict(-1, -1, 'testcase', workflow_name, path)
141+
test_suites = parse_xml_reports_as_dict(-1, -1, 'testsuite', workflow_name, path)
135142
test_file_and_status = namedtuple("test_file_and_status", ["file_name", "status"])
136143
# results dict
137144
res = {}
@@ -146,7 +153,14 @@ def summarize_xml_files(path, workflow_name):
146153
temp_item = test_file_and_status(file_name, item)
147154
res[temp_item] = {}
148155
temp_item_statistics = test_file_and_status(file_name, "STATISTICS")
149-
res[temp_item_statistics] = {'TOTAL': 0, 'PASSED': 0, 'SKIPPED': 0, 'XFAILED': 0, 'FAILED': 0, 'ERROR': 0}
156+
res[temp_item_statistics] = {'TOTAL': 0, 'PASSED': 0, 'SKIPPED': 0, 'XFAILED': 0, 'FAILED': 0, 'ERROR': 0, 'EXECUTION_TIME': 0}
157+
158+
for (k,v) in list(test_suites.items()):
159+
file_name = k[0]
160+
test_tuple_key_statistics = test_file_and_status(file_name, "STATISTICS")
161+
test_running_time = get_test_file_running_time(v)
162+
res[test_tuple_key_statistics]["EXECUTION_TIME"] += test_running_time
163+
TOTAL_EXECUTION_TIME += test_running_time
150164

151165
for (k,v) in list(test_cases.items()):
152166
file_name = k[0]
@@ -195,13 +209,18 @@ def summarize_xml_files(path, workflow_name):
195209
statistics_dict["XFAILED"] = TOTAL_XFAIL_NUM
196210
statistics_dict["FAILED"] = TOTAL_FAILED_NUM
197211
statistics_dict["ERROR"] = TOTAL_ERROR_NUM
212+
statistics_dict["EXECUTION_TIME"] = TOTAL_EXECUTION_TIME
198213
aggregate_item = workflow_name + "_aggregate"
199214
total_item = test_file_and_status(aggregate_item, "STATISTICS")
200215
res[total_item] = statistics_dict
201216

202217
return res
203218

204219
def run_command_and_capture_output(cmd):
220+
if os.environ['TEST_CONFIG'] == 'distributed':
221+
p = subprocess.run("rocminfo | grep -cE 'Name:\s+gfx'", shell=True, capture_output=True, text=True)
222+
num_gpus_visible = int(p.stdout)
223+
assert num_gpus_visible > 1, "Number of visible GPUs should be >1 to run TEST_CONFIG=distributed"
205224
try:
206225
print(f"Running command '{cmd}'")
207226
with open(CONSOLIDATED_LOG_FILE_PATH, "a+") as output_file:

0 commit comments

Comments
 (0)