Skip to content

[PyTorch] Hook up fp16_gemv_trans to x86 fp16 GEMM #137918

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 12 commits into from

Conversation

swolchok
Copy link
Contributor

@swolchok swolchok commented Oct 14, 2024

This is the first big milestone we've been building towards!
(TODO: also hook it up to GEMV in the same way fp16_gemv_trans is hooked up)

Differential Revision: [D64280688](https://our.internmc.facebook.com/intern/diff/D64280688/)

**NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D64280688/)!

[ghstack-poisoned]
Copy link

pytorch-bot bot commented Oct 14, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/137918

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 06b6c11 with merge base 86602a6 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D64280688

swolchok added a commit that referenced this pull request Oct 14, 2024
This is the first big milestone we've been building towards!
(TODO: also hook it up to GEMV in the same way fp16_gemv_trans is hooked up)

Differential Revision: [D64280688](https://our.internmc.facebook.com/intern/diff/D64280688/)

**NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D64280688/)!

ghstack-source-id: 247859556
Pull Request resolved: #137918
@swolchok swolchok requested review from malfet and jgong5 October 14, 2024 18:17
Copy link
Collaborator

@jgong5 jgong5 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have performance numbers?

This is the first big milestone we've been building towards!
(Following rev also hooks this up to actual gemv.)
Differential Revision: [D64280688](https://our.internmc.facebook.com/intern/diff/D64280688/)

**NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D64280688/)!

[ghstack-poisoned]
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D64280688

This is the first big milestone we've been building towards!
(Following rev also hooks this up to actual gemv.)
Differential Revision: [D64280688](https://our.internmc.facebook.com/intern/diff/D64280688/)

**NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D64280688/)!

[ghstack-poisoned]
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D64280688

@swolchok
Copy link
Contributor Author

Do you have performance numbers?

it improves decoding performance about 5x for python torchchat.py generate stories110M --dtype fp16 --device cpu on an x86 machine without AVX512FP16.

This is the first big milestone we've been building towards!
(Following rev also hooks this up to actual gemv.)
Differential Revision: [D64280688](https://our.internmc.facebook.com/intern/diff/D64280688/)

**NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D64280688/)!

[ghstack-poisoned]
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D64280688

// is to upconvert to fp32 and call sgemm. We can do better by
// fusing the conversion.
const bool fp16_gemv_trans_fast_path_would_be_beneficial =
cpuinfo_initialize() && cpuinfo_has_x86_f16c() && !cpuinfo_has_x86_avx512fp16();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess checking cpuinfo_has_x86_avx512fp16 is not necessary since onednn (mkldnn) won't use avx512fp16 to compute gemms by default because the avx512fp16 fma would incur accuracy loss.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jgong5 https://community.intel.com/t5/Intel-oneAPI-Math-Kernel-Library/FP16-GEMM-using-AVX512-on-Sapphire-Rapids/m-p/1570739 doesn't seem to agree. I'm also confused -- I thought FMA was supposed to improve accuracy because it has high internal precision, so the result of the multiply doesn't have to be rounded to FP16 before the addition.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@swolchok The link you referred to is about MKL not oneDNN (or mkldnn). MKL has dedicated API (hgemm) that uses AVX512_FP16 instruction but users should be aware of the accuracy loss due to FP16 accumulators. It is not about the accumulator for a single FMA (which has high internal precision has you mentioned) but about accumulation along the K-dim across multiple FMAs. On the other hand, oneDNN uses FP32 accumulators to keep high accuracy.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, we use FP32 accumulation as well.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mkldnn_fp16_gemm despite the name is also available on ARM, it looks like your change will skip MKLDNN unless it is on x86 platform, wouldn't it? What is the motivation for it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mkldnn_fp16_gemm despite the name is also available on ARM,

news to me!

it looks like your change will skip MKLDNN unless it is on x86 platform, wouldn't it?

fortunately no, because ARM machines won't pass cpuinfo_has_x86_f16c().

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also available on ARM

I don't think so? https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/mkldnn/Utils.h#L120

something to look at for BF16 though.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, we use FP32 accumulation as well.

Will you remove this cpuinfo_has_x86_avx512fp16() check then? I don't think it is relevant.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove this cpuinfo_has_x86_avx512fp16() check then? I don't think it is relevant.

I am surprised, but you're definitely the authority on this and I don't have a Sapphire Rapids machine to test on. I'll leave a note for posterity though.

This is the first big milestone we've been building towards!
(Following rev also hooks this up to actual gemv.)
Differential Revision: [D64280688](https://our.internmc.facebook.com/intern/diff/D64280688/)

**NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D64280688/)!

