Skip to content

Commit 0347299

Browse files
committed
Fix tf.reverse.
1 parent 6ec39ba commit 0347299

File tree

3 files changed

+35
-11
lines changed

3 files changed

+35
-11
lines changed

src/TensorFlowNET.Core/APIs/tf.array.cs

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -162,14 +162,17 @@ public Tensor transpose<T1>(T1 a, Axis perm = null, string name = "transpose", b
162162
/// Reverses specific dimensions of a tensor.
163163
/// </summary>
164164
/// <param name="tensor"></param>
165-
/// <param name="axis"></param>
165+
/// <param name="axis">The indices of the dimensions to reverse. Must be in the range [-rank(tensor), rank(tensor)).</param>
166166
/// <param name="name"></param>
167167
/// <returns></returns>
168-
public Tensor reverse(Tensor tensor, int[] axis, string name = null)
169-
=> gen_array_ops.reverse(tensor, ops.convert_to_tensor(axis), name: name);
170-
171-
public Tensor reverse(Tensor tensor, Tensor axis, string name = null)
172-
=> gen_array_ops.reverse(tensor, axis, name: name);
168+
public Tensor reverse(Tensor tensor, Axis axis, string name = null)
169+
{
170+
if (axis.IsScalar)
171+
{
172+
axis = new Axis(axis.axis);
173+
}
174+
return array_ops.reverse(tensor, axis, name: name);
175+
}
173176

174177
/// <summary>
175178
/// Returns the rank of a tensor.

src/TensorFlowNET.Core/Operations/array_ops.cs

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -413,6 +413,16 @@ public static Tensor reshape(Tensor tensor, object[] shape, string name = null)
413413
return gen_array_ops.reshape(tensor, dims, name: name);
414414
}
415415

416+
public static Tensor reverse(Tensor tensor, Tensor axis, string name = null)
417+
=> tf.Context.ExecuteOp("ReverseV2", name, new ExecuteOpArgs(tensor, axis)
418+
{
419+
GetGradientAttrs = (op) => new
420+
{
421+
T = op.get_attr<TF_DataType>("T"),
422+
Tidx = op.get_attr<TF_DataType>("Tidx")
423+
}
424+
});
425+
416426
private static Tensor ones_like_impl<T>(T tensor, TF_DataType dtype, string name, bool optimize = true)
417427
{
418428
return tf_with(ops.name_scope(name, "ones_like", new { tensor }), scope =>
@@ -658,19 +668,17 @@ public static Tensor tile(Tensor input, Tensor multiples, string name = null)
658668
}
659669
});
660670

661-
public static Tensor tile(Tensor input, object[] multiples, string name = null)
671+
/*public static Tensor tile(Tensor input, Shape multiples, string name = null)
662672
{
663-
Shape dims = shape_utils.from_object_array(multiples);
664-
665-
return tf.Context.ExecuteOp("Tile", name, new ExecuteOpArgs(input, dims)
673+
return tf.Context.ExecuteOp("Tile", name, new ExecuteOpArgs(input, multiples)
666674
{
667675
GetGradientAttrs = (op) => new
668676
{
669677
T = op.get_attr<TF_DataType>("T"),
670678
Tmultiples = op.get_attr<TF_DataType>("Tmultiples")
671679
}
672680
});
673-
}
681+
}*/
674682

675683
public static Tensor zeros_like(Tensor tensor, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool optimize = true)
676684
{

test/TensorFlowNET.UnitTest/ManagedAPI/ArrayOpsTest.cs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
using Tensorflow.NumPy;
33
using Tensorflow;
44
using static Tensorflow.Binding;
5+
using System.Linq;
56

67
namespace TensorFlowNET.UnitTest.ManagedAPI
78
{
@@ -92,5 +93,17 @@ public void TensorArray()
9293
Assert.AreEqual(ta.read(1).numpy(), 20f);
9394
Assert.AreEqual(ta.read(2).numpy(), 30f);
9495
}
96+
97+
/// <summary>
98+
/// https://www.tensorflow.org/api_docs/python/tf/reverse
99+
/// </summary>
100+
[TestMethod]
101+
public void ReverseArray()
102+
{
103+
var a = tf.random.normal((2, 3));
104+
var b = tf.reverse(a, -1);
105+
Assert.IsTrue(Equal(a[0].ToArray<float>().Reverse().ToArray(), b[0].ToArray<float>()));
106+
Assert.IsTrue(Equal(a[1].ToArray<float>().Reverse().ToArray(), b[1].ToArray<float>()));
107+
}
95108
}
96109
}

0 commit comments

Comments
 (0)