Skip to content

Commit 6fca645

Browse files
committed
[OPS] Fixed a bug in the beam search decoder.
1 parent 61f3490 commit 6fca645

File tree

2 files changed

+12
-9
lines changed

2 files changed

+12
-9
lines changed

modules/api/src/main/scala/org/platanios/tensorflow/api/core/types/package.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ package org.platanios.tensorflow.api.core
1717

1818
import org.tensorflow.framework.DataType._
1919

20+
import scala.annotation.implicitNotFound
21+
2022
/**
2123
* @author Emmanouil Antonios Platanios
2224
*/
@@ -73,6 +75,7 @@ package object types {
7375

7476
//region Type Traits
7577

78+
@implicitNotFound(msg = "Cannot prove that ${T} is a supported TensorFlow data type.")
7679
trait TF[T] {
7780
@inline def dataType: org.platanios.tensorflow.api.core.types.DataType[T]
7881
}

modules/api/src/main/scala/org/platanios/tensorflow/api/ops/seq2seq/decoders/BeamSearchDecoder.scala

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -155,12 +155,12 @@ class BeamSearchDecoder[T: TF, State: OutputStructure, StateShape](
155155
val initialState = BeamSearchDecoder.BeamSearchDecoderState[State](
156156
modelState = processedInitialCellState,
157157
logProbabilities = Basic.oneHot[Float, Int](
158-
indices = Basic.zeros[Int, Int](batchSize.expandDims(0)),
158+
indices = Basic.zeros[Int](batchSize.expandDims(0)),
159159
depth = beamWidth,
160160
onValue = Basic.zeros[Float](Shape()),
161161
offValue = Basic.constant(Float.MinValue)),
162162
finished = finished,
163-
sequenceLengths = Basic.zeros[Int, Int](Basic.stack[Int](Seq(batchSize, beamWidth))))
163+
sequenceLengths = Basic.zeros[Int](Basic.stack[Int](Seq(batchSize, beamWidth))))
164164
(finished, beginInput, initialState)
165165
}
166166
}
@@ -183,8 +183,8 @@ class BeamSearchDecoder[T: TF, State: OutputStructure, StateShape](
183183
BeamSearchDecoder.MaybeTensorConverter(
184184
BeamSearchDecoder.MergeBatchBeamsConverter(batchSize, beamWidth)))
185185
val mergedNextTuple = cell(Tuple(mergedInput, mergedCellState))
186-
val nextTupleOutput = BeamSearchDecoder.SplitBatchBeamsConverter(batchSize, beamWidth)(
187-
mergedNextTuple.output, Some(mergedNextTuple.output.shape(1 ::)))
186+
val nextTupleOutput = outputLayer(BeamSearchDecoder.SplitBatchBeamsConverter(batchSize, beamWidth)(
187+
mergedNextTuple.output, Some(mergedNextTuple.output.shape(1 ::))))
188188
val nextTupleState = evOutputToShapeState.map(
189189
mergedNextTuple.state, Some(cell.stateShape),
190190
BeamSearchDecoder.MaybeTensorConverter(
@@ -199,15 +199,15 @@ class BeamSearchDecoder[T: TF, State: OutputStructure, StateShape](
199199

200200
// Calculate the total log probabilities for the new hypotheses (final shape = [batchSize, beamWidth, vocabSize]).
201201
val stepLogProbabilities = BeamSearchDecoder.maskLogProbabilities(
202-
NN.logSoftmax(nextTupleOutput.castTo[Float]), endToken, previouslyFinished)
202+
NN.logSoftmax(nextTupleOutput.toFloat), endToken, previouslyFinished)
203203
val totalLogProbabilities = state.logProbabilities.expandDims(Output.constant[Int](2)) + stepLogProbabilities
204204

205205
// Calculate the continuation lengths by adding to all continuing search states.
206206
val vocabSize = {
207207
if (nextTupleOutput.shape(-1) != -1)
208208
Basic.constant(nextTupleOutput.shape(-1))
209209
else
210-
Basic.shape(nextTupleOutput).castTo[Int].slice(-1)
210+
Basic.shape(nextTupleOutput).toInt.slice(-1)
211211
}
212212

213213
var lengthsToAdd = Basic.oneHot[Int, Int](
@@ -221,7 +221,7 @@ class BeamSearchDecoder[T: TF, State: OutputStructure, StateShape](
221221
predictionLengths.expandDims(2))
222222

223223
// Calculate the scores for each search state.
224-
val scores = lengthPenalty(totalLogProbabilities, newPredictionLengths).castTo[Float]
224+
val scores = lengthPenalty(totalLogProbabilities, newPredictionLengths).toFloat
225225

226226
// During the first time step we only consider the initial search state.
227227
val scoresFlat = Basic.reshape(scores, Basic.stack[Int](Seq(batchSize, -1)))
@@ -240,8 +240,8 @@ class BeamSearchDecoder[T: TF, State: OutputStructure, StateShape](
240240
rangeSize = vocabSize * beamWidth,
241241
gatherShape = Seq(-1),
242242
name = "NextBeamLogProbabilities")
243-
val nextPredictedIDs = Math.mod(wordIndices, vocabSize, name = "NextBeamPredictedIDs").castTo[Int]
244-
val nextParentIDs = Math.divide(wordIndices, vocabSize, name = "NextBeamParentIDs").castTo[Int]
243+
val nextPredictedIDs = Math.mod(wordIndices, vocabSize, name = "NextBeamPredictedIDs").toInt
244+
val nextParentIDs = Math.divide(wordIndices, vocabSize, name = "NextBeamParentIDs").toInt
245245

246246
// Append the new IDs to the current predictions.
247247
val gatheredFinished = BeamSearchDecoder.gather(

0 commit comments

Comments
 (0)