[ghstack-poisoned]
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D64280688

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 31, 2024
This is the first big milestone we've been building towards!
(Following rev also hooks this up to actual gemv.)
Differential Revision: [D64280688](https://our.internmc.facebook.com/intern/diff/D64280688/)

**NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D64280688/)!

[ghstack-poisoned]
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D64280688

This is the first big milestone we've been building towards!
(Following rev also hooks this up to actual gemv.)
Testing: To check perf, I ran python torchchat.py generate stories110M
--dtype fp16 --device cpu on an x86 machine without AVX512FP16. Observed roughly 5x tokens/sec increase.
Differential Revision: [D64280688](https://our.internmc.facebook.com/intern/diff/D64280688/)

**NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D64280688/)!

[ghstack-poisoned]
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D64280688

This is the first big milestone we've been building towards!
(Following rev also hooks this up to actual gemv.)
Testing: To check perf, I ran python torchchat.py generate stories110M
--dtype fp16 --device cpu on an x86 machine without AVX512FP16. Observed roughly 5x tokens/sec increase.
Differential Revision: [D64280688](https://our.internmc.facebook.com/intern/diff/D64280688/)

**NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D64280688/)!

[ghstack-poisoned]
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D64280688

pytorchmergebot pushed a commit that referenced this pull request Nov 1, 2024
…rchitectures (#138005)

Following up on previous rev to use fp16_gemv_trans in gemv, not just gemm-used-for-gemv.

Differential Revision: [D64351092](https://our.internmc.facebook.com/intern/diff/D64351092/)
Pull Request resolved: #138005
Approved by: https://github.com/malfet
ghstack dependencies: #139082, #139083, #137918
pytorchmergebot pushed a commit that referenced this pull request Nov 1, 2024
No real reason to have the zero-beta restriction, so let's lift it.

Testing: intentionally broke new paths locally to verify test coverage existed

Differential Revision: [D64407752](https://our.internmc.facebook.com/intern/diff/D64407752/)

Pull Request resolved: #138275
Approved by: https://github.com/malfet
ghstack dependencies: #139082, #139083, #137918, #138005
malfet added a commit that referenced this pull request Nov 1, 2024
Caused by #137918
By guarding all cpuinfo use with `!defined(__s390x__ ) && !defined(__powerpc__)`
@malfet malfet mentioned this pull request Nov 1, 2024
pytorchmergebot pushed a commit that referenced this pull request Nov 1, 2024
Caused by #137918 By guarding all cpuinfo use with `!defined(__s390x__ ) && !defined(__powerpc__)`

Pull Request resolved: #139491
Approved by: https://github.com/huydhn, https://github.com/Skylion007
rahulsingh-intel pushed a commit to rahulsingh-intel/pytorch that referenced this pull request Nov 5, 2024
This is the first big milestone we've been building towards!
(Following rev also hooks this up to actual gemv.)
Testing: To check perf, I ran python torchchat.py generate stories110M
--dtype fp16 --device cpu on an x86 machine without AVX512FP16. Observed roughly 5x tokens/sec increase.
Differential Revision: [D64280688](https://our.internmc.facebook.com/intern/diff/D64280688/)

**NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D64280688/)!
Pull Request resolved: pytorch#137918
Approved by: https://github.com/malfet
ghstack dependencies: pytorch#139082, pytorch#139083
rahulsingh-intel pushed a commit to rahulsingh-intel/pytorch that referenced this pull request Nov 5, 2024
…rchitectures (pytorch#138005)

Following up on previous rev to use fp16_gemv_trans in gemv, not just gemm-used-for-gemv.

Differential Revision: [D64351092](https://our.internmc.facebook.com/intern/diff/D64351092/)
Pull Request resolved: pytorch#138005
Approved by: https://github.com/malfet
ghstack dependencies: pytorch#139082, pytorch#139083, pytorch#137918
rahulsingh-intel pushed a commit to rahulsingh-intel/pytorch that referenced this pull request Nov 5, 2024
No real reason to have the zero-beta restriction, so let's lift it.

Testing: intentionally broke new paths locally to verify test coverage existed

Differential Revision: [D64407752](https://our.internmc.facebook.com/intern/diff/D64407752/)

Pull Request resolved: pytorch#138275
Approved by: https://github.com/malfet
ghstack dependencies: pytorch#139082, pytorch#139083, pytorch#137918, pytorch#138005
rahulsingh-intel pushed a commit to rahulsingh-intel/pytorch that referenced this pull request Nov 5, 2024
Caused by pytorch#137918 By guarding all cpuinfo use with `!defined(__s390x__ ) && !defined(__powerpc__)`

Pull Request resolved: pytorch#139491
Approved by: https://github.com/huydhn, https://github.com/Skylion007
@github-actions github-actions bot deleted the gh/swolchok/666/head branch December 2, 2024 02:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request fb-exported Merged
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants