Skip to content

[BugFix][Codegen, CUDA] Fix faulty codegen for FP8 #17673

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

AntonMoberg
Copy link

@AntonMoberg AntonMoberg commented Feb 24, 2025

Fixed bug where CUDA codegen produces faulty code when a vectorizable BufferLoadNode contains a Float8 type.

Codegen generated the invalid signature
make___nv_fp8x2_e5m2(param_0[v_.x], param_0[v_.y]) where "param_0" is of type __nv_fp8_e5m2* __restrict__.

This commit adds a missing check is_float8() for CodeGenCUDA::PrintVecElemLoadExpr that is called for vectorizable BufferLoadNodes. Which instead correctly generates the signature _nv_fp8x2_e5m2(make_float2(static_cast<float>(param_0[v_.x], static_cast<float>(param_0[v_.y])))

Additionally this commit removes the added "make_" prefix for float8 in CodeGenCuda::PrintVecConstructor as the correct way to instansiate an nv_fp8x2_[e5m2/e4m3] is through the _nv_fp8x2_[e5m2/e4m3] constructor itself.

@tqchen
Copy link
Member

tqchen commented Feb 24, 2025

thanks @AntonMoberg , @MasterJH5574 would be great if we can validate this PR

@MasterJH5574
Copy link
Contributor

Thank you @AntonMoberg! Would you mind providing an example which can reproduce the error?

@AntonMoberg
Copy link
Author

Hi @tqchen & @MasterJH5574! I am trying to produce a minimal reproducible example but it is proving a bit challenging as the error only occurs in some specific scenarios. However, during this time I have encountered more faulty Codegen related to FP8. I'll get back to you with updates ASAP :)

@AntonMoberg
Copy link
Author

I am converting this PR to draft while I work fleshing it out for more cases. Will provide basic tests and suggested fixes along the way!

@MasterJH5574
Copy link
Contributor

Thank you so much @AntonMoberg!

@AntonMoberg AntonMoberg force-pushed the main branch 3 times, most recently from 8e6a786 to 6f3f13c Compare March 10, 2025 09:42
@AntonMoberg AntonMoberg marked this pull request as ready for review March 10, 2025 09:52
@AntonMoberg
Copy link
Author

@MasterJH5574 @tqchen This should be ready to be reviewed now.
I am not 100% familiar with this side of the codebase so please make sure I am not making any silly mistakes and that this doesn't break any other things.

Also feel free to make edits if something is fishy :)

@AntonMoberg AntonMoberg force-pushed the main branch 5 times, most recently from 43b578b to e8d0e04 Compare March 10, 2025 16:03
@AntonMoberg
Copy link
Author

There we go, it should now be good to go! Had some rebase issues, so sorry for spamming updates

This commit adds tests for the fp8 codegen & compilation of said code
for the most common operators in LLMs and CNNs.
Tested operators are: Matmul, Conv2d, Maxpool2d, Add, Relu,
Gelu, GeluTanh, Sigmoid, Silu, Softmax
Fixed bug where CUDA codegen produces faulty code when a vectorizable
BufferLoadNode contains a Float8 type.

Codegen generated the invalid signature
"make___nv_fp8x2_e5m2(param_0[v_.x], param_0[v_.y])" where "param_0" is
of type "__nv_fp8_e5m2* __restrict__".

This commit adds a missing check "is_float8()" for
CodeGenCUDA::PrintVecElemLoadExpr that is called for
vectorizable BufferLoadNodes. Which instead correctly generates the
signature "_nv_fp8x2_e5m2(make_float2(static_cast<float>(param_0[v_.x],
static_cast<float>(param_0[v_.y])))

Additionally this commit removes the added "make_" prefix for float8 in
CodeGenCuda::PrintVecConstructor as the correct way to instansiate an
nv_fp8x2_[e5m2/e4m3] is through the "_nv_fp8x2_[e5m2/e4m3]"
constructor itself.
Vectorized FP8 are stored as __nv_[fp8x2/fp8x4]_[e5m2/e4m3] (i.e. 16bit
registers). These types do not have overloaded binary operators (such as
*) to handle these types. This commit adds the ability to do this by
exctracting the high and low bits, statically casting them to floats,
performing the operation, then repacking them into dual lane type.
Non-vectorized FP8 are store as __nv_fp8_[e5m2/e4m3] types, these types
do not have support for binary operatios because internally FP8 are
store in 16bit registers. This commits adds binary operator support by
doing the operations in __half instead of fp8 (i.e cast up to 16-bit,
then cast down to 8-bit).
Add missing support for unary operators in fp8. FP8 requires you to cast
to fp16 to perform mathmatical operations, and this commit handles the
casting to and from __half and adds missing checks for the tir
intrinsics to generate the correct operator signatures.
@AntonMoberg
Copy link
Author

Ping! @tqchen @MasterJH5574

@MasterJH5574
Copy link
Contributor

Thank you @AntonMoberg so much for the update! Will take a look!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants