@@ -41,22 +41,6 @@ namespace
41
41
DEFINE_SIMPLE_CONVERSION_FUNCTIONS (device, DPCTLSyclDeviceRef)
42
42
DEFINE_SIMPLE_CONVERSION_FUNCTIONS (context, DPCTLSyclContextRef)
43
43
44
- /* Checks if two devices are equal based on the underlying native pointer.
45
- */
46
- bool deviceEqChecker (const device &D1, const device &D2)
47
- {
48
- if (D1.is_host () && D2.is_host ()) {
49
- return true ;
50
- }
51
- else if ((D1.is_host () && !D2.is_host ()) || (D2.is_host () && !D1.is_host ()))
52
- {
53
- return false ;
54
- }
55
- else {
56
- return D1.get () == D2.get ();
57
- }
58
- }
59
-
60
44
/*
61
45
* Helper function to print the metadata for a sycl::device.
62
46
*/
@@ -80,64 +64,9 @@ void print_device_info(const device &Device)
80
64
std::cout << ss.str ();
81
65
}
82
66
83
- /*
84
- * Helper class to store DPCTLSyclDeviceType and DPCTLSyclBackendType attributes
85
- * for a device along with the SYCL device.
86
- */
87
- struct DeviceWrapper
88
- {
89
- device SyclDevice;
90
- DPCTLSyclBackendType Bty;
91
- DPCTLSyclDeviceType Dty;
92
-
93
- DeviceWrapper (const device &Device)
94
- : SyclDevice(Device), Bty(DPCTL_SyclBackendToDPCTLBackendType(
95
- Device.get_platform().get_backend())),
96
- Dty (DPCTL_SyclDeviceTypeToDPCTLDeviceType(
97
- Device.get_info<info::device::device_type>()))
98
- {
99
- }
100
-
101
- // The constructor is provided for convenience, so that we do not have to
102
- // lookup the BackendType and DeviceType if not needed.
103
- DeviceWrapper (const device &Device,
104
- DPCTLSyclBackendType Bty,
105
- DPCTLSyclDeviceType Dty)
106
- : SyclDevice(Device), Bty(Bty), Dty(Dty)
107
- {
108
- }
109
- };
110
-
111
- auto getHash (const device &d)
112
- {
113
- if (d.is_host ()) {
114
- return std::hash<unsigned long long >{}(-1 );
115
- }
116
- else {
117
- return std::hash<decltype (d.get ())>{}(d.get ());
118
- }
119
- }
120
-
121
- struct DeviceHasher
122
- {
123
- size_t operator ()(const DeviceWrapper &d) const
124
- {
125
- return getHash (d.SyclDevice );
126
- }
127
- };
128
-
129
- struct DeviceEqPred
130
- {
131
- bool operator ()(const DeviceWrapper &d1, const DeviceWrapper &d2) const
132
- {
133
- return deviceEqChecker (d1.SyclDevice , d2.SyclDevice );
134
- }
135
- };
136
-
137
67
struct DeviceCacheBuilder
138
68
{
139
- using DeviceCache =
140
- std::unordered_map<DeviceWrapper, context, DeviceHasher, DeviceEqPred>;
69
+ using DeviceCache = std::unordered_map<device, context>;
141
70
/* This function implements a workaround to the current lack of a default
142
71
* context per root device in DPC++. The map stores a "default" context for
143
72
* each root device, and the QMgrHelper uses the map whenever it creates a
@@ -181,40 +110,29 @@ struct DeviceCacheBuilder
181
110
#include " dpctl_vector_templ.cpp"
182
111
#undef EL
183
112
184
- bool DPCTLDeviceMgr_AreEq (__dpctl_keep const DPCTLSyclDeviceRef DRef1,
185
- __dpctl_keep const DPCTLSyclDeviceRef DRef2 )
113
+ DPCTLSyclContextRef
114
+ DPCTLDeviceMgr_GetDefaultContext ( __dpctl_keep const DPCTLSyclDeviceRef DRef )
186
115
{
187
- auto D1 = unwrap (DRef1);
188
- auto D2 = unwrap (DRef2);
189
- if (D1 && D2)
190
- return deviceEqChecker (*D1, *D2);
191
- else
192
- return false ;
193
- }
116
+ DPCTLSyclContextRef CRef = nullptr ;
194
117
195
- DPCTL_DeviceAndContextPair DPCTLDeviceMgr_GetDeviceAndContextPair (
196
- __dpctl_keep const DPCTLSyclDeviceRef DRef)
197
- {
198
- DPCTL_DeviceAndContextPair rPair{nullptr , nullptr };
199
118
auto Device = unwrap (DRef);
200
- if (!Device) {
201
- return rPair;
202
- }
203
- DeviceWrapper DWrapper{*Device, DPCTLSyclBackendType::DPCTL_UNKNOWN_BACKEND,
204
- DPCTLSyclDeviceType::DPCTL_UNKNOWN_DEVICE};
119
+ if (!Device)
120
+ return CRef;
121
+
205
122
auto &cache = DeviceCacheBuilder::getDeviceCache ();
206
- auto entry = cache.find (DWrapper );
123
+ auto entry = cache.find (*Device );
207
124
if (entry != cache.end ()) {
208
125
try {
209
- rPair.DRef = wrap (new device (entry->first .SyclDevice ));
210
- rPair.CRef = wrap (new context (entry->second ));
126
+ CRef = wrap (new context (entry->second ));
211
127
} catch (std::bad_alloc const &ba) {
212
128
std::cerr << ba.what () << std::endl;
213
- rPair.DRef = nullptr ;
214
- rPair.CRef = nullptr ;
129
+ CRef = nullptr ;
215
130
}
216
131
}
217
- return rPair;
132
+ else {
133
+ std::cerr << " No cached default context for device" << std::endl;
134
+ }
135
+ return CRef;
218
136
}
219
137
220
138
__dpctl_give DPCTLDeviceVectorRef
@@ -228,12 +146,14 @@ DPCTLDeviceMgr_GetDevices(int device_identifier)
228
146
return nullptr ;
229
147
}
230
148
auto &cache = DeviceCacheBuilder::getDeviceCache ();
231
- Devices-> reserve (cache. size ());
149
+
232
150
for (const auto &entry : cache) {
233
- if ((device_identifier & entry.first .Bty ) &&
234
- (device_identifier & entry.first .Dty ))
235
- {
236
- Devices->emplace_back (wrap (new device (entry.first .SyclDevice )));
151
+ auto Bty (DPCTL_SyclBackendToDPCTLBackendType (
152
+ entry.first .get_platform ().get_backend ()));
153
+ auto Dty (DPCTL_SyclDeviceTypeToDPCTLDeviceType (
154
+ entry.first .get_info <info::device::device_type>()));
155
+ if ((device_identifier & Bty) && (device_identifier & Dty)) {
156
+ Devices->emplace_back (wrap (new device (entry.first )));
237
157
}
238
158
}
239
159
// the wrap function is defined inside dpctl_vector_templ.cpp
@@ -248,11 +168,14 @@ size_t DPCTLDeviceMgr_GetNumDevices(int device_identifier)
248
168
{
249
169
size_t nDevices = 0 ;
250
170
auto &cache = DeviceCacheBuilder::getDeviceCache ();
251
- for (const auto &entry : cache)
252
- if ((device_identifier & entry.first .Bty ) &&
253
- (device_identifier & entry.first .Dty ))
171
+ for (const auto &entry : cache) {
172
+ auto Bty (DPCTL_SyclBackendToDPCTLBackendType (
173
+ entry.first .get_platform ().get_backend ()));
174
+ auto Dty (DPCTL_SyclDeviceTypeToDPCTLDeviceType (
175
+ entry.first .get_info <info::device::device_type>()));
176
+ if ((device_identifier & Bty) && (device_identifier & Dty))
254
177
++nDevices;
255
-
178
+ }
256
179
return nDevices;
257
180
}
258
181
0 commit comments