Skip to content

Commit d3f714e

Browse files
authored
add fused adam (#822)
* add fused adam * remove unused code * remove the fix for launcher script out of this PR * move all scalar value cast to the top of the function * using more meaningful variable names
1 parent 2ac7360 commit d3f714e

File tree

8 files changed

+1058
-34
lines changed

8 files changed

+1058
-34
lines changed

intel_extension_for_pytorch/csrc/aten/cpu/kernels/optimizer/AdamFusedStepKrnl.cpp

Lines changed: 528 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
#include "optimizer.h"
2+
3+
#include <torch/csrc/autograd/function.h>
4+
#include <torch/extension.h>
5+
#include "csrc/utils/ipex_op_profile.h"
6+
7+
namespace torch_ipex {
8+
namespace cpu {
9+
10+
DEFINE_DISPATCH(adam_fused_step_kernel_stub);
11+
12+
void adam_fused_step(
13+
const at::Tensor& param_,
14+
const at::Tensor& exp_avg_,
15+
const at::Tensor& exp_avg_sq_,
16+
const at::Tensor& max_exp_avg_sq_,
17+
const at::Tensor& grad_,
18+
const at::Tensor& param2_,
19+
bool amsgrad,
20+
double step,
21+
double beta1,
22+
double beta2,
23+
double learning_rate,
24+
double weight_decay,
25+
double eps) {
26+
IPEX_RECORD_FUNCTION(
27+
"torch_ipex::adam_fused_step", c10::ArrayRef<c10::IValue>({}));
28+
29+
TORCH_CHECK(
30+
learning_rate >= 0, "Expect learning rate >= 0.0, got ", learning_rate);
31+
TORCH_CHECK(eps >= 0, "Expect eps >= 0.0, got ", eps);
32+
TORCH_CHECK(beta1 >= 0 && beta1 < 1, "Expect 0.0 <= beta1 < 1.0, got", beta1);
33+
TORCH_CHECK(beta2 >= 0 && beta2 < 1, "Expect 0.0 <= beta2 < 1.0, got", beta2);
34+
TORCH_CHECK(
35+
weight_decay >= 0, "Expect weight_decay >= 0.0, got ", weight_decay);
36+
37+
TORCH_CHECK(
38+
param_.sizes() == grad_.sizes(),
39+
"Expect param and grad have the same sizes, param sizes: ",
40+
param_.sizes(),
41+
"; grad sizes: ",
42+
grad_.sizes());
43+
TORCH_CHECK(
44+
param_.sizes() == exp_avg_.sizes(),
45+
"Expect param and exp_avg have the same sizes, param sizes: ",
46+
param_.sizes(),
47+
"; exp_avg sizes: ",
48+
exp_avg_.sizes());
49+
TORCH_CHECK(
50+
param_.sizes() == exp_avg_sq_.sizes(),
51+
"Expect param and exp_avg_sq_ have the same sizes, param sizes: ",
52+
param_.sizes(),
53+
"; exp_avg_sq sizes: ",
54+
exp_avg_sq_.sizes());
55+
if (amsgrad) {
56+
TORCH_CHECK(
57+
param_.sizes() == max_exp_avg_sq_.sizes(),
58+
"Expect param and max_exp_avg_sq_ have the same sizes, param sizes: ",
59+
param_.sizes(),
60+
"; max_exp_avg_sq sizes: ",
61+
max_exp_avg_sq_.sizes());
62+
}
63+
TORCH_CHECK(
64+
param2_.numel() == 0 || param_.sizes() == param2_.sizes(),
65+
"Expect param and param2_ have the same sizes, param sizes: ",
66+
param_.sizes(),
67+
"; param2_ sizes: ",
68+
param2_.sizes());
69+
70+
/*
71+
pointer to adam_fused_step_kernel_impl(
72+
param_,
73+
exp_avg_,
74+
exp_avg_sq_,
75+
max_exp_avg_sq_,
76+
grad_,
77+
param2_,
78+
amsgrad,
79+
step,
80+
beta1,
81+
beta2,
82+
learning_rate,
83+
weight_decay,
84+
eps);
85+
*/
86+
adam_fused_step_kernel_stub(
87+
kCPU,
88+
param_,
89+
exp_avg_,
90+
exp_avg_sq_,
91+
max_exp_avg_sq_,
92+
grad_,
93+
param2_,
94+
amsgrad,
95+
step,
96+
beta1,
97+
beta2,
98+
learning_rate,
99+
weight_decay,
100+
eps);
101+
}
102+
103+
} // namespace cpu
104+
} // namespace torch_ipex
105+
106+
namespace {
107+
108+
TORCH_LIBRARY_FRAGMENT(torch_ipex, m) {
109+
m.def(
110+
"adam_fused_step(Tensor(a!) param, Tensor(b!) exp_avg, Tensor(c!) "
111+
"exp_avg_sq, Tensor(d!) max_exp_avg_sq, Tensor grad, Tensor trail, "
112+
"bool amsgrad, float step, float beta1, float "
113+
"beta2, float lr, float weight_decay, float eps) -> ()",
114+
torch_ipex::cpu::adam_fused_step);
115+
}
116+
117+
} // namespace

intel_extension_for_pytorch/csrc/aten/cpu/optimizer/optimizer.h

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,21 @@ void packed_add_kernel_impl(
4848
const at::Tensor& grad,
4949
double alpha);
5050

51+
void adam_fused_step_kernel_impl(
52+
const at::Tensor& param_,
53+
const at::Tensor& exp_avg_,
54+
const at::Tensor& exp_avg_sq_,
55+
const at::Tensor& max_exp_avg_sq_,
56+
const at::Tensor& grad_,
57+
const at::Tensor& param2_,
58+
bool amsgrad,
59+
double step,
60+
double beta1,
61+
double beta2,
62+
double learning_rate,
63+
double weight_decay,
64+
double eps);
65+
5166
} // namespace
5267

5368
using adagrad_fused_step_kernel_fn = std::tuple<at::Tensor, at::Tensor> (*)(
@@ -93,5 +108,21 @@ using packed_add_kernel_fn =
93108
void (*)(at::Tensor&, at::Tensor&, const at::Tensor&, double);
94109
DECLARE_DISPATCH(packed_add_kernel_fn, packed_add_kernel_stub);
95110

111+
using adam_fused_step_kernel_fn = void (*)(
112+
const at::Tensor&,
113+
const at::Tensor&,
114+
const at::Tensor&,
115+
const at::Tensor&,
116+
const at::Tensor&,
117+
const at::Tensor&,
118+
bool,
119+
double,
120+
double,
121+
double,
122+
double,
123+
double,
124+
double);
125+
DECLARE_DISPATCH(adam_fused_step_kernel_fn, adam_fused_step_kernel_stub);
126+
96127
} // namespace cpu
97128
} // namespace torch_ipex

0 commit comments

Comments
 (0)