Skip to content
This repository was archived by the owner on Jun 2, 2025. It is now read-only.

Commit 9f9d2de

Browse files
committed
Add TensorBatch type
1 parent c37ea9f commit 9f9d2de

File tree

2 files changed

+4
-1
lines changed

2 files changed

+4
-1
lines changed

ocf_datapipes/batch/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""Datapipes for batching together data"""
2-
from .batches import BatchKey, NumpyBatch, NWPBatchKey, NWPNumpyBatch, XarrayBatch
2+
from .batches import BatchKey, NumpyBatch, NWPBatchKey, NWPNumpyBatch, TensorBatch, XarrayBatch
33
from .merge_numpy_examples_to_batch import (
44
MergeNumpyBatchIterDataPipe as MergeNumpyBatch,
55
)

ocf_datapipes/batch/batches.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import Union
55

66
import numpy as np
7+
import torch
78
import xarray as xr
89

910

@@ -229,3 +230,5 @@ class NWPBatchKey(Enum):
229230
NumpyBatch = dict[BatchKey, Union[np.ndarray, dict[str, NWPNumpyBatch]]]
230231

231232
XarrayBatch = dict[BatchKey, Union[xr.DataArray, xr.Dataset]]
233+
234+
TensorBatch = dict[BatchKey, Union[torch.Tensor, dict[str, dict[NWPBatchKey, torch.Tensor]]]]

0 commit comments

Comments
 (0)