-
Notifications
You must be signed in to change notification settings - Fork 3.6k
[BugFix][Codegen, CUDA] Fix faulty codegen for FP8 #17673
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
thanks @AntonMoberg , @MasterJH5574 would be great if we can validate this PR |
Thank you @AntonMoberg! Would you mind providing an example which can reproduce the error? |
Hi @tqchen & @MasterJH5574! I am trying to produce a minimal reproducible example but it is proving a bit challenging as the error only occurs in some specific scenarios. However, during this time I have encountered more faulty Codegen related to FP8. I'll get back to you with updates ASAP :) |
I am converting this PR to draft while I work fleshing it out for more cases. Will provide basic tests and suggested fixes along the way! |
Thank you so much @AntonMoberg! |
8e6a786
to
6f3f13c
Compare
@MasterJH5574 @tqchen This should be ready to be reviewed now. Also feel free to make edits if something is fishy :) |
43b578b
to
e8d0e04
Compare
There we go, it should now be good to go! Had some rebase issues, so sorry for spamming updates |
This commit adds tests for the fp8 codegen & compilation of said code for the most common operators in LLMs and CNNs. Tested operators are: Matmul, Conv2d, Maxpool2d, Add, Relu, Gelu, GeluTanh, Sigmoid, Silu, Softmax
Fixed bug where CUDA codegen produces faulty code when a vectorizable BufferLoadNode contains a Float8 type. Codegen generated the invalid signature "make___nv_fp8x2_e5m2(param_0[v_.x], param_0[v_.y])" where "param_0" is of type "__nv_fp8_e5m2* __restrict__". This commit adds a missing check "is_float8()" for CodeGenCUDA::PrintVecElemLoadExpr that is called for vectorizable BufferLoadNodes. Which instead correctly generates the signature "_nv_fp8x2_e5m2(make_float2(static_cast<float>(param_0[v_.x], static_cast<float>(param_0[v_.y]))) Additionally this commit removes the added "make_" prefix for float8 in CodeGenCuda::PrintVecConstructor as the correct way to instansiate an nv_fp8x2_[e5m2/e4m3] is through the "_nv_fp8x2_[e5m2/e4m3]" constructor itself.
Vectorized FP8 are stored as __nv_[fp8x2/fp8x4]_[e5m2/e4m3] (i.e. 16bit registers). These types do not have overloaded binary operators (such as *) to handle these types. This commit adds the ability to do this by exctracting the high and low bits, statically casting them to floats, performing the operation, then repacking them into dual lane type.
Non-vectorized FP8 are store as __nv_fp8_[e5m2/e4m3] types, these types do not have support for binary operatios because internally FP8 are store in 16bit registers. This commits adds binary operator support by doing the operations in __half instead of fp8 (i.e cast up to 16-bit, then cast down to 8-bit).
Add missing support for unary operators in fp8. FP8 requires you to cast to fp16 to perform mathmatical operations, and this commit handles the casting to and from __half and adds missing checks for the tir intrinsics to generate the correct operator signatures.
Ping! @tqchen @MasterJH5574 |
Thank you @AntonMoberg so much for the update! Will take a look! |
Fixed bug where CUDA codegen produces faulty code when a vectorizable BufferLoadNode contains a Float8 type.
Codegen generated the invalid signature
make___nv_fp8x2_e5m2(param_0[v_.x], param_0[v_.y])
where "param_0" is of type__nv_fp8_e5m2* __restrict__
.This commit adds a missing check
is_float8()
for CodeGenCUDA::PrintVecElemLoadExpr that is called for vectorizable BufferLoadNodes. Which instead correctly generates the signature_nv_fp8x2_e5m2(make_float2(static_cast<float>(param_0[v_.x], static_cast<float>(param_0[v_.y])))
Additionally this commit removes the added "make_" prefix for float8 in CodeGenCuda::PrintVecConstructor as the correct way to instansiate an
nv_fp8x2_[e5m2/e4m3]
is through the_nv_fp8x2_[e5m2/e4m3]
constructor itself.