Skip to content

Commit b1719e7

Browse files
authored
Remove streamType field from DecodedOutput (#457)
1 parent 4f3e491 commit b1719e7

File tree

2 files changed

+68
-78
lines changed

2 files changed

+68
-78
lines changed

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 68 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -871,7 +871,7 @@ VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput(
871871
AVFrame* frame = rawOutput.frame.get();
872872
output.streamIndex = streamIndex;
873873
auto& streamInfo = streams_[streamIndex];
874-
output.streamType = streams_[streamIndex].stream->codecpar->codec_type;
874+
TORCH_CHECK(streamInfo.stream->codecpar->codec_type == AVMEDIA_TYPE_VIDEO);
875875
output.pts = frame->pts;
876876
output.ptsSeconds =
877877
ptsToSeconds(frame->pts, formatContext_->streams[streamIndex]->time_base);
@@ -932,86 +932,78 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU(
932932
}
933933

934934
torch::Tensor outputTensor;
935-
if (output.streamType == AVMEDIA_TYPE_VIDEO) {
936-
// We need to compare the current frame context with our previous frame
937-
// context. If they are different, then we need to re-create our colorspace
938-
// conversion objects. We create our colorspace conversion objects late so
939-
// that we don't have to depend on the unreliable metadata in the header.
940-
// And we sometimes re-create them because it's possible for frame
941-
// resolution to change mid-stream. Finally, we want to reuse the colorspace
942-
// conversion objects as much as possible for performance reasons.
943-
enum AVPixelFormat frameFormat =
944-
static_cast<enum AVPixelFormat>(frame->format);
945-
auto frameContext = DecodedFrameContext{
946-
frame->width,
947-
frame->height,
948-
frameFormat,
949-
expectedOutputWidth,
950-
expectedOutputHeight};
935+
// We need to compare the current frame context with our previous frame
936+
// context. If they are different, then we need to re-create our colorspace
937+
// conversion objects. We create our colorspace conversion objects late so
938+
// that we don't have to depend on the unreliable metadata in the header.
939+
// And we sometimes re-create them because it's possible for frame
940+
// resolution to change mid-stream. Finally, we want to reuse the colorspace
941+
// conversion objects as much as possible for performance reasons.
942+
enum AVPixelFormat frameFormat =
943+
static_cast<enum AVPixelFormat>(frame->format);
944+
auto frameContext = DecodedFrameContext{
945+
frame->width,
946+
frame->height,
947+
frameFormat,
948+
expectedOutputWidth,
949+
expectedOutputHeight};
951950

952-
if (streamInfo.colorConversionLibrary == ColorConversionLibrary::SWSCALE) {
953-
outputTensor = preAllocatedOutputTensor.value_or(allocateEmptyHWCTensor(
954-
expectedOutputHeight, expectedOutputWidth, torch::kCPU));
951+
if (streamInfo.colorConversionLibrary == ColorConversionLibrary::SWSCALE) {
952+
outputTensor = preAllocatedOutputTensor.value_or(allocateEmptyHWCTensor(
953+
expectedOutputHeight, expectedOutputWidth, torch::kCPU));
955954

956-
if (!streamInfo.swsContext ||
957-
streamInfo.prevFrameContext != frameContext) {
958-
createSwsContext(streamInfo, frameContext, frame->colorspace);
959-
streamInfo.prevFrameContext = frameContext;
960-
}
961-
int resultHeight =
962-
convertFrameToTensorUsingSwsScale(streamIndex, frame, outputTensor);
963-
// If this check failed, it would mean that the frame wasn't reshaped to
964-
// the expected height.
965-
// TODO: Can we do the same check for width?
966-
TORCH_CHECK(
967-
resultHeight == expectedOutputHeight,
968-
"resultHeight != expectedOutputHeight: ",
969-
resultHeight,
970-
" != ",
971-
expectedOutputHeight);
955+
if (!streamInfo.swsContext || streamInfo.prevFrameContext != frameContext) {
956+
createSwsContext(streamInfo, frameContext, frame->colorspace);
957+
streamInfo.prevFrameContext = frameContext;
958+
}
959+
int resultHeight =
960+
convertFrameToTensorUsingSwsScale(streamIndex, frame, outputTensor);
961+
// If this check failed, it would mean that the frame wasn't reshaped to
962+
// the expected height.
963+
// TODO: Can we do the same check for width?
964+
TORCH_CHECK(
965+
resultHeight == expectedOutputHeight,
966+
"resultHeight != expectedOutputHeight: ",
967+
resultHeight,
968+
" != ",
969+
expectedOutputHeight);
970+
971+
output.frame = outputTensor;
972+
} else if (
973+
streamInfo.colorConversionLibrary ==
974+
ColorConversionLibrary::FILTERGRAPH) {
975+
if (!streamInfo.filterState.filterGraph ||
976+
streamInfo.prevFrameContext != frameContext) {
977+
createFilterGraph(streamInfo, expectedOutputHeight, expectedOutputWidth);
978+
streamInfo.prevFrameContext = frameContext;
979+
}
980+
outputTensor = convertFrameToTensorUsingFilterGraph(streamIndex, frame);
972981

973-
output.frame = outputTensor;
974-
} else if (
975-
streamInfo.colorConversionLibrary ==
976-
ColorConversionLibrary::FILTERGRAPH) {
977-
if (!streamInfo.filterState.filterGraph ||
978-
streamInfo.prevFrameContext != frameContext) {
979-
createFilterGraph(
980-
streamInfo, expectedOutputHeight, expectedOutputWidth);
981-
streamInfo.prevFrameContext = frameContext;
982-
}
983-
outputTensor = convertFrameToTensorUsingFilterGraph(streamIndex, frame);
984-
985-
// Similarly to above, if this check fails it means the frame wasn't
986-
// reshaped to its expected dimensions by filtergraph.
987-
auto shape = outputTensor.sizes();
988-
TORCH_CHECK(
989-
(shape.size() == 3) && (shape[0] == expectedOutputHeight) &&
990-
(shape[1] == expectedOutputWidth) && (shape[2] == 3),
991-
"Expected output tensor of shape ",
992-
expectedOutputHeight,
993-
"x",
994-
expectedOutputWidth,
995-
"x3, got ",
996-
shape);
997-
998-
if (preAllocatedOutputTensor.has_value()) {
999-
// We have already validated that preAllocatedOutputTensor and
1000-
// outputTensor have the same shape.
1001-
preAllocatedOutputTensor.value().copy_(outputTensor);
1002-
output.frame = preAllocatedOutputTensor.value();
1003-
} else {
1004-
output.frame = outputTensor;
1005-
}
982+
// Similarly to above, if this check fails it means the frame wasn't
983+
// reshaped to its expected dimensions by filtergraph.
984+
auto shape = outputTensor.sizes();
985+
TORCH_CHECK(
986+
(shape.size() == 3) && (shape[0] == expectedOutputHeight) &&
987+
(shape[1] == expectedOutputWidth) && (shape[2] == 3),
988+
"Expected output tensor of shape ",
989+
expectedOutputHeight,
990+
"x",
991+
expectedOutputWidth,
992+
"x3, got ",
993+
shape);
994+
995+
if (preAllocatedOutputTensor.has_value()) {
996+
// We have already validated that preAllocatedOutputTensor and
997+
// outputTensor have the same shape.
998+
preAllocatedOutputTensor.value().copy_(outputTensor);
999+
output.frame = preAllocatedOutputTensor.value();
10061000
} else {
1007-
throw std::runtime_error(
1008-
"Invalid color conversion library: " +
1009-
std::to_string(static_cast<int>(streamInfo.colorConversionLibrary)));
1001+
output.frame = outputTensor;
10101002
}
1011-
} else if (output.streamType == AVMEDIA_TYPE_AUDIO) {
1012-
// TODO: https://github.com/pytorch-labs/torchcodec/issues/85 implement
1013-
// audio decoding.
1014-
throw std::runtime_error("Audio is not supported yet.");
1003+
} else {
1004+
throw std::runtime_error(
1005+
"Invalid color conversion library: " +
1006+
std::to_string(static_cast<int>(streamInfo.colorConversionLibrary)));
10151007
}
10161008
}
10171009

src/torchcodec/decoders/_core/VideoDecoder.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -171,8 +171,6 @@ class VideoDecoder {
171171
struct DecodedOutput {
172172
// The actual decoded output as a Tensor.
173173
torch::Tensor frame;
174-
// Could be AVMEDIA_TYPE_VIDEO or AVMEDIA_TYPE_AUDIO.
175-
AVMediaType streamType;
176174
// The stream index of the decoded frame. Used to distinguish
177175
// between streams that are of the same type.
178176
int streamIndex;

0 commit comments

Comments
 (0)