Skip to content

Conversation

stsouko
Copy link
Contributor

@stsouko stsouko commented May 13, 2023

Torch wrapping datasets list has:
TensorDataset
ConcatDataset
ChainDataset

TensorDataset is useful for stacking sets of tensors but can't work with objects without .size() method.

This PR proposes StackDataset, similar to TensorDataset but for a general case like ConcatDataset.

Possible usage of StackDataset is multimodal networks with different input like image+text or for staking non-tensor input and property to predict.

@stsouko stsouko requested review from NivekT and ejguan as code owners May 13, 2023 15:19
@pytorch-bot pytorch-bot bot added the release notes: dataloader release notes category label May 13, 2023
@pytorch-bot
Copy link

pytorch-bot bot commented May 13, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/101338

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit c5bbb0d:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@zou3519 zou3519 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label May 15, 2023
@stsouko stsouko requested a review from ejguan May 15, 2023 20:16
@@ -206,6 +208,54 @@ def __len__(self):
return self.tensors[0].size(0)


class StackDataset(Dataset[Union[Tuple[T_co, ...], Dict[str, T_co]]]):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another thing needs to call out. It might be better to define a TypeVar that contains either Tuple[T_co, ...] or Dict[str, T_co].

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I get the idea.
Do you mean:

T_td = Union[Tuple[T_co, ...], Dict[str, T_co]]

class StackDataset(Dataset[T_td]):

instead of:

class StackDataset(Dataset[Union[Tuple[T_co, ...], Dict[str, T_co]]]):

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No. Let's do something like T_stack = TypeVar('T_stack', Tuple[T_co, ...], Dict[str, T_co]). Using Union means the output can be either Tuple or Dict. But, using TypeVar is like template-style that only allows one type per dataset instance.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok. added.

Comment on lines 241 to 242
tmp = list(kwargs.values())
if any(len(tmp[0]) != len(dataset) for dataset in tmp): # type: ignore[arg-type]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

self._length = ...
if any(self._length != ...

else:
raise ValueError("At least one dataset should be passed")

def __getitem__(self, index):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need a type annotation of TypeVar here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but output already typed in generic Dataset.

Copy link
Contributor

@ejguan ejguan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM. I only have one question about this use case. Do we expect all datasets having the same size or just supporting the least length like TorchData's zip operation.
cc: @NivekT

Copy link
Contributor

@NivekT NivekT left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please fix the mypy linting error

Copy link
Contributor

@NivekT NivekT left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should be good to go. Thanks!

Looks like this requires all Datasets that are passed in to have the same length and keys, which seems fine with me.

@NivekT NivekT added the topic: improvements topic category label May 17, 2023
@NivekT NivekT changed the title stacking dataset [DataLoader] Adding StackDataset May 17, 2023
@NivekT
Copy link
Contributor

NivekT commented May 17, 2023

@pytorchbot merge -r

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label May 17, 2023
@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased stack onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout stack && git pull --rebase)

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 mandatory check(s) failed. The first few are:

Dig deeper by viewing the failures on hud

Details for Dev Infra team Raised by workflow job

Failing merge rule: Core Maintainers

@NivekT
Copy link
Contributor

NivekT commented May 17, 2023

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@stsouko stsouko deleted the stack branch May 20, 2023 12:23
pytorchmergebot pushed a commit that referenced this pull request May 22, 2023
New dataset class added by #101338 missed in documentation.

Pull Request resolved: #101927
Approved by: https://github.com/kit1980
jcaip pushed a commit that referenced this pull request May 23, 2023
Torch wrapping datasets list has:
`TensorDataset`
`ConcatDataset`
`ChainDataset`

`TensorDataset` is useful for stacking sets of tensors but can't work with objects without `.size()` method.

This PR proposes `StackDataset`, similar to `TensorDataset` but for a general case like `ConcatDataset`.

Possible usage of `StackDataset` is multimodal networks with different input like image+text or for staking non-tensor input and property to predict.
Pull Request resolved: #101338
Approved by: https://github.com/ejguan, https://github.com/NivekT
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged open source release notes: dataloader release notes category topic: improvements topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants