@@ -871,7 +871,7 @@ VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput(
871
871
AVFrame* frame = rawOutput.frame .get ();
872
872
output.streamIndex = streamIndex;
873
873
auto & streamInfo = streams_[streamIndex];
874
- output. streamType = streams_[streamIndex]. stream ->codecpar ->codec_type ;
874
+ TORCH_CHECK (streamInfo. stream ->codecpar ->codec_type == AVMEDIA_TYPE_VIDEO) ;
875
875
output.pts = frame->pts ;
876
876
output.ptsSeconds =
877
877
ptsToSeconds (frame->pts , formatContext_->streams [streamIndex]->time_base );
@@ -932,86 +932,78 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU(
932
932
}
933
933
934
934
torch::Tensor outputTensor;
935
- if (output.streamType == AVMEDIA_TYPE_VIDEO) {
936
- // We need to compare the current frame context with our previous frame
937
- // context. If they are different, then we need to re-create our colorspace
938
- // conversion objects. We create our colorspace conversion objects late so
939
- // that we don't have to depend on the unreliable metadata in the header.
940
- // And we sometimes re-create them because it's possible for frame
941
- // resolution to change mid-stream. Finally, we want to reuse the colorspace
942
- // conversion objects as much as possible for performance reasons.
943
- enum AVPixelFormat frameFormat =
944
- static_cast <enum AVPixelFormat>(frame->format );
945
- auto frameContext = DecodedFrameContext{
946
- frame->width ,
947
- frame->height ,
948
- frameFormat,
949
- expectedOutputWidth,
950
- expectedOutputHeight};
935
+ // We need to compare the current frame context with our previous frame
936
+ // context. If they are different, then we need to re-create our colorspace
937
+ // conversion objects. We create our colorspace conversion objects late so
938
+ // that we don't have to depend on the unreliable metadata in the header.
939
+ // And we sometimes re-create them because it's possible for frame
940
+ // resolution to change mid-stream. Finally, we want to reuse the colorspace
941
+ // conversion objects as much as possible for performance reasons.
942
+ enum AVPixelFormat frameFormat =
943
+ static_cast <enum AVPixelFormat>(frame->format );
944
+ auto frameContext = DecodedFrameContext{
945
+ frame->width ,
946
+ frame->height ,
947
+ frameFormat,
948
+ expectedOutputWidth,
949
+ expectedOutputHeight};
951
950
952
- if (streamInfo.colorConversionLibrary == ColorConversionLibrary::SWSCALE) {
953
- outputTensor = preAllocatedOutputTensor.value_or (allocateEmptyHWCTensor (
954
- expectedOutputHeight, expectedOutputWidth, torch::kCPU ));
951
+ if (streamInfo.colorConversionLibrary == ColorConversionLibrary::SWSCALE) {
952
+ outputTensor = preAllocatedOutputTensor.value_or (allocateEmptyHWCTensor (
953
+ expectedOutputHeight, expectedOutputWidth, torch::kCPU ));
955
954
956
- if (!streamInfo.swsContext ||
957
- streamInfo.prevFrameContext != frameContext) {
958
- createSwsContext (streamInfo, frameContext, frame->colorspace );
959
- streamInfo.prevFrameContext = frameContext;
960
- }
961
- int resultHeight =
962
- convertFrameToTensorUsingSwsScale (streamIndex, frame, outputTensor);
963
- // If this check failed, it would mean that the frame wasn't reshaped to
964
- // the expected height.
965
- // TODO: Can we do the same check for width?
966
- TORCH_CHECK (
967
- resultHeight == expectedOutputHeight,
968
- " resultHeight != expectedOutputHeight: " ,
969
- resultHeight,
970
- " != " ,
971
- expectedOutputHeight);
955
+ if (!streamInfo.swsContext || streamInfo.prevFrameContext != frameContext) {
956
+ createSwsContext (streamInfo, frameContext, frame->colorspace );
957
+ streamInfo.prevFrameContext = frameContext;
958
+ }
959
+ int resultHeight =
960
+ convertFrameToTensorUsingSwsScale (streamIndex, frame, outputTensor);
961
+ // If this check failed, it would mean that the frame wasn't reshaped to
962
+ // the expected height.
963
+ // TODO: Can we do the same check for width?
964
+ TORCH_CHECK (
965
+ resultHeight == expectedOutputHeight,
966
+ " resultHeight != expectedOutputHeight: " ,
967
+ resultHeight,
968
+ " != " ,
969
+ expectedOutputHeight);
970
+
971
+ output.frame = outputTensor;
972
+ } else if (
973
+ streamInfo.colorConversionLibrary ==
974
+ ColorConversionLibrary::FILTERGRAPH) {
975
+ if (!streamInfo.filterState .filterGraph ||
976
+ streamInfo.prevFrameContext != frameContext) {
977
+ createFilterGraph (streamInfo, expectedOutputHeight, expectedOutputWidth);
978
+ streamInfo.prevFrameContext = frameContext;
979
+ }
980
+ outputTensor = convertFrameToTensorUsingFilterGraph (streamIndex, frame);
972
981
973
- output.frame = outputTensor;
974
- } else if (
975
- streamInfo.colorConversionLibrary ==
976
- ColorConversionLibrary::FILTERGRAPH) {
977
- if (!streamInfo.filterState .filterGraph ||
978
- streamInfo.prevFrameContext != frameContext) {
979
- createFilterGraph (
980
- streamInfo, expectedOutputHeight, expectedOutputWidth);
981
- streamInfo.prevFrameContext = frameContext;
982
- }
983
- outputTensor = convertFrameToTensorUsingFilterGraph (streamIndex, frame);
984
-
985
- // Similarly to above, if this check fails it means the frame wasn't
986
- // reshaped to its expected dimensions by filtergraph.
987
- auto shape = outputTensor.sizes ();
988
- TORCH_CHECK (
989
- (shape.size () == 3 ) && (shape[0 ] == expectedOutputHeight) &&
990
- (shape[1 ] == expectedOutputWidth) && (shape[2 ] == 3 ),
991
- " Expected output tensor of shape " ,
992
- expectedOutputHeight,
993
- " x" ,
994
- expectedOutputWidth,
995
- " x3, got " ,
996
- shape);
997
-
998
- if (preAllocatedOutputTensor.has_value ()) {
999
- // We have already validated that preAllocatedOutputTensor and
1000
- // outputTensor have the same shape.
1001
- preAllocatedOutputTensor.value ().copy_ (outputTensor);
1002
- output.frame = preAllocatedOutputTensor.value ();
1003
- } else {
1004
- output.frame = outputTensor;
1005
- }
982
+ // Similarly to above, if this check fails it means the frame wasn't
983
+ // reshaped to its expected dimensions by filtergraph.
984
+ auto shape = outputTensor.sizes ();
985
+ TORCH_CHECK (
986
+ (shape.size () == 3 ) && (shape[0 ] == expectedOutputHeight) &&
987
+ (shape[1 ] == expectedOutputWidth) && (shape[2 ] == 3 ),
988
+ " Expected output tensor of shape " ,
989
+ expectedOutputHeight,
990
+ " x" ,
991
+ expectedOutputWidth,
992
+ " x3, got " ,
993
+ shape);
994
+
995
+ if (preAllocatedOutputTensor.has_value ()) {
996
+ // We have already validated that preAllocatedOutputTensor and
997
+ // outputTensor have the same shape.
998
+ preAllocatedOutputTensor.value ().copy_ (outputTensor);
999
+ output.frame = preAllocatedOutputTensor.value ();
1006
1000
} else {
1007
- throw std::runtime_error (
1008
- " Invalid color conversion library: " +
1009
- std::to_string (static_cast <int >(streamInfo.colorConversionLibrary )));
1001
+ output.frame = outputTensor;
1010
1002
}
1011
- } else if (output. streamType == AVMEDIA_TYPE_AUDIO) {
1012
- // TODO: https://github.com/pytorch-labs/torchcodec/issues/85 implement
1013
- // audio decoding.
1014
- throw std::runtime_error ( " Audio is not supported yet. " );
1003
+ } else {
1004
+ throw std::runtime_error (
1005
+ " Invalid color conversion library: " +
1006
+ std::to_string ( static_cast < int >(streamInfo. colorConversionLibrary )) );
1015
1007
}
1016
1008
}
1017
1009
0 commit comments