@@ -10,6 +10,8 @@ using namespace std;
10
10
11
11
namespace {
12
12
13
+ typedef phmap::flat_hash_map<pair<int64_t , int64_t >, int64_t > temporarl_edge_dict;
14
+
13
15
template <bool replace, bool directed>
14
16
tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
15
17
sample (const torch::Tensor &colptr, const torch::Tensor &row,
@@ -146,11 +148,15 @@ hetero_sample(const vector<node_t> &node_types,
146
148
147
149
// Initialize some data structures for the sampling process:
148
150
phmap::flat_hash_map<node_t , vector<int64_t >> samples_dict;
151
+ phmap::flat_hash_map<node_t , vector<pair<int64_t , int64_t >>> temp_samples_dict;
149
152
phmap::flat_hash_map<node_t , phmap::flat_hash_map<int64_t , int64_t >> to_local_node_dict;
153
+ phmap::flat_hash_map<node_t , temporarl_edge_dict> temp_to_local_node_dict;
150
154
phmap::flat_hash_map<node_t , vector<int64_t >> root_time_dict;
151
155
for (const auto &node_type : node_types) {
152
156
samples_dict[node_type];
157
+ temp_samples_dict[node_type];
153
158
to_local_node_dict[node_type];
159
+ temp_to_local_node_dict[node_type];
154
160
root_time_dict[node_type];
155
161
}
156
162
@@ -175,20 +181,33 @@ hetero_sample(const vector<node_t> &node_types,
175
181
}
176
182
177
183
auto &samples = samples_dict.at (node_type);
184
+ auto &temp_samples = temp_samples_dict.at (node_type);
178
185
auto &to_local_node = to_local_node_dict.at (node_type);
186
+ auto &temp_to_local_node = temp_to_local_node_dict.at (node_type);
179
187
auto &root_time = root_time_dict.at (node_type);
180
188
for (int64_t i = 0 ; i < input_node.numel (); i++) {
181
189
const auto &v = input_node_data[i];
182
- samples.push_back (v);
183
- to_local_node.insert ({v, i});
190
+ if (temporal) {
191
+ temp_samples.push_back ({v, i});
192
+ temp_to_local_node.insert ({{v, i}, i});
193
+ } else {
194
+ samples.push_back (v);
195
+ to_local_node.insert ({v, i});
196
+ }
184
197
if (temporal)
185
198
root_time.push_back (node_time_data[v]);
186
199
}
187
200
}
188
201
189
202
phmap::flat_hash_map<node_t , pair<int64_t , int64_t >> slice_dict;
190
- for (const auto &kv : samples_dict)
191
- slice_dict[kv.first ] = {0 , kv.second .size ()};
203
+ if (temporal) {
204
+ for (const auto &kv : temp_samples_dict) {
205
+ slice_dict[kv.first ] = {0 , kv.second .size ()};
206
+ }
207
+ } else {
208
+ for (const auto &kv : samples_dict)
209
+ slice_dict[kv.first ] = {0 , kv.second .size ()};
210
+ }
192
211
193
212
vector<rel_t > all_rel_types;
194
213
for (const auto &kv : num_neighbors_dict) {
@@ -203,8 +222,11 @@ hetero_sample(const vector<node_t> &node_types,
203
222
const auto &dst_node_type = get<2 >(edge_type);
204
223
const auto num_samples = num_neighbors_dict.at (rel_type)[ell];
205
224
const auto &dst_samples = samples_dict.at (dst_node_type);
225
+ const auto &temp_dst_samples = temp_samples_dict.at (dst_node_type);
206
226
auto &src_samples = samples_dict.at (src_node_type);
227
+ auto &temp_src_samples = temp_samples_dict.at (src_node_type);
207
228
auto &to_local_src_node = to_local_node_dict.at (src_node_type);
229
+ auto &temp_to_local_src_node = temp_to_local_node_dict.at (src_node_type);
208
230
209
231
const torch::Tensor &colptr = colptr_dict.at (rel_type);
210
232
const auto *colptr_data = colptr.data_ptr <int64_t >();
@@ -223,7 +245,8 @@ hetero_sample(const vector<node_t> &node_types,
223
245
const auto &begin = slice_dict.at (dst_node_type).first ;
224
246
const auto &end = slice_dict.at (dst_node_type).second ;
225
247
for (int64_t i = begin; i < end; i++) {
226
- const auto &w = dst_samples[i];
248
+ const auto &w = temporal ? temp_dst_samples[i].first : dst_samples[i];
249
+ const int64_t root_w = temporal ? temp_dst_samples[i].second : -1 ;
227
250
int64_t dst_time = 0 ;
228
251
if (temporal)
229
252
dst_time = dst_root_time[i];
@@ -241,15 +264,18 @@ hetero_sample(const vector<node_t> &node_types,
241
264
if (temporal) {
242
265
if (!satisfy_time (node_time_dict, src_node_type, dst_time, v))
243
266
continue ;
244
- // force disjoint of computation tree
267
+ // force disjoint of computation tree based on source batch idx.
245
268
// note that the sampling always needs to have directed=True
246
269
// for temporal case
247
270
// to_local_src_node is not used for temporal / directed case
248
- const int64_t sample_idx = src_samples.size ();
249
- src_samples.push_back (v);
250
- src_root_time.push_back (dst_time);
271
+ const auto res = temp_to_local_src_node.insert ({{v, root_w}, (int64_t )temp_src_samples.size ()});
272
+ if (res.second ) {
273
+ temp_src_samples.push_back ({v, root_w});
274
+ src_root_time.push_back (dst_time);
275
+ }
276
+
251
277
cols.push_back (i);
252
- rows.push_back (sample_idx );
278
+ rows.push_back (res. first -> second );
253
279
edges.push_back (offset);
254
280
} else {
255
281
const auto res = to_local_src_node.insert ({v, src_samples.size ()});
@@ -272,14 +298,17 @@ hetero_sample(const vector<node_t> &node_types,
272
298
// TODO Infinity loop if no neighbor satisfies time constraint:
273
299
if (!satisfy_time (node_time_dict, src_node_type, dst_time, v))
274
300
continue ;
275
- // force disjoint of computation tree
301
+ // force disjoint of computation tree based on source batch idx.
276
302
// note that the sampling always needs to have directed=True
277
303
// for temporal case
278
- const int64_t sample_idx = src_samples.size ();
279
- src_samples.push_back (v);
280
- src_root_time.push_back (dst_time);
304
+ const auto res = temp_to_local_src_node.insert ({{v, root_w}, (int64_t )temp_src_samples.size ()});
305
+ if (res.second ) {
306
+ temp_src_samples.push_back ({v, root_w});
307
+ src_root_time.push_back (dst_time);
308
+ }
309
+
281
310
cols.push_back (i);
282
- rows.push_back (sample_idx );
311
+ rows.push_back (res. first -> second );
283
312
edges.push_back (offset);
284
313
} else {
285
314
const auto res = to_local_src_node.insert ({v, src_samples.size ()});
@@ -307,14 +336,17 @@ hetero_sample(const vector<node_t> &node_types,
307
336
if (temporal) {
308
337
if (!satisfy_time (node_time_dict, src_node_type, dst_time, v))
309
338
continue ;
310
- // force disjoint of computation tree
339
+ // force disjoint of computation tree based on source batch idx.
311
340
// note that the sampling always needs to have directed=True
312
341
// for temporal case
313
- const int64_t sample_idx = src_samples.size ();
314
- src_samples.push_back (v);
315
- src_root_time.push_back (dst_time);
342
+ const auto res = temp_to_local_src_node.insert ({{v, root_w}, (int64_t )temp_src_samples.size ()});
343
+ if (res.second ) {
344
+ temp_src_samples.push_back ({v, root_w});
345
+ src_root_time.push_back (dst_time);
346
+ }
347
+
316
348
cols.push_back (i);
317
- rows.push_back (sample_idx );
349
+ rows.push_back (res. first -> second );
318
350
edges.push_back (offset);
319
351
} else {
320
352
const auto res = to_local_src_node.insert ({v, src_samples.size ()});
@@ -331,11 +363,18 @@ hetero_sample(const vector<node_t> &node_types,
331
363
}
332
364
}
333
365
334
- for (const auto &kv : samples_dict) {
335
- slice_dict[kv.first ] = {slice_dict.at (kv.first ).second , kv.second .size ()};
366
+ if (temporal) {
367
+ for (const auto &kv : temp_samples_dict) {
368
+ slice_dict[kv.first ] = {0 , kv.second .size ()};
369
+ }
370
+ } else {
371
+ for (const auto &kv : samples_dict)
372
+ slice_dict[kv.first ] = {0 , kv.second .size ()};
336
373
}
337
374
}
338
375
376
+ // Temporal sample disable undirected
377
+ assert (!(temporal && !directed));
339
378
if (!directed) { // Construct the subgraph among the sampled nodes:
340
379
phmap::flat_hash_map<int64_t , int64_t >::iterator iter;
341
380
for (const auto &kv : colptr_dict) {
@@ -371,6 +410,18 @@ hetero_sample(const vector<node_t> &node_types,
371
410
}
372
411
}
373
412
413
+ // Construct samples dictionary from temporal sample dictionary.
414
+ if (temporal) {
415
+ for (const auto &kv : temp_samples_dict) {
416
+ const auto &node_type = kv.first ;
417
+ const auto &samples = kv.second ;
418
+ samples_dict[node_type].reserve (samples.size ());
419
+ for (const auto &v : samples) {
420
+ samples_dict[node_type].push_back (v.first );
421
+ }
422
+ }
423
+ }
424
+
374
425
return make_tuple (from_vector<node_t , int64_t >(samples_dict),
375
426
from_vector<rel_t , int64_t >(rows_dict),
376
427
from_vector<rel_t , int64_t >(cols_dict),
0 commit comments