Skip to content

Conversation

mlazos
Copy link
Contributor

@mlazos mlazos commented Feb 14, 2023

Summary:

Adds NNC-like logging that is configured through an env var TORCH_LOGS
Examples:
TORCH_LOGS="dynamo,guards" python script.py - prints dynamo logs at level INFO with guards of all functions that are compiled

TORCH_LOGS="+dynamo,guards,graph" python script.py - prints dynamo logs at level DEBUG with guards and graphs (in tabular) format of all graphs that are compiled

More examples with full output

Implementation:
The implementation parses the log settings from the environment, finds any components (aot, dynamo, inductor) or other loggable objects (guards, graph, etc.) and generates a log_state object. This object contains all of the enabled artifacts, and a qualified log name -> level mapping. _init_logs then adds handlers to the highest level logs (the registered logs), and sets any artifact loggers to level DEBUG if the artifact is enabled.

Note: set_logs is an alternative for manipulating the log_state, but if the environment contains TORCH_LOGS, the environment settings will be prioritized.

Adding a new log:
To add a new log, a dev should add their log name to torch._logging._registrations (there are examples there already).

Adding a new artifact:
To add a new artifact, a dev should add their artifact name to torch._logging._registrations as well.
Additionally, wherever the artifact is logged, torch._logging.getArtifactLogger(__name__, <artifact_name>) should be used instead of the standard logging implementation.

design doc

cc @soumith @voznesenskym @yanboliang @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @desertfire

@pytorch-bot
Copy link

pytorch-bot bot commented Feb 14, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/94858

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 Failures

As of commit 605088b:

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@mlazos mlazos changed the title Consistent, configurable, extensible logging NNC-like configurable logging Feb 14, 2023
@mlazos mlazos requested a review from albanD February 14, 2023 22:01
@mlazos mlazos requested a review from cpuhrsch February 14, 2023 22:15
@mlazos mlazos changed the title NNC-like configurable logging component-level configurable logging Feb 15, 2023
@williamwen42
Copy link
Member

Are there any configs that can be removed because of this change? e.g. log_level or output_code?

@ezyang
Copy link
Contributor

ezyang commented Feb 15, 2023

cc @stas00

@stas00
Copy link
Contributor

stas00 commented Feb 15, 2023

Thank you, @ezyang for the ping. I have already followed up here:
#94788 (comment)

but let me copy it here:


@mlazos, if it's the way proposed by your PR, please make sure there is actual API to use besides env vars.

env vars are fantastic for developers of the component, but when a 3rd party application/framework needs to control it, env vars become very difficult to use and you have to write API anyway to support those env vars overrides, so giving those to the users would make their code more robust, IMHO.

Additionally, again it looks like your solution is pytorch-developer oriented (which is super-useful). Users need to have a simple blank - cover-all flag, so that they don't need to list out all the possible components.

@stas00
Copy link
Contributor

stas00 commented Feb 15, 2023

additionally, this PR invents some sort of new logging level definition semantics, which again looks very neat for devs, but this is not what I think is needed for non-pytorch developers. The proposal here also doesn't allow for the full range of log levels.

The log levels are:

log_levels = {
    "debug": logging.DEBUG,
    "info": logging.INFO,
    "warning": logging.WARNING,
    "error": logging.ERROR,
    "critical": logging.CRITICAL,
}

and ideally should be settable to each of these at will. So perhaps the syntactic sugar can be added to the boring long full definitions, but not replace it.

I did show how we implement these across various projects at HuggingFace, logging.py - I'm not insisting how it should be done here, just showing what appears to work really well.

If you want to add sub-systems to it, perhaps there should be an additional argument that speaks to a specific sub-system. as in:

torch.utils.logging.set_verbosity(all=logging.INFO, dynamo=logging.DEBUG, graph=logging.ERROR)

so most users will just use all, and developers can then override specific sub-system as I have shown above. and this API is future-proof if new sub-systems are added or renamed - just need to use **kwargs in the util definition).

Please let me know if this is any helpful and I'm going in the right direction or not.

Please note that I'm on both sides of the fence - I would like to have a very simple API for users, while allowing for developers to achieve their needs as well.

@stas00
Copy link
Contributor

stas00 commented Feb 15, 2023

