Skip to content

Commit 39ff801

Browse files
rohithmenonpytorchmergebot
authored andcommitted
Add support for an operator level thread local observer (pytorch#108822)
Summary: Add support for an operator level thread local observer Test Plan: Verified the interception as part of a pytorch model evaluation with static runtime. Differential Revision: D49082250 Pull Request resolved: pytorch#108822 Approved by: https://github.com/davidberard98
1 parent 6823860 commit 39ff801

File tree

2 files changed

+44
-0
lines changed

2 files changed

+44
-0
lines changed

torch/csrc/jit/runtime/static/impl.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -760,6 +760,31 @@ size_t StaticModule::prepareStaticNodeInfos(
760760
return node_idx - node_start;
761761
}
762762

763+
#ifdef FBCODE_CAFFE2
764+
thread_local SROperatorObserver* tlsOpObserver = nullptr;
765+
766+
void SROperatorObserver::setCurrentThreadObserver(
767+
SROperatorObserver* observer) {
768+
tlsOpObserver = observer;
769+
}
770+
771+
SROperatorObserver* SROperatorObserver::getCurrentThreadObserver() {
772+
return tlsOpObserver;
773+
}
774+
775+
void SROperatorObserver::onStart(const Node* node) {
776+
if (tlsOpObserver != nullptr && tlsOpObserver->startCb != nullptr) {
777+
tlsOpObserver->startCb(node);
778+
}
779+
}
780+
781+
void SROperatorObserver::onEnd(const Node* node) {
782+
if (tlsOpObserver != nullptr && tlsOpObserver->endCb != nullptr) {
783+
tlsOpObserver->endCb(node);
784+
}
785+
}
786+
#endif // FBCODE_CAFFE2
787+
763788
BlockInfo::BlockInfo(uint32_t input_idx, Block& block)
764789
: input_idx_(input_idx), block_(block) {}
765790

@@ -2052,6 +2077,9 @@ std::vector<IValue> ProcessedNode::inputs_ivalue_vec() const {
20522077
}
20532078

20542079
void ProcessedNode::run() {
2080+
#ifdef FBCODE_CAFFE2
2081+
SROperatorObserver::onStart(node());
2082+
#endif
20552083
#ifndef PYTORCH_DISABLE_PER_OP_PROFILING
20562084
auto step_callbacks =
20572085
at::getStepCallbacksUnlessEmpty(at::RecordScope::STATIC_RUNTIME_OP);
@@ -2085,6 +2113,9 @@ void ProcessedNode::run() {
20852113
DCHECK(verify_no_memory_overlap());
20862114
}
20872115
#endif
2116+
#ifdef FBCODE_CAFFE2
2117+
SROperatorObserver::onEnd(node());
2118+
#endif
20882119
}
20892120

20902121
static bool checkNoMemoryOverlap(const at::Tensor& a, const at::Tensor& b) {

torch/csrc/jit/runtime/static/impl.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,19 @@ class StaticRuntime;
248248

249249
using SROperator = std::function<void(ProcessedNode*)>;
250250

251+
#ifdef FBCODE_CAFFE2
252+
struct TORCH_API SROperatorObserver {
253+
using OperatorCallback = void (*)(const Node*);
254+
OperatorCallback startCb = nullptr;
255+
OperatorCallback endCb = nullptr;
256+
257+
static void setCurrentThreadObserver(SROperatorObserver* observer);
258+
static SROperatorObserver* getCurrentThreadObserver();
259+
static void onStart(const Node* name);
260+
static void onEnd(const Node* name);
261+
};
262+
#endif
263+
251264
// A `BlockInfo` instance stores all of the shared state that each
252265
// `BlockRunner` will need to access. Most of this information is
253266
// read-only and shared between threads.

0 commit comments

Comments
 (0)