Skip to content

Commit ffcc53b

Browse files
authored
Merge pull request #2368 from RossBrunton/ross/release
Fix "use after release" issues
2 parents 1a6ad18 + 8412c8b commit ffcc53b

File tree

5 files changed

+62
-50
lines changed

5 files changed

+62
-50
lines changed

scripts/templates/valddi.cpp.mako

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,19 @@ namespace ur_validation_layer
9494
%endif
9595
%endfor
9696

97+
%for tp in tracked_params:
98+
<%
99+
tp_handle_funcs = next((hf for hf in handle_create_get_retain_release_funcs if th.subt(n, tags, tp['type']) in [hf['handle'], hf['handle'] + "*"]), None)
100+
is_handle_to_adapter = ("_adapter_handle_t" in tp['type'])
101+
%>
102+
%if func_name in tp_handle_funcs['release']:
103+
if( getContext()->enableLeakChecking )
104+
{
105+
getContext()->refCountContext->decrementRefCount(${tp['name']}, ${str(is_handle_to_adapter).lower()});
106+
}
107+
%endif
108+
%endfor
109+
97110
${x}_result_t result = ${th.make_pfn_name(n, tags, obj)}( ${", ".join(th.make_param_lines(n, tags, obj, format=["name"]))} );
98111

99112
%for tp in tracked_params:
@@ -114,15 +127,10 @@ namespace ur_validation_layer
114127
}
115128
}
116129
%elif func_name in tp_handle_funcs['retain']:
117-
if( getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS )
130+
if( getContext()->enableLeakChecking )
118131
{
119132
getContext()->refCountContext->incrementRefCount(${tp['name']}, ${str(is_handle_to_adapter).lower()});
120133
}
121-
%elif func_name in tp_handle_funcs['release']:
122-
if( getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS )
123-
{
124-
getContext()->refCountContext->decrementRefCount(${tp['name']}, ${str(is_handle_to_adapter).lower()});
125-
}
126134
%endif
127135
%endfor
128136

source/loader/layers/validation/ur_valddi.cpp

Lines changed: 44 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -71,12 +71,12 @@ __urdlllocal ur_result_t UR_APICALL urAdapterRelease(
7171
}
7272
}
7373

74-
ur_result_t result = pfnAdapterRelease(hAdapter);
75-
76-
if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
74+
if (getContext()->enableLeakChecking) {
7775
getContext()->refCountContext->decrementRefCount(hAdapter, true);
7876
}
7977

78+
ur_result_t result = pfnAdapterRelease(hAdapter);
79+
8080
return result;
8181
}
8282

@@ -99,7 +99,7 @@ __urdlllocal ur_result_t UR_APICALL urAdapterRetain(
9999

100100
ur_result_t result = pfnAdapterRetain(hAdapter);
101101

102-
if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
102+
if (getContext()->enableLeakChecking) {
103103
getContext()->refCountContext->incrementRefCount(hAdapter, true);
104104
}
105105

@@ -558,7 +558,7 @@ __urdlllocal ur_result_t UR_APICALL urDeviceRetain(
558558

559559
ur_result_t result = pfnRetain(hDevice);
560560

561-
if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
561+
if (getContext()->enableLeakChecking) {
562562
getContext()->refCountContext->incrementRefCount(hDevice, false);
563563
}
564564

@@ -583,12 +583,12 @@ __urdlllocal ur_result_t UR_APICALL urDeviceRelease(
583583
}
584584
}
585585

586-
ur_result_t result = pfnRelease(hDevice);
587-
588-
if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
586+
if (getContext()->enableLeakChecking) {
589587
getContext()->refCountContext->decrementRefCount(hDevice, false);
590588
}
591589

590+
ur_result_t result = pfnRelease(hDevice);
591+
592592
return result;
593593
}
594594

@@ -861,7 +861,7 @@ __urdlllocal ur_result_t UR_APICALL urContextRetain(
861861

862862
ur_result_t result = pfnRetain(hContext);
863863

864-
if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
864+
if (getContext()->enableLeakChecking) {
865865
getContext()->refCountContext->incrementRefCount(hContext, false);
866866
}
867867

@@ -886,12 +886,12 @@ __urdlllocal ur_result_t UR_APICALL urContextRelease(
886886
}
887887
}
888888

889-
ur_result_t result = pfnRelease(hContext);
890-
891-
if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
889+
if (getContext()->enableLeakChecking) {
892890
getContext()->refCountContext->decrementRefCount(hContext, false);
893891
}
894892