Also from the description in the OP I don't see it proposing to cover everything, e.g. it doesn't look like torch.distributed is there and it can be pretty noisy on a multi-gpu setup.

When in doubt please always think of someone using 256 gpus and who is going to see the same info line 256 times.

@mlazos
Copy link
Contributor Author

mlazos commented Feb 15, 2023

@stas00 Thanks for the suggestions, I think these all seem super useful and are doable.

You're correct that this is not for everything, this is initially for the PyTorch 2.0 components - TorchDynamo, AotAutograd, and TorchInductor. After seeing the RFC and issue I thought maybe we could expand this since currently there isn't a centralized system. Agreed on the torch.distributed piece, haven't thought about that at all too, this would need to be handled before expanding into that domain.

Here's what I gathered from your comments:

  1. We need a user facing API other than env vars - agreed, I can add this, I really like your suggested API.
  2. Full range of log levels - this is taken into account in the component (additional > in front of components to indicate more verbosity), for the user facing API the level can be provided through the kwargs as you showed.
  3. One flag for all on - again agreed, right now I considered TORCH_COMPILE_DEBUG=1 for this but I think that will just be confusing tbh. A component "all" would definitely work for this, and allows enablement through a user-facing API.

@mlazos
Copy link
Contributor Author

mlazos commented Feb 15, 2023

Are there any configs that can be removed because of this change? e.g. log_level or output_code?

Yeah good catch, I will remove these

@mlazos mlazos changed the title component-level configurable logging component-level configurable logging for dynamo,inductor,aot Feb 15, 2023
@mlazos mlazos changed the title component-level configurable logging for dynamo,inductor,aot component-level configurable logging for dynamo, inductor, aot Feb 15, 2023
Copy link
Contributor

@ezyang ezyang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some bugs, but the overall structure is good. Land as soon as you can!

@mlazos
Copy link
Contributor Author

mlazos commented Mar 17, 2023

There are some bugs, but the overall structure is good. Land as soon as you can!

Thanks! Really appreciated your design feedback!

@mlazos
Copy link
Contributor Author

mlazos commented Mar 17, 2023

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Mar 17, 2023
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: linux-binary-manywheel / manywheel-py3_8-cuda11_7-build / build

Details for Dev Infra team Raised by workflow job



@functools.lru_cache(None)
def warning_once(self, *args, **kwargs):
Copy link
Contributor

@stas00 stas00 Mar 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so how was this resolved at the end?

Our discussion was wiped out here and I can't find it in the see of resolved discussions on the discussion tab.

The API declares this as a method but it can't be used as a method, no?

Copy link
Contributor

@stas00 stas00 Mar 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can't work:

def test(self): pass
class A: pass
a = A()
a.test()
Traceback (most recent call last):
  File "/tmp/test1.py", line 7, in <module>
    a.test()
AttributeError: 'A' object has no attribute 'test'

so the doc is invalid. as it's not identical to logger.warning and it's not a method.

Copy link
Contributor

@stas00 stas00 Mar 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. it has to be called as a function warning_once(logger, ...)
  2. and thus has to be imported.

Copy link
Contributor

@stas00 stas00 Mar 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made a proposal below of one possible way to rectify this

Copy link
Contributor Author

@mlazos mlazos Mar 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I ended up going with your approach and leaving this is as. I made the call that the ux was better to patch the logger class.

The only con was bad form for patching std lib but I felt like this was okay.

Copy link
Contributor Author

@mlazos mlazos Mar 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll commit your code below, I think it's fine for what's needed. This is kind of standalone thing, let me know if you want to follow up on this in a separate PR, happy to consider other solutions.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed it's an issue with different libraries patching stdlib

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the pytorch repo is so big, I worry that I won't get traction on opting everyone into a custom version of get_logger for this one method. I followed the pattern with everything else: have a clear way of getting functionality (ie getArtifactLogger if artifacts are desired, and now a specific warning once functionality when that's desired) and then the basic case of vanilla logging is unmodified.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that works - thank you for handling so many nuances, @mlazos! awesome work!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!! I really appreciate the time and effort you put into reviewing the document and PR

@mlazos
Copy link
Contributor Author

