Skip to content

Add a new optimizer #1136

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions src/TensorFlowNET.Core/Keras/IOptimizerApi.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,27 @@ IOptimizer Adam(float learning_rate = 0.001f,
bool amsgrad = false,
string name = "Adam");

/// <summary>
/// Adam enables L2 weight decay on gradients.
/// </summary>
/// <param name="learning_rate"></param>
/// <param name="weight_decay"></param>
/// <param name="beta_1"></param>
/// <param name="beta_2"></param>
/// <param name="epsilon"></param>
/// <param name="amsgrad"></param>
/// <param name="decay_params"></param>
/// <param name="name"></param>
/// <returns></returns>
IOptimizer AdamW(float learning_rate = 0.001f,
float weight_decay = 0.004f,
float beta_1 = 0.9f,
float beta_2 = 0.999f,
float epsilon = 1e-7f,
bool amsgrad = false,
List<string> no_decay_params = null,
string name = "AdamW");

/// <summary>
/// Construct a new RMSprop optimizer.
/// </summary>
Expand Down
64 changes: 64 additions & 0 deletions src/TensorFlowNET.Keras/Optimizers/AdamW.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
namespace Tensorflow.Keras.Optimizers
{
public class AdamW : Adam
{
string name;
float weight_decay;
DeviceDType deType;
List<string> no_decay_params = null;
public AdamW(float learning_rate= 0.001f,
float weight_decay= 0.004f,
float beta_1= 0.9f,
float beta_2= 0.999f,
float epsilon= 1e-7f,
bool amsgrad = false,
List<string> no_decay_params = null,
string name= "AdamW") : base(learning_rate, beta_1, beta_2, epsilon, amsgrad)
{
this.name = name;
this.weight_decay = weight_decay;
this.no_decay_params = no_decay_params;
}

protected Operation _decay_weights_op(IVariableV1 var, float learning_rate, Dictionary<DeviceDType, Dictionary<string, Tensor>> apply_state)
{
bool do_decay = _do_use_weight_decay(var.Name);
if (do_decay) return var.assign_add(
-learning_rate * var.AsTensor() * apply_state[deType]["weight_decay"]);
return tf.no_op();
}


protected bool _do_use_weight_decay(string param_name)
{
// Whether to use L2 weight decay for `param_name`.
if (this.weight_decay == 0)
return false;

if (this.no_decay_params != null)
{
foreach (var name in no_decay_params)
{
if (param_name.Contains(name)) return false;
}

}
return true;
}

protected override Operation _resource_apply_dense(IVariableV1 var, Tensor grad, Dictionary<DeviceDType, Dictionary<string, Tensor>> apply_state)
{
var decay = _decay_weights_op(var, _hyper["learning_rate"], apply_state);
tf.control_dependencies(new[] { decay });
return base._resource_apply_dense(var, grad, apply_state);
}

protected override void _prepare_local(DeviceDType device_dtype, Dictionary<DeviceDType, Dictionary<string, Tensor>> apply_state)
{
this.deType = device_dtype;
base._prepare_local(device_dtype, apply_state);
apply_state[device_dtype]["weight_decay"] = tf.constant(
weight_decay, name: "adam_weight_decay_rate");
}
}
}
16 changes: 16 additions & 0 deletions src/TensorFlowNET.Keras/Optimizers/OptimizerApi.cs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,22 @@ public IOptimizer Adam(float learning_rate = 0.001f,
amsgrad: amsgrad,
name: name);

public IOptimizer AdamW(float learning_rate = 0.001f,
float weight_decay = 0.004f,
float beta_1 = 0.9f,
float beta_2 = 0.999f,
float epsilon = 1e-7f,
bool amsgrad = false,
List<string> no_decay_params = null,
string name = "AdamW") => new AdamW(learning_rate: learning_rate,
beta_1: beta_1,
beta_2: beta_2,
epsilon: epsilon,
amsgrad: amsgrad,
name: name,
weight_decay: weight_decay,
no_decay_params: no_decay_params);

/// <summary>
/// Construct a new RMSprop optimizer.
/// </summary>
Expand Down