893+
ur_result_t result = pfnRelease(hContext);
894+
895895
return result;
896896
}
897897

@@ -1248,7 +1248,7 @@ __urdlllocal ur_result_t UR_APICALL urMemRetain(
12481248

12491249
ur_result_t result = pfnRetain(hMem);
12501250

1251-
if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
1251+
if (getContext()->enableLeakChecking) {
12521252
getContext()->refCountContext->incrementRefCount(hMem, false);
12531253
}
12541254

@@ -1273,12 +1273,12 @@ __urdlllocal ur_result_t UR_APICALL urMemRelease(
12731273
}
12741274
}
12751275

1276-
ur_result_t result = pfnRelease(hMem);
1277-
1278-
if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
1276+
if (getContext()->enableLeakChecking) {
12791277
getContext()->refCountContext->decrementRefCount(hMem, false);
12801278
}
12811279

1280+
ur_result_t result = pfnRelease(hMem);
1281+
12821282
return result;
12831283
}
12841284

@@ -1657,7 +1657,7 @@ __urdlllocal ur_result_t UR_APICALL urSamplerRetain(
16571657

16581658
ur_result_t result = pfnRetain(hSampler);
16591659

1660-
if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
1660+
if (getContext()->enableLeakChecking) {
16611661
getContext()->refCountContext->incrementRefCount(hSampler, false);
16621662
}
16631663

@@ -1682,12 +1682,12 @@ __urdlllocal ur_result_t UR_APICALL urSamplerRelease(
16821682
}
16831683
}
16841684

1685-
ur_result_t result = pfnRelease(hSampler);
1686-
1687-
if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
1685+
if (getContext()->enableLeakChecking) {
16881686
getContext()->refCountContext->decrementRefCount(hSampler, false);
16891687
}
16901688

1689+
ur_result_t result = pfnRelease(hSampler);
1690+
16911691
return result;
16921692
}
16931693

@@ -2154,7 +2154,7 @@ __urdlllocal ur_result_t UR_APICALL urUSMPoolRetain(
21542154

21552155
ur_result_t result = pfnPoolRetain(pPool);
21562156

2157-
if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
2157+
if (getContext()->enableLeakChecking) {
21582158
getContext()->refCountContext->incrementRefCount(pPool, false);
21592159
}
21602160

@@ -2178,12 +2178,12 @@ __urdlllocal ur_result_t UR_APICALL urUSMPoolRelease(
21782178
}
21792179
}
21802180

2181-
ur_result_t result = pfnPoolRelease(pPool);
2182-
2183-
if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
2181+
if (getContext()->enableLeakChecking) {
21842182
getContext()->refCountContext->decrementRefCount(pPool, false);
21852183
}
21862184

2185+
ur_result_t result = pfnPoolRelease(pPool);
2186+
21872187
return result;
21882188
}
21892189

@@ -2631,7 +2631,7 @@ __urdlllocal ur_result_t UR_APICALL urPhysicalMemRetain(
26312631

26322632
ur_result_t result = pfnRetain(hPhysicalMem);
26332633

2634-
if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
2634+
if (getContext()->enableLeakChecking) {
26352635
getContext()->refCountContext->incrementRefCount(hPhysicalMem, false);
26362636
}
26372637

@@ -2656,12 +2656,12 @@ __urdlllocal ur_result_t UR_APICALL urPhysicalMemRelease(
26562656
}
26572657
}
26582658

2659-
ur_result_t result = pfnRelease(hPhysicalMem);
2660-
2661-
if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
2659+
if (getContext()->enableLeakChecking) {
26622660
getContext()->refCountContext->decrementRefCount(hPhysicalMem, false);
26632661
}
26642662

2663+
ur_result_t result = pfnRelease(hPhysicalMem);
2664+
26652665
return result;
26662666
}
26672667

@@ -2952,7 +2952,7 @@ __urdlllocal ur_result_t UR_APICALL urProgramRetain(
29522952

29532953
ur_result_t result = pfnRetain(hProgram);
29542954

2955-
if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
2955+
if (getContext()->enableLeakChecking) {
29562956
getContext()->refCountContext->incrementRefCount(hProgram, false);
29572957
}
29582958

@@ -2977,12 +2977,12 @@ __urdlllocal ur_result_t UR_APICALL urProgramRelease(
29772977
}
29782978
}
29792979

