@@ -155,12 +155,12 @@ class BeamSearchDecoder[T: TF, State: OutputStructure, StateShape](
155
155
val initialState = BeamSearchDecoder .BeamSearchDecoderState [State ](
156
156
modelState = processedInitialCellState,
157
157
logProbabilities = Basic .oneHot[Float , Int ](
158
- indices = Basic .zeros[Int , Int ](batchSize.expandDims(0 )),
158
+ indices = Basic .zeros[Int ](batchSize.expandDims(0 )),
159
159
depth = beamWidth,
160
160
onValue = Basic .zeros[Float ](Shape ()),
161
161
offValue = Basic .constant(Float .MinValue )),
162
162
finished = finished,
163
- sequenceLengths = Basic .zeros[Int , Int ](Basic .stack[Int ](Seq (batchSize, beamWidth))))
163
+ sequenceLengths = Basic .zeros[Int ](Basic .stack[Int ](Seq (batchSize, beamWidth))))
164
164
(finished, beginInput, initialState)
165
165
}
166
166
}
@@ -183,8 +183,8 @@ class BeamSearchDecoder[T: TF, State: OutputStructure, StateShape](
183
183
BeamSearchDecoder .MaybeTensorConverter (
184
184
BeamSearchDecoder .MergeBatchBeamsConverter (batchSize, beamWidth)))
185
185
val mergedNextTuple = cell(Tuple (mergedInput, mergedCellState))
186
- val nextTupleOutput = BeamSearchDecoder .SplitBatchBeamsConverter (batchSize, beamWidth)(
187
- mergedNextTuple.output, Some (mergedNextTuple.output.shape(1 :: )))
186
+ val nextTupleOutput = outputLayer( BeamSearchDecoder .SplitBatchBeamsConverter (batchSize, beamWidth)(
187
+ mergedNextTuple.output, Some (mergedNextTuple.output.shape(1 :: ))))
188
188
val nextTupleState = evOutputToShapeState.map(
189
189
mergedNextTuple.state, Some (cell.stateShape),
190
190
BeamSearchDecoder .MaybeTensorConverter (
@@ -199,15 +199,15 @@ class BeamSearchDecoder[T: TF, State: OutputStructure, StateShape](
199
199
200
200
// Calculate the total log probabilities for the new hypotheses (final shape = [batchSize, beamWidth, vocabSize]).
201
201
val stepLogProbabilities = BeamSearchDecoder .maskLogProbabilities(
202
- NN .logSoftmax(nextTupleOutput.castTo[ Float ] ), endToken, previouslyFinished)
202
+ NN .logSoftmax(nextTupleOutput.toFloat ), endToken, previouslyFinished)
203
203
val totalLogProbabilities = state.logProbabilities.expandDims(Output .constant[Int ](2 )) + stepLogProbabilities
204
204
205
205
// Calculate the continuation lengths by adding to all continuing search states.
206
206
val vocabSize = {
207
207
if (nextTupleOutput.shape(- 1 ) != - 1 )
208
208
Basic .constant(nextTupleOutput.shape(- 1 ))
209
209
else
210
- Basic .shape(nextTupleOutput).castTo[ Int ] .slice(- 1 )
210
+ Basic .shape(nextTupleOutput).toInt .slice(- 1 )
211
211
}
212
212
213
213
var lengthsToAdd = Basic .oneHot[Int , Int ](
@@ -221,7 +221,7 @@ class BeamSearchDecoder[T: TF, State: OutputStructure, StateShape](
221
221
predictionLengths.expandDims(2 ))
222
222
223
223
// Calculate the scores for each search state.
224
- val scores = lengthPenalty(totalLogProbabilities, newPredictionLengths).castTo[ Float ]
224
+ val scores = lengthPenalty(totalLogProbabilities, newPredictionLengths).toFloat
225
225
226
226
// During the first time step we only consider the initial search state.
227
227
val scoresFlat = Basic .reshape(scores, Basic .stack[Int ](Seq (batchSize, - 1 )))
@@ -240,8 +240,8 @@ class BeamSearchDecoder[T: TF, State: OutputStructure, StateShape](
240
240
rangeSize = vocabSize * beamWidth,
241
241
gatherShape = Seq (- 1 ),
242
242
name = " NextBeamLogProbabilities" )
243
- val nextPredictedIDs = Math .mod(wordIndices, vocabSize, name = " NextBeamPredictedIDs" ).castTo[ Int ]
244
- val nextParentIDs = Math .divide(wordIndices, vocabSize, name = " NextBeamParentIDs" ).castTo[ Int ]
243
+ val nextPredictedIDs = Math .mod(wordIndices, vocabSize, name = " NextBeamPredictedIDs" ).toInt
244
+ val nextParentIDs = Math .divide(wordIndices, vocabSize, name = " NextBeamParentIDs" ).toInt
245
245
246
246
// Append the new IDs to the current predictions.
247
247
val gatheredFinished = BeamSearchDecoder .gather(
0 commit comments