Skip to content

Commit 6143af2

Browse files
yaoyaowdDong Wang
and
Dong Wang
authored
use batch idx and node id as unique key for dedup in temporal sampling (#267)
Co-authored-by: Dong Wang <[email protected]>
1 parent 916ba55 commit 6143af2

File tree

1 file changed

+73
-22
lines changed

1 file changed

+73
-22
lines changed

csrc/cpu/neighbor_sample_cpu.cpp

Lines changed: 73 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ using namespace std;
1010

1111
namespace {
1212

13+
typedef phmap::flat_hash_map<pair<int64_t, int64_t>, int64_t> temporarl_edge_dict;
14+
1315
template <bool replace, bool directed>
1416
tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
1517
sample(const torch::Tensor &colptr, const torch::Tensor &row,
@@ -146,11 +148,15 @@ hetero_sample(const vector<node_t> &node_types,
146148

147149
// Initialize some data structures for the sampling process:
148150
phmap::flat_hash_map<node_t, vector<int64_t>> samples_dict;
151+
phmap::flat_hash_map<node_t, vector<pair<int64_t, int64_t>>> temp_samples_dict;
149152
phmap::flat_hash_map<node_t, phmap::flat_hash_map<int64_t, int64_t>> to_local_node_dict;
153+
phmap::flat_hash_map<node_t, temporarl_edge_dict> temp_to_local_node_dict;
150154
phmap::flat_hash_map<node_t, vector<int64_t>> root_time_dict;
151155
for (const auto &node_type : node_types) {
152156
samples_dict[node_type];
157+
temp_samples_dict[node_type];
153158
to_local_node_dict[node_type];
159+
temp_to_local_node_dict[node_type];
154160
root_time_dict[node_type];
155161
}
156162

@@ -175,20 +181,33 @@ hetero_sample(const vector<node_t> &node_types,
175181
}
176182

177183
auto &samples = samples_dict.at(node_type);
184+
auto &temp_samples = temp_samples_dict.at(node_type);
178185
auto &to_local_node = to_local_node_dict.at(node_type);
186+
auto &temp_to_local_node = temp_to_local_node_dict.at(node_type);
179187
auto &root_time = root_time_dict.at(node_type);
180188
for (int64_t i = 0; i < input_node.numel(); i++) {
181189
const auto &v = input_node_data[i];
182-
samples.push_back(v);
183-
to_local_node.insert({v, i});
190+
if (temporal) {
191+
temp_samples.push_back({v, i});
192+
temp_to_local_node.insert({{v, i}, i});
193+
} else {
194+
samples.push_back(v);
195+
to_local_node.insert({v, i});
196+
}
184197
if (temporal)
185198
root_time.push_back(node_time_data[v]);
186199
}
187200
}
188201

189202
phmap::flat_hash_map<node_t, pair<int64_t, int64_t>> slice_dict;
190-
for (const auto &kv : samples_dict)
191-
slice_dict[kv.first] = {0, kv.second.size()};
203+
if (temporal) {
204+
for (const auto &kv : temp_samples_dict) {
205+
slice_dict[kv.first] = {0, kv.second.size()};
206+
}
207+
} else {
208+
for (const auto &kv : samples_dict)
209+
slice_dict[kv.first] = {0, kv.second.size()};
210+
}
192211

193212
vector<rel_t> all_rel_types;
194213
for (const auto &kv : num_neighbors_dict) {
@@ -203,8 +222,11 @@ hetero_sample(const vector<node_t> &node_types,
203222
const auto &dst_node_type = get<2>(edge_type);
204223
const auto num_samples = num_neighbors_dict.at(rel_type)[ell];
205224
const auto &dst_samples = samples_dict.at(dst_node_type);
225+
const auto &temp_dst_samples = temp_samples_dict.at(dst_node_type);
206226
auto &src_samples = samples_dict.at(src_node_type);
227+
auto &temp_src_samples = temp_samples_dict.at(src_node_type);
207228
auto &to_local_src_node = to_local_node_dict.at(src_node_type);
229+
auto &temp_to_local_src_node = temp_to_local_node_dict.at(src_node_type);
208230

209231
const torch::Tensor &colptr = colptr_dict.at(rel_type);
210232
const auto *colptr_data = colptr.data_ptr<int64_t>();
@@ -223,7 +245,8 @@ hetero_sample(const vector<node_t> &node_types,
223245
const auto &begin = slice_dict.at(dst_node_type).first;
224246
const auto &end = slice_dict.at(dst_node_type).second;
225247
for (int64_t i = begin; i < end; i++) {
226-
const auto &w = dst_samples[i];
248+
const auto &w = temporal ? temp_dst_samples[i].first : dst_samples[i];
249+
const int64_t root_w = temporal ? temp_dst_samples[i].second : -1;
227250
int64_t dst_time = 0;
228251
if (temporal)
229252
dst_time = dst_root_time[i];
@@ -241,15 +264,18 @@ hetero_sample(const vector<node_t> &node_types,
241264
if (temporal) {
242265
if (!satisfy_time(node_time_dict, src_node_type, dst_time, v))
243266
continue;
244-
// force disjoint of computation tree
267+
// force disjoint of computation tree based on source batch idx.
245268
// note that the sampling always needs to have directed=True
246269
// for temporal case
247270
// to_local_src_node is not used for temporal / directed case
248-
const int64_t sample_idx = src_samples.size();
249-
src_samples.push_back(v);
250-
src_root_time.push_back(dst_time);
271+
const auto res = temp_to_local_src_node.insert({{v, root_w}, (int64_t)temp_src_samples.size()});
272+
if (res.second) {
273+
temp_src_samples.push_back({v, root_w});
274+
src_root_time.push_back(dst_time);
275+
}
276+
251277
cols.push_back(i);
252-
rows.push_back(sample_idx);
278+
rows.push_back(res.first->second);
253279
edges.push_back(offset);
254280
} else {
255281
const auto res = to_local_src_node.insert({v, src_samples.size()});
@@ -272,14 +298,17 @@ hetero_sample(const vector<node_t> &node_types,
272298
// TODO Infinity loop if no neighbor satisfies time constraint:
273299
if (!satisfy_time(node_time_dict, src_node_type, dst_time, v))
274300
continue;
275-
// force disjoint of computation tree
301+
// force disjoint of computation tree based on source batch idx.
276302
// note that the sampling always needs to have directed=True
277303
// for temporal case
278-
const int64_t sample_idx = src_samples.size();
279-
src_samples.push_back(v);
280-
src_root_time.push_back(dst_time);
304+
const auto res = temp_to_local_src_node.insert({{v, root_w}, (int64_t)temp_src_samples.size()});
305+
if (res.second) {
306+
temp_src_samples.push_back({v, root_w});
307+
src_root_time.push_back(dst_time);
308+
}
309+
281310
cols.push_back(i);
282-
rows.push_back(sample_idx);
311+
rows.push_back(res.first->second);
283312
edges.push_back(offset);
284313
} else {
285314
const auto res = to_local_src_node.insert({v, src_samples.size()});
@@ -307,14 +336,17 @@ hetero_sample(const vector<node_t> &node_types,
307336
if (temporal) {
308337
if (!satisfy_time(node_time_dict, src_node_type, dst_time, v))
309338
continue;
310-
// force disjoint of computation tree
339+
// force disjoint of computation tree based on source batch idx.
311340
// note that the sampling always needs to have directed=True
312341
// for temporal case
313-
const int64_t sample_idx = src_samples.size();
314-
src_samples.push_back(v);
315-
src_root_time.push_back(dst_time);
342+
const auto res = temp_to_local_src_node.insert({{v, root_w}, (int64_t)temp_src_samples.size()});
343+
if (res.second) {
344+
temp_src_samples.push_back({v, root_w});
345+
src_root_time.push_back(dst_time);
346+
}
347+
316348
cols.push_back(i);
317-
rows.push_back(sample_idx);
349+
rows.push_back(res.first->second);
318350
edges.push_back(offset);
319351
} else {
320352
const auto res = to_local_src_node.insert({v, src_samples.size()});
@@ -331,11 +363,18 @@ hetero_sample(const vector<node_t> &node_types,
331363
}
332364
}
333365

334-
for (const auto &kv : samples_dict) {
335-
slice_dict[kv.first] = {slice_dict.at(kv.first).second, kv.second.size()};
366+
if (temporal) {
367+
for (const auto &kv : temp_samples_dict) {
368+
slice_dict[kv.first] = {0, kv.second.size()};
369+
}
370+
} else {
371+
for (const auto &kv : samples_dict)
372+
slice_dict[kv.first] = {0, kv.second.size()};
336373
}
337374
}
338375

376+
// Temporal sample disable undirected
377+
assert(!(temporal && !directed));
339378
if (!directed) { // Construct the subgraph among the sampled nodes:
340379
phmap::flat_hash_map<int64_t, int64_t>::iterator iter;
341380
for (const auto &kv : colptr_dict) {
@@ -371,6 +410,18 @@ hetero_sample(const vector<node_t> &node_types,
371410
}
372411
}
373412

413+
// Construct samples dictionary from temporal sample dictionary.
414+
if (temporal) {
415+
for (const auto &kv : temp_samples_dict) {
416+
const auto &node_type = kv.first;
417+
const auto &samples = kv.second;
418+
samples_dict[node_type].reserve(samples.size());
419+
for (const auto &v : samples) {
420+
samples_dict[node_type].push_back(v.first);
421+
}
422+
}
423+
}
424+
374425
return make_tuple(from_vector<node_t, int64_t>(samples_dict),
375426
from_vector<rel_t, int64_t>(rows_dict),
376427
from_vector<rel_t, int64_t>(cols_dict),

0 commit comments

Comments
 (0)