Skip to content

Add "Quantization in Practice" blogpost #912

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Feb 8, 2022

Conversation

subramen
Copy link
Contributor

No description provided.

@netlify
Copy link

netlify bot commented Jan 19, 2022

👷 Deploy Preview for pytorch-dot-org-preview processing.

🔨 Explore the source changes: c1b8fee

🔍 Inspect the deploy log: https://app.netlify.com/sites/pytorch-dot-org-preview/deploys/6202972b0ccea50008fd004a

@woo-kim woo-kim marked this pull request as ready for review January 20, 2022 02:18
@subramen
Copy link
Contributor Author

subramen commented Jan 20, 2022

@woo-kim @holly1238 The MathJax sections are rendering smaller than the surrounding text. Do you have suggestions on how to fix this?


> If someone asks you what time it is, you don't respond "10:14:34:430705", but you might say "a quarter past 10".

Quantization has roots in information compression; in deep networks it refers to reducing the numerical precision of its weights and/or activations. Overparameterized DNNs have more degrees of freedom and this makes them good candidates for information compression. When you quantize a model, two things generally happen - the model gets smaller and runs with better efficiency. Processing 8-bit numbers is faster than 32-bit numbers, and a smaller model has lower memory footprint and power consumption.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A source would help for the overparametrized point

For the processing 8 bit numbers is faster point that's because hardware providers made that explicitly available so worth mentioning

There are a few different ways to quantize your model with PyTorch. In this blog post, we'll take a look at how each technique looks like in practice. I will use a non-standard model that is not traceable, to paint an accurate picture of how much effort is really needed when quantizing your model.

<div class="text-center">
<img src="/assets/images/quantization_gif.gif" width="60%">
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very cool image


To reconvert to floating point space, the inverse function is given by $\tilde r = (Q(r) - Z) \cdot S$. $\tilde r \neq r$, and their difference constitutes the **quantization error**.

The scaling factor $S$ is simply the ratio of the input range to the output range: $S = \frac{\beta - \alpha}{\beta_q - \alpha_q}$
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is the scaling factor important to highlight here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A) to provide an intuition of where this number is coming from, and B) to help make sense of how a/symmetric quantization schemes are better/worse for a given use case



### Quantization Schemes
$S, Z$ can be calculated and used for quantizing an entire tensor ("per-tensor"), or individually for each channel ("per-channel").
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

picture would help here


## FX GRAPH
from torch.quantization import quantize_fx
qconfig_dict = {"": torch.quantization.get_default_qconfig('fbgemm')}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does fbgemm stand for? is this like saying nn.Linear in a qconfig?

model_quantized = quantize_fx.convert_fx(model_prepared)
```

### Quantization-aware Training (QAT)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NVIDIA had a nice diagram on QAT since this flow is a bit complex would help to explain in a picture https://developer.nvidia.com/blog/improving-int8-accuracy-using-quantization-aware-training-and-tao-toolkit/

Same comment for the other 2 techniques


**Download the [notebook](https://gist.github.com/suraj813/735357e56321237950a0348b50f2f3b4) or run it on [Colab](https://colab.research.google.com/gist/suraj813/735357e56321237950a0348b50f2f3b4/fx-and-eager-mode-quantization-example.ipynb) (note that Colab runtimes may differ significantly from local machines).**

Traceable models can be easily quantized with FX Graph Mode, but it's possible the model you're using is not traceable end-to-end. Maybe it has loops or `if` statements on inputs (dynamic control flow), or relies on third-party libraries. The model I use in this example has [dynamic control flow and uses third-party libraries](https://github.com/facebookresearch/demucs/blob/v2/demucs/model.py). As a result, it cannot be symbolically traced directly. In this code walkthrough, I show how you can bypass this limitation by quantizing the child modules individually for FX Graph Mode, and how to patch Quant/DeQuant stubs in Eager Mode.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The model we use since this is an official tutorial

It's likely that you can still use QAT by "fine-tuning" it on a sample of the training dataset, but I did not try it on demucs (yet).


## Quantizing "real-world" models
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is interesting in its own right, I'd rather you either write it out in this blog post or do a follow up post on quantizing non tracable models

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah it's the most interesting part of the article imo. But it's also very verbose - containing code and commentary - which is why I've only linked to the notebook from here. A follow-up post might be easier to parse, thanks for the suggestion.




## What's next - Define-by-Run Quantization
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would be in the followup post too and you can basically gett people excited about reading this followup post by describing briefly how cool it is and how it works

Also make sure to synthesize everything people learnt at the very end, that your table and figures will help this doc be a reference for people new to quantization

The scaling factor $S$ is simply the ratio of the input range to the output range: $S = \frac{\beta - \alpha}{\beta_q - \alpha_q}$
where [$\alpha, \beta$] is the clipping range of the input, i.e. the boundaries of permissible inputs. [$\alpha_q, \beta_q$] is the range in quantized output space that it is mapped to. For 8-bit quantization, the output range $\beta_q - \alpha_q <= (2^8 - 1) $.

The process of choosing the appropriate input range is known as **calibration**; commonly used methods are MinMax and Entropy.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we mention observers here? calibration is just a step that runs some sample data through the model and also through the inserted observers, so that the values can be recorded and used to calculate quantization parameters

The `QConfig` ([code](https://github.com/PyTorch/PyTorch/blob/d6b15bfcbdaff8eb73fa750ee47cef4ccee1cd92/torch/ao/quantization/qconfig.py#L165)) NamedTuple specifies the observers and quantization schemes for the network's weights and activations. The default qconfig is at `torch.quantization.get_default_qconfig(backend)` where `backend='fbgemm'` for x86 CPU and `backend='qnnpack'` for ARM.


## In PyTorch
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there are two other dimensions, quantization mode: static/dynamic/weight only, and backend: server cpu/mobile cpu, gpu. I have more info in my slides

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GPU path with TensorRT is in prototype for internal users, GPU path with cudnn and cuda is WIP

<img src="/assets/images/quantization_gif.gif" width="60%">
</div>

## A quick introduction to quantization
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would you like to integrate some part of this in the official quantization api?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you mean the API itself, or in the docs? I'm happy to add it!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the documentation, we can add the explanations of core things like qconfig, quantized tensor, observer/fake_quant, qscheme to documentation: https://pytorch.org/docs/master/quantization.html#

@subramen subramen merged commit 89bab36 into pytorch:site Feb 8, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants