-
Notifications
You must be signed in to change notification settings - Fork 25.2k
Initial fake / meta tensor support for nested tensors #96354
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/96354
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ✅ No FailuresAs of commit 724d0ae: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The Nested changes look gooood
This PR adds initial fake / meta tensor support for nested tensors. In particular: * Avoids crashing for `device="meta"` in the `NestedTensorImpl` constructor * Adds `NestedTensorMeta` registrations for: * `detach` - used by `torch.Tensor._make_subclass()`, which `FakeTensor` uses * `values` - used for printing so it's nice to have * Tweaks `FakeTensorMode.run_fallback_kernel()` to work with nested tensors * Enhances `MetaConverter` to work with nested tensors * Registers fake tensor `op_impl` for the nested tensor constructor `_nested_tensor_from_tensor_list` * Adds tests to `test/test_nestedtensor.py`. I'm open to being convinced these should live in `test/test_fake_tensor.py` instead **Hack alert**: The NT constructor `_nested_tensor_from_tensor_list(tensor_list, dtype, layout, device, pin_memory)` needs to build a real, *non-meta* size metadata tensor from the individual sizes in the `tensor_list`. This is currently done by dispatching to `at::stack`. I added a horrible hack to disable the meta key in TLS here so it doesn't incorrectly dispatch via meta. What's a better way to do this? Possible alternatives: * Create a Python meta registration for the constructor and dispatch to a (newly exposed to Python) `_nested_view_from_buffer`. We run into a similar problem in the implementation of `_nested_view_from_buffer` where we wrongly dispatch via meta when checking the passed-in NT metadata, but we could skip these checks. * Some better mechanism for selectively disabling meta when operating within `FakeTensorMode`? (heavily based on work by mikaylagawarecki in #93981) [ghstack-poisoned]
Regarding the hack, I want to mention something else I proposed to @wanchaol when they had a similar problem. In Wanchao's case, they are in a fake tensor mode, but they want to represent dtensor meshes as tensors, but they'd get fakeified and you lose the mesh data. My proposal for dealing with this is more aggressive constant propagation in fake tensor. Fake tensor is already willing to do constant propagation in limited cases (e.g., scalar tensors). In Wanchao's case, we probably could have gotten away by saying that any tensor <256 numel can just get constant propagated. Not sure if that's enough for you here; maybe it is, because if your index tensors are really big, other things are going to be bad. But we can also have other ways to identify index tensors. In the worst case scenario, we could introduce a new device type, "constant", which always gets repped in CPU and never gets fakeified away. cc @eellison |
This PR adds initial fake / meta tensor support for nested tensors. In particular: * Expands some `NestedTensorImpl` constructors to both accept an explicit dim and avoid validating metadata * Explicit dim is required when we're passing symbolic metadata because it needs to be non-symbolic (i.e. it can't be computed from `nested_sizes`) * `bool validate_metadata = true` can be disabled for a check-less construction * Used by new `_nested_from_buffer_unchecked(buffer, sizes, strides, offsets, dim)` private op * **Feedback desired here** on the nicest way to do this * Adds `NestedTensorMeta` registrations for: * `detach` - used by `torch.Tensor._make_subclass()`, which `FakeTensor` uses * `values` - used for printing so it's nice to have * `zero_` / `normal_` / `fill_` - commonly-used init functions * `ones_like` / `empty_like` - commonly-used factories * `add` - simple op for validation purposes * `_nested_tensor_size()` / `_nested_tensor_strides()` - private accessors used in testing * Tweaks `FakeTensorMode.run_fallback_kernel()` to work with nested tensors * Enhances `MetaConverter` to work with nested tensors * `nested_sizes` / `nested_strides` have size `(s0, N)` where N = the rank of the underlying components of the NT * `storage_offsets` has size `s0` * The NT's `buffer` is of size `i0`, an unbacked SymInt * Registers fake tensor `op_impl` for the nested tensor constructor `_nested_tensor_from_tensor_list` * Errors for now since we don't support this NT <-> T boundary! * Expands tensor serialization to support the combo of NT + meta * Example: `FakeTensor(nested_tensor(..., device='meta', num_components=s0, dim=2), cpu)` * (bikeshedding on how this should look is welcome) * Adds tests to `test/test_nestedtensor.py` that use `CrossRefFakeMode` * exFails: anything that uses `_nested_tensor_from_tensor_list` since we don't support this for fake * Adds a basic Dynamo test that demonstrates valid fake usage (i.e. a simple graph with NT inputs can be traced) Note: The initial direction for this PR defined the fake version of a real NT as a meta NT with the same nested structure (i.e. non-meta metadata tensors). This is no longer the case, as it is too restrictive. In particular, we want dynamism over the batch size, and a fixed nested structure fixes the batch size. Now the metadata tensors have symbolic sizes themselves. Due to this, dynamic shapes support is **required** to be on for fake / meta conversion of NTs. Open questions: * We’re creating a buffer for the NT sized using an unbacked SymInt. Do we need this to propagate? e.g. if we call empty_like(buffer) should we use the same unbacked SymInt to size the result? * Is it okay to require dynamic shapes for NT? (heavily based on work by mikaylagawarecki in #93981) cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left some small comments, I need to look at the constructor overload hierarchy again just for my clarity
if self.is_meta: | ||
nt_sizes = self._nested_tensor_size() | ||
suffixes.append( | ||
"num_components=" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is num_components
actually a valid kwarg to the constructor? (In general, it is best for the printing function to hew to the actual working syntax as much as possible.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is num_components actually a valid kwarg to the constructor?
Sadly, no :( The serialization here is made up and doesn't fit the proper semantics for __repr__()
.
We don't have many public APIs for NT creation at the moment, so almost nothing we do here will actually work.
This PR adds initial fake / meta tensor support for nested tensors. In particular: * Expands some `NestedTensorImpl` constructors to both accept an explicit dim and avoid validating metadata * Explicit dim is required when we're passing symbolic metadata because it needs to be non-symbolic (i.e. it can't be computed from `nested_sizes`) * `bool validate_metadata = true` can be disabled for a check-less construction * Used by new `_nested_from_buffer_unchecked(buffer, sizes, strides, offsets, dim)` private op * **Feedback desired here** on the nicest way to do this * Adds `NestedTensorMeta` registrations for: * `detach` - used by `torch.Tensor._make_subclass()`, which `FakeTensor` uses * `values` - used for printing so it's nice to have * `zero_` / `normal_` / `fill_` - commonly-used init functions * `ones_like` / `empty_like` - commonly-used factories * `add` - simple op for validation purposes * `_nested_tensor_size()` / `_nested_tensor_strides()` - private accessors used in testing * Tweaks `FakeTensorMode.run_fallback_kernel()` to work with nested tensors * Enhances `MetaConverter` to work with nested tensors * `nested_sizes` / `nested_strides` have size `(s0, N)` where N = the rank of the underlying components of the NT * `storage_offsets` has size `s0` * The NT's `buffer` is of size `i0`, an unbacked SymInt * Registers fake tensor `op_impl` for the nested tensor constructor `_nested_tensor_from_tensor_list` * Errors for now since we don't support this NT <-> T boundary! * Expands tensor serialization to support the combo of NT + meta * Example: `FakeTensor(nested_tensor(..., device='meta', num_components=s0, dim=2), cpu)` * (bikeshedding on how this should look is welcome) * Adds tests to `test/test_nestedtensor.py` that use `CrossRefFakeMode` * exFails: anything that uses `_nested_tensor_from_tensor_list` since we don't support this for fake * Adds a basic Dynamo test that demonstrates valid fake usage (i.e. a simple graph with NT inputs can be traced) Note: The initial direction for this PR defined the fake version of a real NT as a meta NT with the same nested structure (i.e. non-meta metadata tensors). This is no longer the case, as it is too restrictive. In particular, we want dynamism over the batch size, and a fixed nested structure fixes the batch size. Now the metadata tensors have symbolic sizes themselves. Due to this, dynamic shapes support is **required** to be on for fake / meta conversion of NTs. Open questions: * We’re creating a buffer for the NT sized using an unbacked SymInt. Do we need this to propagate? e.g. if we call empty_like(buffer) should we use the same unbacked SymInt to size the result? * Is it okay to require dynamic shapes for NT? (heavily based on work by mikaylagawarecki in #93981) cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire [ghstack-poisoned]
This PR adds initial fake / meta tensor support for nested tensors. In particular: * Adds `NestedTensorMeta` registrations for: * `detach` - used by `torch.Tensor._make_subclass()`, which `FakeTensor` uses * `values` - used for printing so it's nice to have * `zero_` / `normal_` / `fill_` - commonly-used init functions * `ones_like` / `empty_like` - commonly-used factories * `add` - simple op for validation purposes * `_nested_tensor_size()` / `_nested_tensor_strides()` - private accessors used in testing * Tweaks `FakeTensorMode.run_fallback_kernel()` to work with nested tensors * Enhances `MetaConverter` to work with nested tensors * `nested_sizes` / `nested_strides` have size `(s0, N)` where N = the rank of the underlying components of the NT * `storage_offsets` has size `s0` * The NT's `buffer` is of size `i0`, an unbacked SymInt * Registers fake tensor `op_impl` for the nested tensor constructor `_nested_tensor_from_tensor_list` * Errors for now since we don't support this NT <-> T boundary! * Expands tensor serialization to support the combo of NT + meta * Example: `FakeTensor(nested_tensor(..., device='meta', num_components=s0, dim=2), cpu)` * (bikeshedding on how this should look is welcome) * Adds tests to `test/test_nestedtensor.py` that use `CrossRefFakeMode` * exFails: anything that uses `_nested_tensor_from_tensor_list` since we don't support this for fake * Adds a basic Dynamo test that demonstrates valid fake usage (i.e. a simple graph with NT inputs can be traced) Note: The initial direction for this PR defined the fake version of a real NT as a meta NT with the same nested structure (i.e. non-meta metadata tensors). This is no longer the case, as it is too restrictive. In particular, we want dynamism over the batch size, and a fixed nested structure fixes the batch size. Now the metadata tensors have symbolic sizes themselves. Due to this, dynamic shapes support is **required** to be on for fake / meta conversion of NTs. Open questions: * ~~Is it okay to require dynamic shapes for NT?~~ We're saying yes! TODOs: * Properly track SymInts for NT metadata in `torch/fx/experimental/symbolic_shapes.py` * Throw a special error for `_nested_tensor_from_tensor_list()` and catch it in the test harness instead of manually listing cross-ref exFails * Be more granular about which checks are skipped for the symbolic metadata case (heavily based on work by mikaylagawarecki in #93981) cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire [ghstack-poisoned]
if t.is_nested: | ||
assert all([c is None for c in constraint]), \ | ||
"Dim constraints are not supported for nested tensors" | ||
track_symint(NestedTensorPropertySource(source, TensorProperty.SIZE), | ||
t._nested_tensor_size().size(0)) | ||
track_symint(NestedTensorPropertySource(source, TensorProperty.STRIDE), | ||
t._nested_tensor_strides().size(0)) | ||
track_symint(NestedTensorPropertySource(source, TensorProperty.STORAGE_OFFSET), | ||
t._nested_tensor_storage_offsets().size(0)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ezyang does this match what you expect here? I'm scared we're not guarding on enough because the metadata is so symbolic.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, well, it's better. But the way to figure out if you're guarding enough, is to just exhaustively run track_symint on every internal data structure. I don't agree with the NestedTensorPropertySource structure by the way; you should make an entirely new enum for stuff like _nested_tensor_size
, and then you can index each entry individually
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay gotcha, I changed it work recursively on the metadata tensors. Now it's more aggressively guarding on size, stride, and storage offset of _nested_tensor_size
, _nested_tensor_strides
, and _nested_tensor_storage_offsets
, but I'm not convinced this buys us much in practice.
Since, for example, contiguity of a nested tensor is defined by the values of _nested_tensor_size
, _nested_tensor_strides
, and _nested_tensor_storage_offsets
, is it still possible we're not guarding enough? Like we are guarding on the metadata structure e.g. (s0, n)
, but this structure is fairly permissive. We can pass those checks and still not know if the NT is contiguous.
This PR adds initial fake / meta tensor support for nested tensors. In particular: * Adds `NestedTensorMeta` registrations for: * `detach` - used by `torch.Tensor._make_subclass()`, which `FakeTensor` uses * `values` - used for printing so it's nice to have * `zero_` / `normal_` / `fill_` - commonly-used init functions * `ones_like` / `empty_like` - commonly-used factories * `add` - simple op for validation purposes * `_nested_tensor_size()` / `_nested_tensor_strides()` - private accessors used in testing * Tweaks `FakeTensorMode.run_fallback_kernel()` to work with nested tensors * Enhances `MetaConverter` to work with nested tensors * `nested_sizes` / `nested_strides` have size `(s0, N)` where N = the rank of the underlying components of the NT * `storage_offsets` has size `s0` * The NT's `buffer` is of size `i0`, an unbacked SymInt * Registers fake tensor `op_impl` for the nested tensor constructor `_nested_tensor_from_tensor_list` * Errors for now since we don't support this NT <-> T boundary! * Expands tensor serialization to support the combo of NT + meta * Example: `FakeTensor(nested_tensor(..., device='meta', num_components=s0, dim=2), cpu)` * (bikeshedding on how this should look is welcome) * Adds tests to `test/test_nestedtensor.py` that use `CrossRefFakeMode` * exFails: anything that uses `_nested_tensor_from_tensor_list` since we don't support this for fake * Adds a basic Dynamo test that demonstrates valid fake usage (i.e. a simple graph with NT inputs can be traced) Note: The initial direction for this PR defined the fake version of a real NT as a meta NT with the same nested structure (i.e. non-meta metadata tensors). This is no longer the case, as it is too restrictive. In particular, we want dynamism over the batch size, and a fixed nested structure fixes the batch size. Now the metadata tensors have symbolic sizes themselves. Due to this, dynamic shapes support is **required** to be on for fake / meta conversion of NTs. Open questions: * ~~Is it okay to require dynamic shapes for NT?~~ We're saying yes! TODOs: * Properly track SymInts for NT metadata in `torch/fx/experimental/symbolic_shapes.py` * Throw a special error for `_nested_tensor_from_tensor_list()` and catch it in the test harness instead of manually listing cross-ref exFails * Be more granular about which checks are skipped for the symbolic metadata case (heavily based on work by mikaylagawarecki in #93981) cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire [ghstack-poisoned]
This PR adds initial fake / meta tensor support for nested tensors. In particular: * Adds `NestedTensorMeta` registrations for: * `detach` - used by `torch.Tensor._make_subclass()`, which `FakeTensor` uses * `values` - used for printing so it's nice to have * `zero_` / `normal_` / `fill_` - commonly-used init functions * `ones_like` / `empty_like` - commonly-used factories * `add` - simple op for validation purposes * `_nested_tensor_size()` / `_nested_tensor_strides()` - private accessors used in testing * Tweaks `FakeTensorMode.run_fallback_kernel()` to work with nested tensors * Enhances `MetaConverter` to work with nested tensors * `nested_sizes` / `nested_strides` have size `(s0, N)` where N = the rank of the underlying components of the NT * `storage_offsets` has size `s0` * The NT's `buffer` is of size `i0`, an unbacked SymInt * Registers fake tensor `op_impl` for the nested tensor constructor `_nested_tensor_from_tensor_list` * Errors for now since we don't support this NT <-> T boundary! * Expands tensor serialization to support the combo of NT + meta * Example: `FakeTensor(nested_tensor(..., device='meta', num_components=s0, dim=2), cpu)` * (bikeshedding on how this should look is welcome) * Adds tests to `test/test_nestedtensor.py` that use `CrossRefFakeMode` * exFails: anything that uses `_nested_tensor_from_tensor_list` since we don't support this for fake * Adds a basic Dynamo test that demonstrates valid fake usage (i.e. a simple graph with NT inputs can be traced) Note: The initial direction for this PR defined the fake version of a real NT as a meta NT with the same nested structure (i.e. non-meta metadata tensors). This is no longer the case, as it is too restrictive. In particular, we want dynamism over the batch size, and a fixed nested structure fixes the batch size. Now the metadata tensors have symbolic sizes themselves. Due to this, dynamic shapes support is **required** to be on for fake / meta conversion of NTs. Open questions: * ~~Is it okay to require dynamic shapes for NT?~~ We're saying yes! TODOs: * Properly track SymInts for NT metadata in `torch/fx/experimental/symbolic_shapes.py` * Throw a special error for `_nested_tensor_from_tensor_list()` and catch it in the test harness instead of manually listing cross-ref exFails * Be more granular about which checks are skipped for the symbolic metadata case (heavily based on work by mikaylagawarecki in #93981) cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire [ghstack-poisoned]
Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as |
Stack from ghstack (oldest at bottom):
This PR adds initial fake / meta tensor support for nested tensors. In particular:
NestedTensorMeta
registrations for:detach
- used bytorch.Tensor._make_subclass()
, whichFakeTensor
usesvalues
- used for printing so it's nice to havezero_
/normal_
/fill_
- commonly-used init functionsones_like
/empty_like
- commonly-used factoriesadd
- simple op for validation purposes_nested_tensor_size()
/_nested_tensor_strides()
- private accessors used in testingFakeTensorMode.run_fallback_kernel()
to work with nested tensorsMetaConverter
to work with nested tensorsnested_sizes
/nested_strides
have size(s0, N)
where N = the rank of the underlying components of the NT(batch size, rank of underlying components)
, with e.g.nested_sizes[i,j] = the jth dim of the ith component
nested_sizes
, as the latter is not dynamic enough to avoid specializing on batch size(s0, N)
structure fornested_sizes
/nested_strides
is too loose to be able to guard on NT contiguity, for example. To determine contiguity, we need to index into the values ofnested_sizes
andnested_strides
, and they don't exist. For dense tensors, guarding on things like['a.stride()[0] == a.size()[1]', 'a.stride()[1] == 1', ...]
implicitly guards on contiguity.torch.compile()
with NT to contiguous-only and explicitly maintain / propagate a contiguity property even for meta / fake NTs?transpose()
within the math fallback kernel of SDPA, andchunk()
/split_with_sizes()
(which produce non-contiguous views) in internal models we're targetingstorage_offsets
has sizes0
buffer
is of sizei0
, an unbacked SymInt@op_impl
for the nested tensor constructor_nested_tensor_from_tensor_list
FakeTensor(nested_tensor(..., device='meta', num_components=s0, dim=2), cpu)
test/test_nestedtensor.py
that useCrossRefFakeMode
_nested_tensor_from_tensor_list
since we don't support this for fakeNote: The initial direction for this PR defined the fake version of a real NT as a meta NT with the same nested structure (i.e. non-meta metadata tensors). This is no longer the case, as it is too restrictive. In particular, we want dynamism over the batch size, and a fixed nested structure fixes the batch size. Now the metadata tensors have symbolic sizes themselves. Due to this, dynamic shapes support is required to be on for fake / meta conversion of NTs.
Open questions:
Is it okay to require dynamic shapes for NT?We're saying yes!TODOs:
Properly track SymInts for NT metadata intorch/fx/experimental/symbolic_shapes.py
Throw a special error for_nested_tensor_from_tensor_list()
and catch it in the test harness instead of manually listing cross-ref exFailsBe more granular about which checks are skipped for the symbolic metadata case(heavily based on work by @mikaylagawarecki in #93981)
cc @voznesenskym @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @ipiszy @bertmaher @soumith @desertfire