Skip to content

Commit 8412c8b

Browse files
committed
Fix "use after release" issues
In some cases, we use handles after releasing them, or incorrectly release handles we shouldn't. This doesn't cause any issues currently, but will when we start using reference counting in the loader.
1 parent 7a4902d commit 8412c8b

File tree

5 files changed

+62
-50
lines changed

5 files changed

+62
-50
lines changed

scripts/templates/valddi.cpp.mako

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,19 @@ namespace ur_validation_layer
9494
%endif
9595
%endfor
9696

97+
%for tp in tracked_params:
98+
<%
99+
tp_handle_funcs = next((hf for hf in handle_create_get_retain_release_funcs if th.subt(n, tags, tp['type']) in [hf['handle'], hf['handle'] + "*"]), None)
100+
is_handle_to_adapter = ("_adapter_handle_t" in tp['type'])
101+
%>
102+
%if func_name in tp_handle_funcs['release']:
103+
if( getContext()->enableLeakChecking )
104+
{
105+
getContext()->refCountContext->decrementRefCount(${tp['name']}, ${str(is_handle_to_adapter).lower()});
106+
}
107+
%endif
108+
%endfor
109+
97110
${x}_result_t result = ${th.make_pfn_name(n, tags, obj)}( ${", ".join(th.make_param_lines(n, tags, obj, format=["name"]))} );
98111

99112
%for tp in tracked_params:
@@ -114,15 +127,10 @@ namespace ur_validation_layer
114127
}
115128
}
116129
%elif func_name in tp_handle_funcs['retain']:
117-
if( getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS )
130+
if( getContext()->enableLeakChecking )
118131
{
119132
getContext()->refCountContext->incrementRefCount(${tp['name']}, ${str(is_handle_to_adapter).lower()});
120133
}
121-
%elif func_name in tp_handle_funcs['release']:
122-
if( getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS )
123-
{
124-
getContext()->refCountContext->decrementRefCount(${tp['name']}, ${str(is_handle_to_adapter).lower()});
125-
}
126134
%endif
127135
%endfor
128136

source/loader/layers/validation/ur_valddi.cpp

Lines changed: 44 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -71,12 +71,12 @@ __urdlllocal ur_result_t UR_APICALL urAdapterRelease(
7171
}
7272
}
7373

74-
ur_result_t result = pfnAdapterRelease(hAdapter);
75-
76-
if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
74+
if (getContext()->enableLeakChecking) {
7775
getContext()->refCountContext->decrementRefCount(hAdapter, true);
7876
}
7977

78+
ur_result_t result = pfnAdapterRelease(hAdapter);
79+
8080
return result;
8181
}
8282

@@ -99,7 +99,7 @@ __urdlllocal ur_result_t UR_APICALL urAdapterRetain(
9999

100100
ur_result_t result = pfnAdapterRetain(hAdapter);
101101

102-
if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
102+
if (getContext()->enableLeakChecking) {
103103
getContext()->refCountContext->incrementRefCount(hAdapter, true);
104104
}
105105

@@ -558,7 +558,7 @@ __urdlllocal ur_result_t UR_APICALL urDeviceRetain(
558558

559559
ur_result_t result = pfnRetain(hDevice);
560560

561-
if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
561+
if (getContext()->enableLeakChecking) {
562562
getContext()->refCountContext->incrementRefCount(hDevice, false);
563563
}
564564

@@ -583,12 +583,12 @@ __urdlllocal ur_result_t UR_APICALL urDeviceRelease(
583583
}
584584
}
585585

586-
ur_result_t result = pfnRelease(hDevice);
587-
588-
if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
586+
if (getContext()->enableLeakChecking) {
589587
getContext()->refCountContext->decrementRefCount(hDevice, false);
590588
}
591589

590+
ur_result_t result = pfnRelease(hDevice);
591+
592592
return result;
593593
}
594594

@@ -861,7 +861,7 @@ __urdlllocal ur_result_t UR_APICALL urContextRetain(
861861

862862
ur_result_t result = pfnRetain(hContext);
863863

864-
if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
864+
if (getContext()->enableLeakChecking) {
865865
getContext()->refCountContext->incrementRefCount(hContext, false);
866866
}
867867

@@ -886,12 +886,12 @@ __urdlllocal ur_result_t UR_APICALL urContextRelease(
886886
}
887887
}
888888

889-
ur_result_t result = pfnRelease(hContext);
890-
891-
if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
889+
if (getContext()->enableLeakChecking) {
892890
getContext()->refCountContext->decrementRefCount(hContext, false);
893891
}
894892