2980-
ur_result_t result = pfnRelease(hProgram);
2981-
2982-
if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
2980+
if (getContext()->enableLeakChecking) {
29832981
getContext()->refCountContext->decrementRefCount(hProgram, false);
29842982
}
29852983

2984+
ur_result_t result = pfnRelease(hProgram);
2985+
29862986
return result;
29872987
}
29882988

@@ -3618,7 +3618,7 @@ __urdlllocal ur_result_t UR_APICALL urKernelRetain(
36183618

36193619
ur_result_t result = pfnRetain(hKernel);
36203620

3621-
if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
3621+
if (getContext()->enableLeakChecking) {
36223622
getContext()->refCountContext->incrementRefCount(hKernel, false);
36233623
}
36243624

@@ -3643,12 +3643,12 @@ __urdlllocal ur_result_t UR_APICALL urKernelRelease(
36433643
}
36443644
}
36453645

3646-
ur_result_t result = pfnRelease(hKernel);
3647-
3648-
if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
3646+
if (getContext()->enableLeakChecking) {
36493647
getContext()->refCountContext->decrementRefCount(hKernel, false);
36503648
}
36513649

3650+
ur_result_t result = pfnRelease(hKernel);
3651+
36523652
return result;
36533653
}
36543654

@@ -4138,7 +4138,7 @@ __urdlllocal ur_result_t UR_APICALL urQueueRetain(
41384138

41394139
ur_result_t result = pfnRetain(hQueue);
41404140

4141-
if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
4141+
if (getContext()->enableLeakChecking) {
41424142
getContext()->refCountContext->incrementRefCount(hQueue, false);
41434143
}
41444144

@@ -4163,12 +4163,12 @@ __urdlllocal ur_result_t UR_APICALL urQueueRelease(
41634163
}
41644164
}
41654165

4166-
ur_result_t result = pfnRelease(hQueue);
4167-
4168-
if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
4166+
if (getContext()->enableLeakChecking) {
41694167
getContext()->refCountContext->decrementRefCount(hQueue, false);
41704168
}
41714169

4170+
ur_result_t result = pfnRelease(hQueue);
4171+
41724172
return result;
41734173
}
41744174

@@ -4454,7 +4454,7 @@ __urdlllocal ur_result_t UR_APICALL urEventRetain(
44544454

44554455
ur_result_t result = pfnRetain(hEvent);
44564456

4457-
if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
4457+
if (getContext()->enableLeakChecking) {
44584458
getContext()->refCountContext->incrementRefCount(hEvent, false);
44594459
}
44604460

@@ -4478,12 +4478,12 @@ __urdlllocal ur_result_t UR_APICALL urEventRelease(
44784478
}
44794479
}
44804480

4481-
ur_result_t result = pfnRelease(hEvent);
4482-
4483-
if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
4481+
if (getContext()->enableLeakChecking) {
44844482
getContext()->refCountContext->decrementRefCount(hEvent, false);
44854483
}
44864484

4485+
ur_result_t result = pfnRelease(hEvent);
4486+
44874487
return result;
44884488
}
44894489

test/conformance/adapter/urAdapterRelease.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ struct urAdapterReleaseTest : uur::runtime::urAdapterTest {
1616

1717
TEST_F(urAdapterReleaseTest, Success) {
1818
uint32_t referenceCountBefore = 0;
19+
ASSERT_SUCCESS(urAdapterRetain(adapter));
1920

2021
ASSERT_SUCCESS(urAdapterGetInfo(adapter, UR_ADAPTER_INFO_REFERENCE_COUNT,
2122
sizeof(referenceCountBefore),

test/conformance/device/urDeviceRelease.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ struct urDeviceReleaseTest : uur::urAllDevicesTest {};
88

99
TEST_F(urDeviceReleaseTest, Success) {
1010
for (auto device : devices) {
11+
ASSERT_SUCCESS(urDeviceRetain(device));
12+
1113
uint32_t prevRefCount = 0;
1214
ASSERT_SUCCESS(uur::GetObjectReferenceCount(device, prevRefCount));
1315

test/conformance/testing/include/uur/fixtures.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ struct urDeviceTest : urPlatformTest,
9595
void SetUp() override {
9696
UUR_RETURN_ON_FATAL_FAILURE(urPlatformTest::SetUp());
9797
device = GetParam();
98+
EXPECT_SUCCESS(urDeviceRetain(device));
9899
}
99100

100101
void TearDown() override {

0 commit comments

Comments
 (0)