-
Notifications
You must be signed in to change notification settings - Fork 304
Add "Quantization in Practice" blogpost #912
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
👷 Deploy Preview for pytorch-dot-org-preview processing. 🔨 Explore the source changes: c1b8fee 🔍 Inspect the deploy log: https://app.netlify.com/sites/pytorch-dot-org-preview/deploys/6202972b0ccea50008fd004a |
@woo-kim @holly1238 The MathJax sections are rendering smaller than the surrounding text. Do you have suggestions on how to fix this? |
|
||
> If someone asks you what time it is, you don't respond "10:14:34:430705", but you might say "a quarter past 10". | ||
|
||
Quantization has roots in information compression; in deep networks it refers to reducing the numerical precision of its weights and/or activations. Overparameterized DNNs have more degrees of freedom and this makes them good candidates for information compression. When you quantize a model, two things generally happen - the model gets smaller and runs with better efficiency. Processing 8-bit numbers is faster than 32-bit numbers, and a smaller model has lower memory footprint and power consumption. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A source would help for the overparametrized point
For the processing 8 bit numbers is faster point that's because hardware providers made that explicitly available so worth mentioning
There are a few different ways to quantize your model with PyTorch. In this blog post, we'll take a look at how each technique looks like in practice. I will use a non-standard model that is not traceable, to paint an accurate picture of how much effort is really needed when quantizing your model. | ||
|
||
<div class="text-center"> | ||
<img src="/assets/images/quantization_gif.gif" width="60%"> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very cool image
|
||
To reconvert to floating point space, the inverse function is given by $\tilde r = (Q(r) - Z) \cdot S$. $\tilde r \neq r$, and their difference constitutes the **quantization error**. | ||
|
||
The scaling factor $S$ is simply the ratio of the input range to the output range: $S = \frac{\beta - \alpha}{\beta_q - \alpha_q}$ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is the scaling factor important to highlight here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A) to provide an intuition of where this number is coming from, and B) to help make sense of how a/symmetric quantization schemes are better/worse for a given use case
|
||
|
||
### Quantization Schemes | ||
$S, Z$ can be calculated and used for quantizing an entire tensor ("per-tensor"), or individually for each channel ("per-channel"). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
picture would help here
|
||
## FX GRAPH | ||
from torch.quantization import quantize_fx | ||
qconfig_dict = {"": torch.quantization.get_default_qconfig('fbgemm')} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what does fbgemm stand for? is this like saying nn.Linear
in a qconfig
?
model_quantized = quantize_fx.convert_fx(model_prepared) | ||
``` | ||
|
||
### Quantization-aware Training (QAT) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NVIDIA had a nice diagram on QAT since this flow is a bit complex would help to explain in a picture https://developer.nvidia.com/blog/improving-int8-accuracy-using-quantization-aware-training-and-tao-toolkit/
Same comment for the other 2 techniques
|
||
**Download the [notebook](https://gist.github.com/suraj813/735357e56321237950a0348b50f2f3b4) or run it on [Colab](https://colab.research.google.com/gist/suraj813/735357e56321237950a0348b50f2f3b4/fx-and-eager-mode-quantization-example.ipynb) (note that Colab runtimes may differ significantly from local machines).** | ||
|
||
Traceable models can be easily quantized with FX Graph Mode, but it's possible the model you're using is not traceable end-to-end. Maybe it has loops or `if` statements on inputs (dynamic control flow), or relies on third-party libraries. The model I use in this example has [dynamic control flow and uses third-party libraries](https://github.com/facebookresearch/demucs/blob/v2/demucs/model.py). As a result, it cannot be symbolically traced directly. In this code walkthrough, I show how you can bypass this limitation by quantizing the child modules individually for FX Graph Mode, and how to patch Quant/DeQuant stubs in Eager Mode. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The model we use
since this is an official tutorial
It's likely that you can still use QAT by "fine-tuning" it on a sample of the training dataset, but I did not try it on demucs (yet). | ||
|
||
|
||
## Quantizing "real-world" models |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is interesting in its own right, I'd rather you either write it out in this blog post or do a follow up post on quantizing non tracable models
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah it's the most interesting part of the article imo. But it's also very verbose - containing code and commentary - which is why I've only linked to the notebook from here. A follow-up post might be easier to parse, thanks for the suggestion.
|
||
|
||
|
||
## What's next - Define-by-Run Quantization |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This would be in the followup post too and you can basically gett people excited about reading this followup post by describing briefly how cool it is and how it works
Also make sure to synthesize everything people learnt at the very end, that your table and figures will help this doc be a reference for people new to quantization
The scaling factor $S$ is simply the ratio of the input range to the output range: $S = \frac{\beta - \alpha}{\beta_q - \alpha_q}$ | ||
where [$\alpha, \beta$] is the clipping range of the input, i.e. the boundaries of permissible inputs. [$\alpha_q, \beta_q$] is the range in quantized output space that it is mapped to. For 8-bit quantization, the output range $\beta_q - \alpha_q <= (2^8 - 1) $. | ||
|
||
The process of choosing the appropriate input range is known as **calibration**; commonly used methods are MinMax and Entropy. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we mention observers here? calibration is just a step that runs some sample data through the model and also through the inserted observers, so that the values can be recorded and used to calculate quantization parameters
The `QConfig` ([code](https://github.com/PyTorch/PyTorch/blob/d6b15bfcbdaff8eb73fa750ee47cef4ccee1cd92/torch/ao/quantization/qconfig.py#L165)) NamedTuple specifies the observers and quantization schemes for the network's weights and activations. The default qconfig is at `torch.quantization.get_default_qconfig(backend)` where `backend='fbgemm'` for x86 CPU and `backend='qnnpack'` for ARM. | ||
|
||
|
||
## In PyTorch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there are two other dimensions, quantization mode: static/dynamic/weight only, and backend: server cpu/mobile cpu, gpu. I have more info in my slides
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
are the gpu backends usable yet? i don't see this as a backend in
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
GPU path with TensorRT is in prototype for internal users, GPU path with cudnn and cuda is WIP
<img src="/assets/images/quantization_gif.gif" width="60%"> | ||
</div> | ||
|
||
## A quick introduction to quantization |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would you like to integrate some part of this in the official quantization api?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you mean the API itself, or in the docs? I'm happy to add it!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the documentation, we can add the explanations of core things like qconfig, quantized tensor, observer/fake_quant, qscheme to documentation: https://pytorch.org/docs/master/quantization.html#
No description provided.