893+
ur_result_t result = pfnRelease(hContext);
894+
895895
return result;
896896
}
897897

@@ -1248,7 +1248,7 @@ __urdlllocal ur_result_t UR_APICALL urMemRetain(
12481248

12491249
ur_result_t result = pfnRetain(hMem);
12501250

1251-
if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
1251+
if (getContext()->enableLeakChecking) {
12521252
getContext()->refCountContext->incrementRefCount(hMem, false);
12531253
}
12541254

@@ -1273,12 +1273,12 @@ __urdlllocal ur_result_t UR_APICALL urMemRelease(
12731273
}
12741274
}
12751275

1276-
ur_result_t result = pfnRelease(hMem);
1277-
1278-
if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
1276+
if (getContext()->enableLeakChecking) {
12791277
getContext()->refCountContext->decrementRefCount(hMem, false);
12801278
}
12811279

1280+
ur_result_t result = pfnRelease(hMem);
1281+
12821282
return result;
12831283
}
12841284

@@ -1657,7 +1657,7 @@ __urdlllocal ur_result_t UR_APICALL urSamplerRetain(
16571657

16581658
ur_result_t result = pfnRetain(hSampler);
16591659

1660-
if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
1660+
if (getContext()->enableLeakChecking) {
16611661
getContext()->refCountContext->incrementRefCount(hSampler, false);
16621662
}
16631663

@@ -1682,12 +1682,12 @@ __urdlllocal ur_result_t UR_APICALL urSamplerRelease(
16821682
}
16831683
}
16841684

1685-
ur_result_t result = pfnRelease(hSampler);
1686-
1687-
if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
1685+
if (getContext()->enableLeakChecking) {
16881686
getContext()->refCountContext->decrementRefCount(hSampler, false);
16891687
}
16901688

1689+
ur_result_t result = pfnRelease(hSampler);
1690+
16911691
return result;
16921692
}
16931693

@@ -2154,7 +2154,7 @@ __urdlllocal ur_result_t UR_APICALL urUSMPoolRetain(
21542154

21552155
ur_result_t result = pfnPoolRetain(pPool);
21562156

2157-
if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
2157+
if (getContext()->enableLeakChecking) {
21582158
getContext()->refCountContext->incrementRefCount(pPool, false);
21592159
}
21602160

@@ -2178,12 +2178,12 @@ __urdlllocal ur_result_t UR_APICALL urUSMPoolRelease(
21782178
}
21792179
}
21802180

2181-
ur_result_t result = pfnPoolRelease(pPool);
2182-
2183-
if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
2181+
if (getContext()->enableLeakChecking) {
21842182
getContext()->refCountContext->decrementRefCount(pPool, false);
21852183
}
21862184

2185+
ur_result_t result = pfnPoolRelease(pPool);
2186+
21872187
return result;
21882188
}
21892189

@@ -2631,7 +2631,7 @@ __urdlllocal ur_result_t UR_APICALL urPhysicalMemRetain(
26312631

26322632
ur_result_t result = pfnRetain(hPhysicalMem);
26332633

2634-
if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
2634+
if (getContext()->enableLeakChecking) {
26352635
getContext()->refCountContext->incrementRefCount(hPhysicalMem, false);
26362636
}
26372637

@@ -2656,12 +2656,12 @@ __urdlllocal ur_result_t UR_APICALL urPhysicalMemRelease(
26562656
}
26572657
}
26582658

2659-
ur_result_t result = pfnRelease(hPhysicalMem);
2660-
2661-
if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
2659+
if (getContext()->enableLeakChecking) {
26622660
getContext()->refCountContext->decrementRefCount(hPhysicalMem, false);
26632661
}
26642662

2663+
ur_result_t result = pfnRelease(hPhysicalMem);
2664+
26652665
return result;
26662666
}
26672667

@@ -2952,7 +2952,7 @@ __urdlllocal ur_result_t UR_APICALL urProgramRetain(
29522952

29532953
ur_result_t result = pfnRetain(hProgram);
29542954

2955-
if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
2955+
if (getContext()->enableLeakChecking) {
29562956
getContext()->refCountContext->incrementRefCount(hProgram, false);
29572957
}
29582958

@@ -2977,12 +2977,12 @@ __urdlllocal ur_result_t UR_APICALL urProgramRelease(
29772977
}
29782978
}
29792979

