Skip to content

Commit 6f3f13c

Browse files
committed
Fix faulty codegen for unary fp8 operators
Add missing support for unary operators in fp8. FP8 requires you to cast to fp16 to perform mathmatical operations, and this commit handles the casting to and from __half and adds missing checks for the tir intrinsics to generate the correct operator signatures.
1 parent 5434141 commit 6f3f13c

File tree

7 files changed

+35
-21
lines changed

7 files changed

+35
-21
lines changed

python/tvm/relax/transform/legalize_ops/nn.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -475,8 +475,8 @@ def te_gelu(x: te.Tensor):
475475
dtype = x.dtype
476476
erf_inp = x * tir.const(0.5**0.5, dtype)
477477

478-
if dtype == "float16":
479-
erf = topi.math.cast(topi.erf(topi.math.cast(erf_inp, "float32")), "float16")
478+
if dtype == "float16" or dtype == "e5m2_float8" or dtype == "e4m3_float8":
479+
erf = topi.math.cast(topi.erf(topi.math.cast(erf_inp, "float32")), dtype)
480480
else:
481481
erf = topi.erf(erf_inp)
482482

src/relax/op/distributed/nn.cc

+2-1
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@ StructInfo InferDistStructInfoSoftmax(const Call& call, const BlockBuilder& ctx)
3232
ctx->ReportFatal(Diagnostic::Error(call)
3333
<< "Input of distributed operator must have known ndim");
3434
}
35-
if (!input_tensor_sinfo->IsUnknownDtype() && !input_tensor_sinfo->dtype.is_float()) {
35+
if (!input_tensor_sinfo->IsUnknownDtype() && !input_tensor_sinfo->dtype.is_float() &&
36+
!input_tensor_sinfo->dtype.is_float16() && !input_tensor_sinfo->dtype.is_float8()) {
3637
ctx->ReportFatal(Diagnostic::Error(call) << "Softmax requires the input tensor to have float "
3738
"dtype. However, the given input dtype is "
3839
<< input_tensor_sinfo->dtype);

src/relax/op/distributed/unary.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@ StructInfo InferDistStructInfoUnary(const Call& call, const BlockBuilder& ctx,
4040
TensorStructInfo input_tensor_sinfo = input_dtensor_sinfo->tensor_sinfo;
4141

4242
if (require_float_dtype && !input_tensor_sinfo->IsUnknownDtype() &&
43-
!input_tensor_sinfo->dtype.is_float()) {
43+
!input_tensor_sinfo->dtype.is_float() && !input_tensor_sinfo->dtype.is_float16() &&
44+
!input_tensor_sinfo->dtype.is_float8()) {
4445
ctx->ReportFatal(
4546
Diagnostic::Error(call)
4647
<< call->op

src/relax/op/nn/nn.cc

+2-1
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,8 @@ StructInfo InferStructInfoSoftmax(const Call& call, const BlockBuilder& ctx) {
7474
if (data_sinfo->IsUnknownNdim()) {
7575
return data_sinfo;
7676
}
77-
if (!data_sinfo->IsUnknownDtype() && !data_sinfo->dtype.is_float()) {
77+
if (!data_sinfo->IsUnknownDtype() && !data_sinfo->dtype.is_float() &&
78+
!data_sinfo->dtype.is_float16() && !data_sinfo->dtype.is_float8()) {
7879
ctx->ReportFatal(Diagnostic::Error(call) << "Softmax requires the input tensor to have float "
7980
"dtype. However, the given input dtype is "
8081
<< data_sinfo->dtype);

src/relax/op/op_common.h

+3-2
Original file line numberDiff line numberDiff line change
@@ -199,8 +199,9 @@ template <bool require_float_dtype, typename FType>
199199
inline StructInfo InferStructInfoUnary(const Call& call, const BlockBuilder& ctx,
200200
FType f_compute_out_dtype) {
201201
TensorStructInfo input_sinfo = GetUnaryInputTensorStructInfo(call, ctx);
202-
if (require_float_dtype && !input_sinfo->IsUnknownDtype() &&
203-
(!input_sinfo->dtype.is_float() && !input_sinfo->dtype.is_bfloat())) {
202+
if (require_float_dtype && !input_sinfo->IsUnknownDtype() && !input_sinfo->dtype.is_float() &&
203+
!input_sinfo->dtype.is_bfloat() && !input_sinfo->dtype.is_float16() &&
204+
!input_sinfo->dtype.is_float8()) {
204205
ctx->ReportFatal(
205206
Diagnostic::Error(call)
206207
<< call->op

src/target/source/codegen_cuda.cc

+23-13
Original file line numberDiff line numberDiff line change
@@ -149,8 +149,7 @@ std::string CodeGenCUDA::Finish() {
149149
if (enable_fp16_) {
150150
decl_stream << "#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)\n";
151151
decl_stream << "#include <cuda_fp16.h>\n";
152-
decl_stream << "__device__ half max"
153-
<< "(half a, half b)\n"
152+
decl_stream << "__device__ half max" << "(half a, half b)\n"
154153
<< "{\n return __hgt(__half(a), __half(b)) ? a : b;\n}\n";
155154
decl_stream << "__device__ half min(half a, half b)\n"
156155
<< "{\n return __hlt(__half(a), __half(b)) ? a : b;\n}\n";
@@ -165,8 +164,7 @@ std::string CodeGenCUDA::Finish() {
165164
if (enable_bf16_) {
166165
decl_stream << "#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)\n";
167166
decl_stream << "#include <cuda_bf16.h>\n";
168-
decl_stream << "__device__ nv_bfloat16 max"
169-
<< "(nv_bfloat16 a, nv_bfloat16 b)\n"
167+
decl_stream << "__device__ nv_bfloat16 max" << "(nv_bfloat16 a, nv_bfloat16 b)\n"
170168
<< "{\n return __hgt(a, b) ? a : b;\n}\n";
171169
decl_stream << "__device__ nv_bfloat16 min(nv_bfloat16 a, nv_bfloat16 b)\n"
172170
<< "{\n return __hlt(a, b) ? a : b;\n}\n";
@@ -542,8 +540,7 @@ void CodeGenCUDA::PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr l
542540
}
543541
for (int i = 0, lanes = t.lanes() / 2; i < lanes; ++i) {
544542
if (isalpha(op[0]) || op[0] == '_') {
545-
value_temp << op << "2"
546-
<< "(__half2(";
543+
value_temp << op << "2" << "(__half2(";
547544
PrintVecElemLoad(vlhs, lhs.dtype(), i * lanes, value_temp);
548545
value_temp << "), __half2(";
549546
PrintVecElemLoad(vrhs, rhs.dtype(), i * lanes, value_temp);
@@ -653,8 +650,7 @@ void CodeGenCUDA::PrintVecElemStore(const std::string& vec, DataType t, int i,
653650
ICHECK(i >= 0 && i < (t.bits() == 8 ? 16 : (t.bits() == 16 || t.bits() == 32) ? 8 : 4));
654651
if (t.bits() == 8 && (t.is_int() || t.is_uint())) {
655652
if (t.lanes() == 2 || t.lanes() == 3) {
656-
stream << vec << '.' << access[i % t.lanes()] << "="
657-
<< "(" << value << ");\n";
653+
stream << vec << '.' << access[i % t.lanes()] << "=" << "(" << value << ");\n";
658654
} else {
659655
std::string ac = t.lanes() == 4 ? vec : (vec + "." + access[i / 4]);
660656
stream << ac << "=";
@@ -861,7 +857,23 @@ void CodeGenCUDA::PrintCallExtern(Type ret_type, String global_symbol, const Arr
861857
}
862858
os << sret;
863859
} else {
864-
CodeGenC::PrintCallExtern(ret_type, global_symbol, args, skip_first_arg, os);
860+
if (ret_dtype.is_float8()) {
861+
std::string fp8_type = (ret_dtype.is_e5m2_float8() ? "__NV_E5M2" : "__NV_E4M3");
862+
os << "__nv_fp8_" << (ret_dtype.is_e5m2_float8() ? "e5m2" : "e4m3") << "(";
863+
864+
LOG_INFO << global_symbol;
865+
os << global_symbol << "(__half(__nv_cvt_fp8_to_halfraw(";
866+
for (size_t i = static_cast<size_t>(skip_first_arg); i < args.size(); ++i) {
867+
this->PrintExpr(args[i], os);
868+
os << ".__x, " << fp8_type << "))";
869+
if (i < args.size() - 1) {
870+
os << ", " << "__half(__nv_cvt_fp8_to_halfraw(";
871+
}
872+
}
873+
os << "))";
874+
} else {
875+
CodeGenC::PrintCallExtern(ret_type, global_symbol, args, skip_first_arg, os);
876+
}
865877
}
866878
}
867879

@@ -1198,8 +1210,7 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
11981210
this->stream << "\" @!p mov.b32 %0, 0;\\n\"\n";
11991211
this->stream << "\" @p ld.global.nc.f32 %0, [%1];}\\n\"\n";
12001212
// stream << "\" @p ld.global.nc.L2::128B.f32 %0, [%1];}\\n\"\n" ;
1201-
stream << ": \"=f\"(" << reg << "[" << local_addr << "]"
1202-
<< ")\n";
1213+
stream << ": \"=f\"(" << reg << "[" << local_addr << "]" << ")\n";
12031214
stream << ": \"l\"((void*)(" << global_buffer << "+" << global_addr << ")), \"r\"((int)"
12041215
<< guard << ")\n";
12051216
stream << ");\n";
@@ -1385,8 +1396,7 @@ void CodeGenCUDA::VisitExpr_(const RampNode* op, std::ostream& os) {
13851396
PrintVecConstructor(op->dtype, os);
13861397
os << "(";
13871398
for (int i = 0; i < lanes; i++) {
1388-
os << "(" << PrintExpr(op->base) << ")"
1389-
<< "+(" << PrintExpr(op->stride) << "*" << i << ")";
1399+
os << "(" << PrintExpr(op->base) << ")" << "+(" << PrintExpr(op->stride) << "*" << i << ")";
13901400
if (i != lanes - 1) os << ", ";
13911401
}
13921402
os << ")";

src/target/source/intrin_rule_cuda.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ struct CUDAMath {
5252
default:
5353
return "";
5454
}
55-
} else if (t.is_bfloat16()) {
55+
} else if (t.is_bfloat16() || t.is_float8()) {
5656
if (name == "fabs") {
5757
return "__habs";
5858
} else if (name == "round") {

0 commit comments

Comments
 (0)