Skip to content
Kelang edited this page Aug 4, 2020 · 3 revisions

This is some text! This is some text! This is some text!

定义表格、行、列
bgcolor=red>Month Savings
January $100
#include <iostream>

#include "torch/script.h" #include "torch/optim.h"

int main() { int N = 64, D_in = 1000, H = 100, D_out = 10; double learning_rate = 1e-3;

auto x = torch::randn({N, D_in}, at::TensorOptions().requires_grad(false)); auto y = torch::randn({N, D_out}, at::TensorOptions().requires_grad(false));

// The Adam optimizer wants parameters in a std::vector. std::vector<at::Tensor> params = { torch::randn({D_in, H}, at::TensorOptions().requires_grad(true)), torch::randn({H, D_out}, at::TensorOptions().requires_grad(true))};

// Build the optimizer. torch::optim::Adam adam(params, torch::optim::AdamOptions(learning_rate));

// Make quick references for using in the forward pass. const at::Tensor & w1 = adam.parameters()[0]; const at::Tensor & w2 = adam.parameters()[1];

for (int i = 0; i < 500; ++i) { auto y_pred = at::mm(at::clamp(at::mm(x, w1), 0), w2); auto loss = at::sum(at::pow(at::sub(y_pred, y), 2));

<span class="pl-k">if</span> ((i % <span class="pl-c1">100</span>) == <span class="pl-c1">99</span>) {
  std::cout &lt;&lt; <span class="pl-s"><span class="pl-pds">"</span>loss = <span class="pl-pds">"</span></span> &lt;&lt; loss &lt;&lt; std::endl;
}

adam.<span class="pl-c1">zero_grad</span>();
loss.<span class="pl-c1">backward</span>();
adam.<span class="pl-c1">step</span>();

} return 0; }

package main

import ( "fmt"

at <span class="pl-s">"github.com/gotorch/gotorch/aten"</span>
<span class="pl-s">"github.com/gotorch/gotorch/torch"</span>
<span class="pl-s">"github.com/gotorch/gotorch/torch/optim"</span>

)

func main() { N, D_in, H, D_out := 64, 1000, 100, 10 learning_rate := 1e-3

<span class="pl-s1">x</span> <span class="pl-c1">:=</span> <span class="pl-s1">torch</span>.<span class="pl-en">RandN</span>([]<span class="pl-smi">int</span>{<span class="pl-s1">N</span>, <span class="pl-s1">Din</span>},
	<span class="pl-s1">at</span>.<span class="pl-en">TensorOptions</span>().<span class="pl-en">RequiresGrad</span>(<span class="pl-c1">false</span>))
<span class="pl-s1">y</span> <span class="pl-c1">:=</span> <span class="pl-s1">torch</span>.<span class="pl-en">RandN</span>([]<span class="pl-smi">int</span>{<span class="pl-s1">N</span>, <span class="pl-s1">Dout</span>},
	<span class="pl-s1">at</span>.<span class="pl-en">TensorOptions</span>().<span class="pl-en">RequiresGrad</span>(<span class="pl-c1">false</span>))

<span class="pl-s1">params</span> <span class="pl-c1">:=</span> []at.<span class="pl-smi">Tensor</span>{
	<span class="pl-s1">torch</span>.<span class="pl-en">RandN</span>([]<span class="pl-smi">int</span>{<span class="pl-s1">Din</span>, <span class="pl-s1">H</span>},
		<span class="pl-s1">at</span>.<span class="pl-en">TensorOptions</span>().<span class="pl-en">RequiresGrad</span>(<span class="pl-c1">true</span>)),
	<span class="pl-s1">torch</span>.<span class="pl-en">RandN</span>([]<span class="pl-smi">int</span>{<span class="pl-s1">H</span>, <span class="pl-s1">Dout</span>},
		<span class="pl-s1">at</span>.<span class="pl-en">TensorOptions</span>().<span class="pl-en">RequiresGrad</span>(<span class="pl-c1">true</span>)),
}

<span class="pl-s1">adam</span> <span class="pl-c1">:=</span> <span class="pl-s1">optim</span>.<span class="pl-en">NewAdam</span>(<span class="pl-s1">params</span>, <span class="pl-s1">optim</span>.<span class="pl-en">AdamOptions</span>(<span class="pl-s1">learning_rate</span>))

<span class="pl-s1">w1</span> <span class="pl-c1">:=</span> <span class="pl-s1">adam</span>.<span class="pl-en">parameters</span>()[<span class="pl-c1">0</span>]
<span class="pl-s1">w2</span> <span class="pl-c1">:=</span> <span class="pl-s1">adam</span>.<span class="pl-en">parameters</span>()[<span class="pl-c1">1</span>]

<span class="pl-k">for</span> <span class="pl-s1">i</span> <span class="pl-c1">:=</span> <span class="pl-c1">0</span>; <span class="pl-s1">i</span> <span class="pl-c1">&lt;</span> <span class="pl-c1">500</span>; <span class="pl-s1">i</span><span class="pl-c1">++</span> {
	<span class="pl-s1">y_pred</span> <span class="pl-c1">:=</span> <span class="pl-s1">at</span>.<span class="pl-en">Sum</span>(<span class="pl-s1">at</span>.<span class="pl-en">Clamp</span>(<span class="pl-s1">at</span>.<span class="pl-en">MM</span>(<span class="pl-s1">x</span>, <span class="pl-s1">w1</span>), <span class="pl-c1">0</span>), <span class="pl-s1">w2</span>)
	<span class="pl-s1">loss</span> <span class="pl-c1">:=</span> <span class="pl-s1">at</span>.<span class="pl-en">Sum</span>(<span class="pl-s1">at</span>.<span class="pl-en">Pow</span>(<span class="pl-s1">at</span>.<span class="pl-en">Sub</span>(<span class="pl-s1">y_pred</span>, <span class="pl-s1">y</span>), <span class="pl-c1">2</span>))

	<span class="pl-k">if</span> <span class="pl-s1">i</span><span class="pl-c1">%</span><span class="pl-c1">100</span> <span class="pl-c1">==</span> <span class="pl-c1">0</span> {
		<span class="pl-s1">fmt</span>.<span class="pl-en">Println</span>(<span class="pl-s">"loss = "</span>, <span class="pl-s1">loss</span>)
	}

	<span class="pl-s1">adam</span>.<span class="pl-en">ZeroGrad</span>()
	<span class="pl-s1">loss</span>.<span class="pl-en">Backward</span>()
	<span class="pl-s1">adam</span>.<span class="pl-en">Step</span>()
}

}

Clone this wiki locally