-
Notifications
You must be signed in to change notification settings - Fork 25.2k
[DataLoader] Adding StackDataset
#101338
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[DataLoader] Adding StackDataset
#101338
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/101338
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit c5bbb0d: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
torch/utils/data/dataset.py
Outdated
@@ -206,6 +208,54 @@ def __len__(self): | |||
return self.tensors[0].size(0) | |||
|
|||
|
|||
class StackDataset(Dataset[Union[Tuple[T_co, ...], Dict[str, T_co]]]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another thing needs to call out. It might be better to define a TypeVar
that contains either Tuple[T_co, ...]
or Dict[str, T_co]
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure I get the idea.
Do you mean:
T_td = Union[Tuple[T_co, ...], Dict[str, T_co]]
class StackDataset(Dataset[T_td]):
instead of:
class StackDataset(Dataset[Union[Tuple[T_co, ...], Dict[str, T_co]]]):
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No. Let's do something like T_stack = TypeVar('T_stack', Tuple[T_co, ...], Dict[str, T_co])
. Using Union
means the output can be either Tuple
or Dict
. But, using TypeVar
is like template-style that only allows one type per dataset instance.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok. added.
torch/utils/data/dataset.py
Outdated
tmp = list(kwargs.values()) | ||
if any(len(tmp[0]) != len(dataset) for dataset in tmp): # type: ignore[arg-type] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit:
self._length = ...
if any(self._length != ...
else: | ||
raise ValueError("At least one dataset should be passed") | ||
|
||
def __getitem__(self, index): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need a type annotation of TypeVar here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
but output already typed in generic Dataset.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall LGTM. I only have one question about this use case. Do we expect all datasets having the same size or just supporting the least length like TorchData's zip
operation.
cc: @NivekT
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please fix the mypy
linting error
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should be good to go. Thanks!
Looks like this requires all Datasets that are passed in to have the same length and keys, which seems fine with me.
@pytorchbot merge -r |
@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here |
Successfully rebased |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 mandatory check(s) failed. The first few are: Dig deeper by viewing the failures on hud |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
New dataset class added by #101338 missed in documentation. Pull Request resolved: #101927 Approved by: https://github.com/kit1980
Torch wrapping datasets list has: `TensorDataset` `ConcatDataset` `ChainDataset` `TensorDataset` is useful for stacking sets of tensors but can't work with objects without `.size()` method. This PR proposes `StackDataset`, similar to `TensorDataset` but for a general case like `ConcatDataset`. Possible usage of `StackDataset` is multimodal networks with different input like image+text or for staking non-tensor input and property to predict. Pull Request resolved: #101338 Approved by: https://github.com/ejguan, https://github.com/NivekT
Torch wrapping datasets list has:
TensorDataset
ConcatDataset
ChainDataset
TensorDataset
is useful for stacking sets of tensors but can't work with objects without.size()
method.This PR proposes
StackDataset
, similar toTensorDataset
but for a general case likeConcatDataset
.Possible usage of
StackDataset
is multimodal networks with different input like image+text or for staking non-tensor input and property to predict.