Skip to content

Remove streamType field from DecodedOutput #457

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jan 21, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 67 additions & 76 deletions src/torchcodec/decoders/_core/VideoDecoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -869,7 +869,6 @@ VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput(
AVFrame* frame = rawOutput.frame.get();
output.streamIndex = streamIndex;
auto& streamInfo = streams_[streamIndex];
output.streamType = streams_[streamIndex].stream->codecpar->codec_type;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do think there is some value in doing a TORCH_CHECK(streamInfo.stream->codecpar->codec_type == AVMEDIA_TYPE_VIDEO) here or close by to make our assumptions explicit.

output.pts = frame->pts;
output.ptsSeconds =
ptsToSeconds(frame->pts, formatContext_->streams[streamIndex]->time_base);
Expand Down Expand Up @@ -930,86 +929,78 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU(
}

torch::Tensor outputTensor;
if (output.streamType == AVMEDIA_TYPE_VIDEO) {
// We need to compare the current frame context with our previous frame
// context. If they are different, then we need to re-create our colorspace
// conversion objects. We create our colorspace conversion objects late so
// that we don't have to depend on the unreliable metadata in the header.
// And we sometimes re-create them because it's possible for frame
// resolution to change mid-stream. Finally, we want to reuse the colorspace
// conversion objects as much as possible for performance reasons.
enum AVPixelFormat frameFormat =
static_cast<enum AVPixelFormat>(frame->format);
auto frameContext = DecodedFrameContext{
frame->width,
frame->height,
frameFormat,
expectedOutputWidth,
expectedOutputHeight};
// We need to compare the current frame context with our previous frame
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Diff looks bigger than it actually is because this block got indented to the left.

// context. If they are different, then we need to re-create our colorspace
// conversion objects. We create our colorspace conversion objects late so
// that we don't have to depend on the unreliable metadata in the header.
// And we sometimes re-create them because it's possible for frame
// resolution to change mid-stream. Finally, we want to reuse the colorspace
// conversion objects as much as possible for performance reasons.
enum AVPixelFormat frameFormat =
static_cast<enum AVPixelFormat>(frame->format);
auto frameContext = DecodedFrameContext{
frame->width,
frame->height,
frameFormat,
expectedOutputWidth,
expectedOutputHeight};

if (streamInfo.colorConversionLibrary == ColorConversionLibrary::SWSCALE) {
outputTensor = preAllocatedOutputTensor.value_or(allocateEmptyHWCTensor(
expectedOutputHeight, expectedOutputWidth, torch::kCPU));
if (streamInfo.colorConversionLibrary == ColorConversionLibrary::SWSCALE) {
outputTensor = preAllocatedOutputTensor.value_or(allocateEmptyHWCTensor(
expectedOutputHeight, expectedOutputWidth, torch::kCPU));

if (!streamInfo.swsContext ||
streamInfo.prevFrameContext != frameContext) {
createSwsContext(streamInfo, frameContext, frame->colorspace);
streamInfo.prevFrameContext = frameContext;
}
int resultHeight =
convertFrameToTensorUsingSwsScale(streamIndex, frame, outputTensor);
// If this check failed, it would mean that the frame wasn't reshaped to
// the expected height.
// TODO: Can we do the same check for width?
TORCH_CHECK(
resultHeight == expectedOutputHeight,
"resultHeight != expectedOutputHeight: ",
resultHeight,
" != ",
expectedOutputHeight);
if (!streamInfo.swsContext || streamInfo.prevFrameContext != frameContext) {
createSwsContext(streamInfo, frameContext, frame->colorspace);
streamInfo.prevFrameContext = frameContext;
}
int resultHeight =
convertFrameToTensorUsingSwsScale(streamIndex, frame, outputTensor);
// If this check failed, it would mean that the frame wasn't reshaped to
// the expected height.
// TODO: Can we do the same check for width?
TORCH_CHECK(
resultHeight == expectedOutputHeight,
"resultHeight != expectedOutputHeight: ",
resultHeight,
" != ",
expectedOutputHeight);

output.frame = outputTensor;
} else if (
streamInfo.colorConversionLibrary ==
ColorConversionLibrary::FILTERGRAPH) {
if (!streamInfo.filterState.filterGraph ||
streamInfo.prevFrameContext != frameContext) {
createFilterGraph(streamInfo, expectedOutputHeight, expectedOutputWidth);
streamInfo.prevFrameContext = frameContext;
}
outputTensor = convertFrameToTensorUsingFilterGraph(streamIndex, frame);

output.frame = outputTensor;
} else if (
streamInfo.colorConversionLibrary ==
ColorConversionLibrary::FILTERGRAPH) {
if (!streamInfo.filterState.filterGraph ||
streamInfo.prevFrameContext != frameContext) {
createFilterGraph(
streamInfo, expectedOutputHeight, expectedOutputWidth);
streamInfo.prevFrameContext = frameContext;
}
outputTensor = convertFrameToTensorUsingFilterGraph(streamIndex, frame);

// Similarly to above, if this check fails it means the frame wasn't
// reshaped to its expected dimensions by filtergraph.
auto shape = outputTensor.sizes();
TORCH_CHECK(
(shape.size() == 3) && (shape[0] == expectedOutputHeight) &&
(shape[1] == expectedOutputWidth) && (shape[2] == 3),
"Expected output tensor of shape ",
expectedOutputHeight,
"x",
expectedOutputWidth,
"x3, got ",
shape);

if (preAllocatedOutputTensor.has_value()) {
// We have already validated that preAllocatedOutputTensor and
// outputTensor have the same shape.
preAllocatedOutputTensor.value().copy_(outputTensor);
output.frame = preAllocatedOutputTensor.value();
} else {
output.frame = outputTensor;
}
// Similarly to above, if this check fails it means the frame wasn't
// reshaped to its expected dimensions by filtergraph.
auto shape = outputTensor.sizes();
TORCH_CHECK(
(shape.size() == 3) && (shape[0] == expectedOutputHeight) &&
(shape[1] == expectedOutputWidth) && (shape[2] == 3),
"Expected output tensor of shape ",
expectedOutputHeight,
"x",
expectedOutputWidth,
"x3, got ",
shape);

if (preAllocatedOutputTensor.has_value()) {
// We have already validated that preAllocatedOutputTensor and
// outputTensor have the same shape.
preAllocatedOutputTensor.value().copy_(outputTensor);
output.frame = preAllocatedOutputTensor.value();
} else {
throw std::runtime_error(
"Invalid color conversion library: " +
std::to_string(static_cast<int>(streamInfo.colorConversionLibrary)));
output.frame = outputTensor;
}
} else if (output.streamType == AVMEDIA_TYPE_AUDIO) {
// TODO: https://github.com/pytorch-labs/torchcodec/issues/85 implement
// audio decoding.
throw std::runtime_error("Audio is not supported yet.");
} else {
throw std::runtime_error(
"Invalid color conversion library: " +
std::to_string(static_cast<int>(streamInfo.colorConversionLibrary)));
}
}

Expand Down
2 changes: 0 additions & 2 deletions src/torchcodec/decoders/_core/VideoDecoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -191,8 +191,6 @@ class VideoDecoder {
struct DecodedOutput {
// The actual decoded output as a Tensor.
torch::Tensor frame;
// Could be AVMEDIA_TYPE_VIDEO or AVMEDIA_TYPE_AUDIO.
AVMediaType streamType;
// The stream index of the decoded frame. Used to distinguish
// between streams that are of the same type.
int streamIndex;
Expand Down
Loading