mlazos commented Mar 18, 2023

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: inductor / cuda11.8-py3.10-gcc7-sm86 / test (inductor_timm, 1, 2, linux.g5.4xlarge.nvidia.gpu)

Details for Dev Infra team Raised by workflow job

@mlazos
Copy link
Contributor Author

mlazos commented Mar 18, 2023

@pytorchbot merge -f "logs don't affect accuracy"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

cyyever pushed a commit to cyyever/pytorch_private that referenced this pull request Mar 23, 2023
Summary:

Adds NNC-like logging that is configured through an env var `TORCH_COMPILE_LOGS`
Examples:
`TORCH_LOGS="dynamo,guards" python script.py` - prints dynamo logs at level INFO with guards of all functions that are compiled

`TORCH_LOGS="+dynamo,guards,graph" python script.py` - prints dynamo logs at level DEBUG with guards and graphs (in tabular) format of all graphs that are compiled

[More examples with full output](https://gist.github.com/mlazos/b17f474457308ce15e88c91721ac1cce)

Implementation:
The implementation parses the log settings from the environment, finds any components (aot, dynamo, inductor) or other loggable objects (guards, graph, etc.) and generates a log_state object. This object contains all of the enabled artifacts, and a qualified log name -> level mapping. _init_logs then adds handlers to the highest level logs (the registered logs), and sets any artifact loggers to level DEBUG if the artifact is enabled.

Note: set_logs is an alternative for manipulating the log_state, but if the environment contains TORCH_LOGS, the environment settings will be prioritized.

Adding a new log:
To add a new log, a dev should add their log name to torch._logging._registrations (there are examples there already).

Adding a new artifact:
To add a new artifact, a dev should add their artifact name to torch._logging._registrations as well.
Additionally, wherever the artifact is logged, `torch._logging.getArtifactLogger(__name__, <artifact_name>)` should be used instead of the standard logging implementation.

[design doc](https://docs.google.com/document/d/1ZRfTWKa8eaPq1AxaiHrq4ASTPouzzlPiuquSBEJYwS8/edit#)

Pull Request resolved: pytorch/pytorch#94858
Approved by: https://github.com/ezyang
cyyever pushed a commit to cyyever/pytorch_private that referenced this pull request Mar 27, 2023
Summary:

Adds NNC-like logging that is configured through an env var `TORCH_COMPILE_LOGS`
Examples:
`TORCH_LOGS="dynamo,guards" python script.py` - prints dynamo logs at level INFO with guards of all functions that are compiled

`TORCH_LOGS="+dynamo,guards,graph" python script.py` - prints dynamo logs at level DEBUG with guards and graphs (in tabular) format of all graphs that are compiled

[More examples with full output](https://gist.github.com/mlazos/b17f474457308ce15e88c91721ac1cce)

Implementation:
The implementation parses the log settings from the environment, finds any components (aot, dynamo, inductor) or other loggable objects (guards, graph, etc.) and generates a log_state object. This object contains all of the enabled artifacts, and a qualified log name -> level mapping. _init_logs then adds handlers to the highest level logs (the registered logs), and sets any artifact loggers to level DEBUG if the artifact is enabled.

Note: set_logs is an alternative for manipulating the log_state, but if the environment contains TORCH_LOGS, the environment settings will be prioritized.

Adding a new log:
To add a new log, a dev should add their log name to torch._logging._registrations (there are examples there already).

Adding a new artifact:
To add a new artifact, a dev should add their artifact name to torch._logging._registrations as well.
Additionally, wherever the artifact is logged, `torch._logging.getArtifactLogger(__name__, <artifact_name>)` should be used instead of the standard logging implementation.

[design doc](https://docs.google.com/document/d/1ZRfTWKa8eaPq1AxaiHrq4ASTPouzzlPiuquSBEJYwS8/edit#)

Pull Request resolved: pytorch/pytorch#94858
Approved by: https://github.com/ezyang
pytorchmergebot pushed a commit that referenced this pull request Sep 18, 2023
```
   torch._dynamo.config.log_level = logging.INFO
   torch._dynamo.config.output_code = True
```

were replaced with the module level log control #94858
Pull Request resolved: #109409
Approved by: https://github.com/msaroufim
@github-actions github-actions bot deleted the mlazos/logging branch September 17, 2024 01:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants