-
Notifications
You must be signed in to change notification settings - Fork 115
test
This is some text! This is some text! This is some text!
定义表格、行、列bgcolor=red>Month | Savings |
---|---|
January | $100 |
#include <iostream>#include "torch/script.h" #include "torch/optim.h"
int main() { int N = 64, D_in = 1000, H = 100, D_out = 10; double learning_rate = 1e-3;
auto x = torch::randn({N, D_in}, at::TensorOptions().requires_grad(false)); auto y = torch::randn({N, D_out}, at::TensorOptions().requires_grad(false));
// The Adam optimizer wants parameters in a std::vector. std::vector<at::Tensor> params = { torch::randn({D_in, H}, at::TensorOptions().requires_grad(true)), torch::randn({H, D_out}, at::TensorOptions().requires_grad(true))};
// Build the optimizer. torch::optim::Adam adam(params, torch::optim::AdamOptions(learning_rate));
// Make quick references for using in the forward pass. const at::Tensor & w1 = adam.parameters()[0]; const at::Tensor & w2 = adam.parameters()[1];
for (int i = 0; i < 500; ++i) { auto y_pred = at::mm(at::clamp(at::mm(x, w1), 0), w2); auto loss = at::sum(at::pow(at::sub(y_pred, y), 2));
<span class="pl-k">if</span> ((i % <span class="pl-c1">100</span>) == <span class="pl-c1">99</span>) { std::cout << <span class="pl-s"><span class="pl-pds">"</span>loss = <span class="pl-pds">"</span></span> << loss << std::endl; } adam.<span class="pl-c1">zero_grad</span>(); loss.<span class="pl-c1">backward</span>(); adam.<span class="pl-c1">step</span>();
} return 0; }
package mainimport ( "fmt"
at <span class="pl-s">"github.com/gotorch/gotorch/aten"</span> <span class="pl-s">"github.com/gotorch/gotorch/torch"</span> <span class="pl-s">"github.com/gotorch/gotorch/torch/optim"</span>
)
func main() { N, D_in, H, D_out := 64, 1000, 100, 10 learning_rate := 1e-3
<span class="pl-s1">x</span> <span class="pl-c1">:=</span> <span class="pl-s1">torch</span>.<span class="pl-en">RandN</span>([]<span class="pl-smi">int</span>{<span class="pl-s1">N</span>, <span class="pl-s1">Din</span>}, <span class="pl-s1">at</span>.<span class="pl-en">TensorOptions</span>().<span class="pl-en">RequiresGrad</span>(<span class="pl-c1">false</span>)) <span class="pl-s1">y</span> <span class="pl-c1">:=</span> <span class="pl-s1">torch</span>.<span class="pl-en">RandN</span>([]<span class="pl-smi">int</span>{<span class="pl-s1">N</span>, <span class="pl-s1">Dout</span>}, <span class="pl-s1">at</span>.<span class="pl-en">TensorOptions</span>().<span class="pl-en">RequiresGrad</span>(<span class="pl-c1">false</span>)) <span class="pl-s1">params</span> <span class="pl-c1">:=</span> []at.<span class="pl-smi">Tensor</span>{ <span class="pl-s1">torch</span>.<span class="pl-en">RandN</span>([]<span class="pl-smi">int</span>{<span class="pl-s1">Din</span>, <span class="pl-s1">H</span>}, <span class="pl-s1">at</span>.<span class="pl-en">TensorOptions</span>().<span class="pl-en">RequiresGrad</span>(<span class="pl-c1">true</span>)), <span class="pl-s1">torch</span>.<span class="pl-en">RandN</span>([]<span class="pl-smi">int</span>{<span class="pl-s1">H</span>, <span class="pl-s1">Dout</span>}, <span class="pl-s1">at</span>.<span class="pl-en">TensorOptions</span>().<span class="pl-en">RequiresGrad</span>(<span class="pl-c1">true</span>)), } <span class="pl-s1">adam</span> <span class="pl-c1">:=</span> <span class="pl-s1">optim</span>.<span class="pl-en">NewAdam</span>(<span class="pl-s1">params</span>, <span class="pl-s1">optim</span>.<span class="pl-en">AdamOptions</span>(<span class="pl-s1">learning_rate</span>)) <span class="pl-s1">w1</span> <span class="pl-c1">:=</span> <span class="pl-s1">adam</span>.<span class="pl-en">parameters</span>()[<span class="pl-c1">0</span>] <span class="pl-s1">w2</span> <span class="pl-c1">:=</span> <span class="pl-s1">adam</span>.<span class="pl-en">parameters</span>()[<span class="pl-c1">1</span>] <span class="pl-k">for</span> <span class="pl-s1">i</span> <span class="pl-c1">:=</span> <span class="pl-c1">0</span>; <span class="pl-s1">i</span> <span class="pl-c1"><</span> <span class="pl-c1">500</span>; <span class="pl-s1">i</span><span class="pl-c1">++</span> { <span class="pl-s1">y_pred</span> <span class="pl-c1">:=</span> <span class="pl-s1">at</span>.<span class="pl-en">Sum</span>(<span class="pl-s1">at</span>.<span class="pl-en">Clamp</span>(<span class="pl-s1">at</span>.<span class="pl-en">MM</span>(<span class="pl-s1">x</span>, <span class="pl-s1">w1</span>), <span class="pl-c1">0</span>), <span class="pl-s1">w2</span>) <span class="pl-s1">loss</span> <span class="pl-c1">:=</span> <span class="pl-s1">at</span>.<span class="pl-en">Sum</span>(<span class="pl-s1">at</span>.<span class="pl-en">Pow</span>(<span class="pl-s1">at</span>.<span class="pl-en">Sub</span>(<span class="pl-s1">y_pred</span>, <span class="pl-s1">y</span>), <span class="pl-c1">2</span>)) <span class="pl-k">if</span> <span class="pl-s1">i</span><span class="pl-c1">%</span><span class="pl-c1">100</span> <span class="pl-c1">==</span> <span class="pl-c1">0</span> { <span class="pl-s1">fmt</span>.<span class="pl-en">Println</span>(<span class="pl-s">"loss = "</span>, <span class="pl-s1">loss</span>) } <span class="pl-s1">adam</span>.<span class="pl-en">ZeroGrad</span>() <span class="pl-s1">loss</span>.<span class="pl-en">Backward</span>() <span class="pl-s1">adam</span>.<span class="pl-en">Step</span>() }
}