Skip to content

Commit fd82f11

Browse files
JacobSzwejbkapytorchmergebot
authored andcommitted
[lite interpreter][hack] Add batch_norm_update_stats if batchnorm and training are present (#100134)
Summary: not sure how the train bool to batch_norm gets set. But its not the is_training module level flag. We get weird behavior for teams trying to do on device training because of this Test Plan: ci Differential Revision: D45335791 Pull Request resolved: #100134 Approved by: https://github.com/larryliu0820
1 parent d5bd236 commit fd82f11

File tree

1 file changed

+40
-0
lines changed

1 file changed

+40
-0
lines changed

torch/csrc/jit/mobile/model_tracer/TracerRunner.cpp

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,44 @@ void call_setup_methods() {
9393
torch::from_blob(storage.data(), at::IntArrayRef(sizes), at::kFloat);
9494
}
9595

96+
/**
97+
* Similar to setup methods there are a suite a functions that often appear
98+
* under certain conditions but may avoid getting called in the trace due to the
99+
* narrow nature of bundled inputs
100+
*/
101+
void call_dependent_methods(std::set<std::string>& root_ops) {
102+
bool is_training = false;
103+
bool has_batchnorm = false;
104+
bool has_dropout = false;
105+
for (const std::string& op : root_ops) {
106+
if (op.find("backward") != std::string::npos ||
107+
op.find("requires_grad_") != std::string::npos) {
108+
is_training = true;
109+
}
110+
if (op.find("batch_norm") != std::string::npos) {
111+
has_batchnorm = true;
112+
}
113+
if (op.find("dropout") != std::string::npos) {
114+
has_dropout = true;
115+
}
116+
}
117+
if (is_training && has_batchnorm) {
118+
at::batch_norm(
119+
at::ones({2, 2}),
120+
c10::nullopt,
121+
c10::nullopt,
122+
c10::nullopt,
123+
c10::nullopt,
124+
true,
125+
0.1,
126+
0.1,
127+
false);
128+
}
129+
if (is_training && has_dropout) {
130+
at::dropout(at::ones({20, 20, 20}), 0.2, true);
131+
}
132+
}
133+
96134
/**
97135
* Call methods on the Tensor object that we expect to be called
98136
* in production on this Tensor.
@@ -307,6 +345,8 @@ TracerResult trace_run(const std::vector<std::string>& input_module_paths) {
307345
}
308346
}
309347

348+
call_dependent_methods(root_ops);
349+
310350
op_tracer.getCalledOperators().withLock(
311351
[&](std::set<std::string>& called_operators) {
312352
traced_operators = called_operators;

0 commit comments

Comments
 (0)