2980-
ur_result_t result = pfnRelease(hProgram);
2981-
2982-
if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
2980+
if (getContext()->enableLeakChecking) {
29832981
getContext()->refCountContext->decrementRefCount(hProgram, false);
29842982
}
29852983

2984+
ur_result_t result = pfnRelease(hProgram);
2985+
29862986
return result;
29872987
}
29882988

@@ -3618,7 +3618,7 @@ __urdlllocal ur_result_t UR_APICALL urKernelRetain(
36183618

36193619
ur_result_t result = pfnRetain(hKernel);
36203620

3621-
if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
3621+
if (getContext()->enableLeakChecking) {
36223622
getContext()->refCountContext->incrementRefCount(hKernel, false);
36233623
}
36243624

@@ -3643,12 +3643,12 @@ __urdlllocal ur_result_t UR_APICALL urKernelRelease(
36433643
}
36443644
}
36453645

3646-
ur_result_t result = pfnRelease(hKernel);
3647-
3648-
if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
3646+
if (getContext()->enableLeakChecking) {
36493647
getContext()->refCountContext->decrementRefCount(hKernel, false);
36503648
}
36513649

3650+
ur_result_t result = pfnRelease(hKernel);
3651+
36523652
return result;
36533653
}
36543654

@@ -4138,7 +4138,7 @@ __urdlllocal ur_result_t UR_APICALL urQueueRetain(
41384138

41394139
ur_result_t result = pfnRetain(hQueue);
41404140

4141-
if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
4141+
if (getContext()->enableLeakChecking) {
41424142
getContext()->refCountContext->incrementRefCount(hQueue, false);
41434143
}
41444144

@@ -4163,12 +4163,12 @@ __urdlllocal ur_result_t UR_APICALL urQueueRelease(
41634163
}
41644164
}
41654165

4166-
ur_result_t result = pfnRelease(hQueue);
4167-
4168-
if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
4166+
if (getContext()->enableLeakChecking) {
41694167
getContext()->refCountContext->decrementRefCount(hQueue, false);
41704168
}
41714169

4170+
ur_result_t result = pfnRelease(hQueue);
4171+
41724172
return result;
41734173
}
41744174

@@ -4454,7 +4454,7 @@ __urdlllocal ur_result_t UR_APICALL urEventRetain(
44544454

44554455
ur_result_t result = pfnRetain(hEvent);
44564456

4457-
if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
4457+
if (getContext()->enableLeakChecking) {
44584458
getContext()->refCountContext->incrementRefCount(hEvent, false);
44594459
}
44604460

@@ -4478,12 +4478,12 @@ __urdlllocal ur_result_t UR_APICALL urEventRelease(
44784478
}
44794479
}
44804480

4481-
ur_result_t result = pfnRelease(hEvent);
4482-
4483-
if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
4481+
if (getContext()->enableLeakChecking) {
44844482
getContext()->refCountContext->decrementRefCount(hEvent, false);
44854483
}
44864484

4485+
ur_result_t result = pfnRelease(hEvent);
4486+
44874487
return result;
44884488
}
44894489

test/conformance/adapter/urAdapterRelease.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ struct urAdapterReleaseTest : uur::runtime::urAdapterTest {
1616

1717
TEST_F(urAdapterReleaseTest, Success) {
1818
uint32_t referenceCountBefore = 0;
19+
ASSERT_SUCCESS(urAdapterRetain(adapter));
1920

2021
ASSERT_SUCCESS(urAdapterGetInfo(adapter, UR_ADAPTER_INFO_REFERENCE_COUNT,
2122
sizeof(referenceCountBefore),

test/conformance/device/urDeviceRelease.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ struct urDeviceReleaseTest : uur::urAllDevicesTest {};
88

99
TEST_F(urDeviceReleaseTest, Success) {
1010
for (auto device : devices) {
11+
ASSERT_SUCCESS(urDeviceRetain(device));
12+
1113
uint32_t prevRefCount = 0;
1214
ASSERT_SUCCESS(uur::GetObjectReferenceCount(device, prevRefCount));
1315

test/conformance/testing/include/uur/fixtures.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ struct urDeviceTest : urPlatformTest,
9595
void SetUp() override {
9696
UUR_RETURN_ON_FATAL_FAILURE(urPlatformTest::SetUp());
9797
device = GetParam();
98+
EXPECT_SUCCESS(urDeviceRetain(device));
9899
}
99100

100101
void TearDown() override {

0 commit comments

Comments
 (0)