Skip to content

Commit 3088f07

Browse files
committed
[OPS] Added a type constraint in the TF records parsing ops.
1 parent 537a3c7 commit 3088f07

File tree

3 files changed

+15
-8
lines changed

3 files changed

+15
-8
lines changed

modules/api/src/main/scala/org/platanios/tensorflow/api/core/types/package.scala

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ package object types {
158158
type UByteOrUShort = Union[UByte]#or[UShort]#create
159159
type Integer = Union[Byte]#or[Short]#or[Int]#or[Long]#or[UByte]#or[UShort]#or[UInt]#or[ULong]#create
160160
type StringOrInteger = Union[String]#or[Byte]#or[Short]#or[Int]#or[Long]#or[UByte]#or[UShort]#or[UInt]#or[ULong]#create
161+
type StringOrFloatOrLong = Union[String]#or[Float]#or[Long]#create
161162
type Real = Union[TruncatedHalf]#or[Half]#or[Float]#or[Double]#or[Byte]#or[Short]#or[Int]#or[Long]#or[UByte]#or[UShort]#or[UInt]#or[ULong]#create
162163
type Complex = Union[ComplexFloat]#or[ComplexDouble]#create
163164
type NotQuantized = Union[TruncatedHalf]#or[Half]#or[Float]#or[Double]#or[Byte]#or[Short]#or[Int]#or[Long]#or[UByte]#or[UShort]#or[UInt]#or[ULong]#or[ComplexFloat]#or[ComplexDouble]#create
@@ -177,7 +178,8 @@ package object types {
177178
type IsIntOrLongOrUByte[T] = Contains[T, IntOrLongOrUByte]
178179
type IsIntOrUInt[T] = Contains[T, Integer]
179180
type IsUByteOrUShort[T] = Contains[T, UByteOrUShort]
180-
type IsStringOrIntOrUInt[T] = Contains[T, StringOrInteger]
181+
type IsStringOrInteger[T] = Contains[T, StringOrInteger]
182+
type IsStringOrFloatOrLong[T] = Contains[T, StringOrFloatOrLong]
181183
type IsReal[T] = Contains[T, Real]
182184
type IsComplex[T] = Contains[T, Complex]
183185
type IsNotQuantized[T] = Contains[T, NotQuantized]
@@ -233,8 +235,12 @@ package object types {
233235
def apply[T: IsUByteOrUShort]: IsUByteOrUShort[T] = implicitly[IsUByteOrUShort[T]]
234236
}
235237

236-
object IsStringOrIntOrUInt {
237-
def apply[T: IsStringOrIntOrUInt]: IsStringOrIntOrUInt[T] = implicitly[IsStringOrIntOrUInt[T]]
238+
object IsStringOrInteger {
239+
def apply[T: IsStringOrInteger]: IsStringOrInteger[T] = implicitly[IsStringOrInteger[T]]
240+
}
241+
242+
object IsStringOrFloatOrLong {
243+
def apply[T: IsStringOrFloatOrLong]: IsStringOrFloatOrLong[T] = implicitly[IsStringOrFloatOrLong[T]]
238244
}
239245

240246
object IsReal {

modules/api/src/main/scala/org/platanios/tensorflow/api/ops/Parsing.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ object Parsing extends Parsing {
272272
* @param defaultValue Value to be used if an example is missing this feature. It must match the specified `shape`.
273273
* @tparam T Data type of the input feature.
274274
*/
275-
case class FixedLengthFeature[T: TF](
275+
case class FixedLengthFeature[T: TF : IsStringOrFloatOrLong](
276276
key: String,
277277
shape: Shape,
278278
defaultValue: Option[Tensor[T]] = None
@@ -285,7 +285,7 @@ object Parsing extends Parsing {
285285
*
286286
* @param dataType Data type of the input feature.
287287
*/
288-
case class VariableLengthFeature[T: TF](
288+
case class VariableLengthFeature[T: TF : IsStringOrFloatOrLong](
289289
key: String,
290290
dataType: DataType[T]
291291
) extends Feature
@@ -341,7 +341,7 @@ object Parsing extends Parsing {
341341
* position. If so, we skip sorting.
342342
* @tparam T Data type of the `valueKey` feature.
343343
*/
344-
case class SparseFeature[T: TF](
344+
case class SparseFeature[T: TF : IsStringOrFloatOrLong](
345345
indexKeys: Seq[String],
346346
valueKey: String,
347347
size: Seq[Long],

modules/api/src/main/scala/org/platanios/tensorflow/api/package.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,8 @@ package object api extends implicits.Implicits with Documentation {
212212
type IsIntOrLongOrHalfOrFloatOrDouble[T] = core.types.IsIntOrLongOrHalfOrFloatOrDouble[T]
213213
type IsIntOrLongOrUByte[T] = core.types.IsIntOrLongOrUByte[T]
214214
type IsIntOrUInt[T] = core.types.IsIntOrUInt[T]
215-
type IsStringOrIntOrUInt[T] = core.types.IsStringOrIntOrUInt[T]
215+
type IsStringOrInteger[T] = core.types.IsStringOrInteger[T]
216+
type IsStringOrFloatOrLong[T] = core.types.IsStringOrFloatOrLong[T]
216217
type IsReal[T] = core.types.IsReal[T]
217218
type IsComplex[T] = core.types.IsComplex[T]
218219
type IsNotQuantized[T] = core.types.IsNotQuantized[T]
@@ -231,7 +232,7 @@ package object api extends implicits.Implicits with Documentation {
231232
val IsIntOrLongOrHalfOrFloatOrDouble: core.types.IsIntOrLongOrHalfOrFloatOrDouble.type = core.types.IsIntOrLongOrHalfOrFloatOrDouble
232233
val IsIntOrLongOrUByte : core.types.IsIntOrLongOrUByte.type = core.types.IsIntOrLongOrUByte
233234
val IsIntOrUInt : core.types.IsIntOrUInt.type = core.types.IsIntOrUInt
234-
val IsStringOrIntOrUInt : core.types.IsStringOrIntOrUInt.type = core.types.IsStringOrIntOrUInt
235+
val IsStringOrFloatOrLong : core.types.IsStringOrFloatOrLong.type = core.types.IsStringOrFloatOrLong
235236
val IsReal : core.types.IsReal.type = core.types.IsReal
236237
val IsComplex : core.types.IsComplex.type = core.types.IsComplex
237238
val IsNotQuantized : core.types.IsNotQuantized.type = core.types.IsNotQuantized

0 commit comments

Comments
 (0)