Skip to content

Commit e8d0e04

Browse files
committed
Fix build error (FP8 Dtypes mismatch)
1 parent ef31780 commit e8d0e04

File tree

4 files changed

+37
-33
lines changed

4 files changed

+37
-33
lines changed

python/tvm/relax/transform/legalize_ops/nn.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -475,7 +475,7 @@ def te_gelu(x: te.Tensor):
475475
dtype = x.dtype
476476
erf_inp = x * tir.const(0.5**0.5, dtype)
477477

478-
if dtype == "float16" or dtype == "e5m2_float8" or dtype == "e4m3_float8":
478+
if dtype == "float16" or dtype == "float8_e5m2" or dtype == "float8_e4m3fn":
479479
erf = topi.math.cast(topi.erf(topi.math.cast(erf_inp, "float32")), dtype)
480480
else:
481481
erf = topi.erf(erf_inp)

src/target/source/codegen_c.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -879,7 +879,7 @@ void CodeGenC::VisitStmt_(const BufferStoreNode* op) {
879879
stream << "] = ";
880880
if (op->value.dtype().is_float8()) {
881881
ICHECK(value_dtype.lanes() == 2);
882-
std::string fp8_type = op->value.dtype().is_e5m2_float8() ? "e5m2" : "e4m3";
882+
std::string fp8_type = op->value.dtype().is_float8_e5m2() ? "e5m2" : "e4m3";
883883
static const char access[] = {'x', 'y'};
884884
stream << "__nv_fp8_" << fp8_type << "(__half2(";
885885
PrintVecElemLoad(value, op->value.dtype(), i, stream);

src/target/source/codegen_cuda.cc

+25-21
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,8 @@ std::string CodeGenCUDA::Finish() {
149149
if (enable_fp16_) {
150150
decl_stream << "#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)\n";
151151
decl_stream << "#include <cuda_fp16.h>\n";
152-
decl_stream << "__device__ half max" << "(half a, half b)\n"
152+
decl_stream << "__device__ half max"
153+
<< "(half a, half b)\n"
153154
<< "{\n return __hgt(__half(a), __half(b)) ? a : b;\n}\n";
154155
decl_stream << "__device__ half min(half a, half b)\n"
155156
<< "{\n return __hlt(__half(a), __half(b)) ? a : b;\n}\n";
@@ -164,7 +165,8 @@ std::string CodeGenCUDA::Finish() {
164165
if (enable_bf16_) {
165166
decl_stream << "#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)\n";
166167
decl_stream << "#include <cuda_bf16.h>\n";
167-
decl_stream << "__device__ nv_bfloat16 max" << "(nv_bfloat16 a, nv_bfloat16 b)\n"
168+
decl_stream << "__device__ nv_bfloat16 max"
169+
<< "(nv_bfloat16 a, nv_bfloat16 b)\n"
168170
<< "{\n return __hgt(a, b) ? a : b;\n}\n";
169171
decl_stream << "__device__ nv_bfloat16 min(nv_bfloat16 a, nv_bfloat16 b)\n"
170172
<< "{\n return __hlt(a, b) ? a : b;\n}\n";
@@ -531,16 +533,16 @@ void CodeGenCUDA::PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr l
531533

532534
if (t.is_float8()) {
533535
std::ostringstream value_temp;
534-
std::string fp8_lanes = (t.lanes() == 4) ? "x4" : ((t.lanes() == 2) ? "x2" : "");
535-
ICHECK(t.is_e4m3_float8() || t.is_e5m2_float8());
536+
ICHECK(t.is_float8_e4m3fn() || t.is_float8_e5m2());
536537
if (t.lanes() == 2) {
537-
value_temp << "__nv_fp8x2_" << (t.is_e5m2_float8() ? "e5m2" : "e4m3") << "(";
538+
value_temp << "__nv_fp8x2_" << (t.is_float8_e5m2() ? "e5m2" : "e4m3") << "(";
538539
} else {
539-
value_temp << "__nv_fp8x4_" << (t.is_e5m2_float8() ? "e5m2" : "e4m3") << "(";
540+
value_temp << "__nv_fp8x4_" << (t.is_float8_e5m2() ? "e5m2" : "e4m3") << "(";
540541
}
541542
for (int i = 0, lanes = t.lanes() / 2; i < lanes; ++i) {
542543
if (isalpha(op[0]) || op[0] == '_') {
543-
value_temp << op << "2" << "(__half2(";
544+
value_temp << op << "2"
545+
<< "(__half2(";
544546
PrintVecElemLoad(vlhs, lhs.dtype(), i * lanes, value_temp);
545547
value_temp << "), __half2(";
546548
PrintVecElemLoad(vrhs, rhs.dtype(), i * lanes, value_temp);
@@ -557,9 +559,7 @@ void CodeGenCUDA::PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr l
557559
value_temp << ", ";
558560
}
559561
if (i == lanes - 1) {
560-
if (t.lanes() == 2) {
561-
value_temp << ")";
562-
}
562+
value_temp << ")";
563563
PrintVecElemStore(sret, t, i, value_temp.str());
564564
}
565565
}
@@ -595,7 +595,7 @@ void CodeGenCUDA::PrintVecElemLoad(const std::string& vec, DataType t, int i,
595595
}
596596

597597
static const char access[] = {'x', 'y', 'z', 'w'};
598-
std::string fp8_type = (t.is_float8()) ? (t.is_e4m3_float8() ? "e4m3" : "e5m2") : "";
598+
std::string fp8_type = (t.is_float8()) ? (t.is_float8_e4m3fn() ? "e4m3" : "e5m2") : "";
599599
ICHECK(i >= 0 && i < (t.bits() == 8 ? 16 : (t.bits() == 16 || t.bits() == 32) ? 8 : 4));
600600
if (t.bits() == 8 && (t.is_int() || t.is_uint())) {
601601
std::string type_name = t.is_int() ? "char" : "unsigned char";
@@ -615,7 +615,7 @@ void CodeGenCUDA::PrintVecElemLoad(const std::string& vec, DataType t, int i,
615615
os << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2];
616616
} else if (t.is_float8()) {
617617
os << "__nv_cvt_fp8x2_to_halfraw2(" << vec << ".__x,"
618-
<< (t.is_e5m2_float8() ? "__NV_E5M2" : "__NV_E4M3") << ")";
618+
<< (t.is_float8_e5m2() ? "__NV_E5M2" : "__NV_E4M3") << ")";
619619
} else if (t.lanes() > 4 && t.lanes() <= 8) {
620620
std::string type_name;
621621
if (t.bits() == 16) {
@@ -650,7 +650,8 @@ void CodeGenCUDA::PrintVecElemStore(const std::string& vec, DataType t, int i,
650650
ICHECK(i >= 0 && i < (t.bits() == 8 ? 16 : (t.bits() == 16 || t.bits() == 32) ? 8 : 4));
651651
if (t.bits() == 8 && (t.is_int() || t.is_uint())) {
652652
if (t.lanes() == 2 || t.lanes() == 3) {
653-
stream << vec << '.' << access[i % t.lanes()] << "=" << "(" << value << ");\n";
653+
stream << vec << '.' << access[i % t.lanes()] << "="
654+
<< "(" << value << ");\n";
654655
} else {
655656
std::string ac = t.lanes() == 4 ? vec : (vec + "." + access[i / 4]);
656657
stream << ac << "=";
@@ -858,16 +859,17 @@ void CodeGenCUDA::PrintCallExtern(Type ret_type, String global_symbol, const Arr
858859
os << sret;
859860
} else {
860861
if (ret_dtype.is_float8()) {
861-
std::string fp8_type = (ret_dtype.is_e5m2_float8() ? "__NV_E5M2" : "__NV_E4M3");
862-
os << "__nv_fp8_" << (ret_dtype.is_e5m2_float8() ? "e5m2" : "e4m3") << "(";
862+
std::string fp8_type = (ret_dtype.is_float8_e5m2() ? "__NV_E5M2" : "__NV_E4M3");
863+
os << "__nv_fp8_" << (ret_dtype.is_float8_e5m2() ? "e5m2" : "e4m3") << "(";
863864

864865
LOG_INFO << global_symbol;
865866
os << global_symbol << "(__half(__nv_cvt_fp8_to_halfraw(";
866867
for (size_t i = static_cast<size_t>(skip_first_arg); i < args.size(); ++i) {
867868
this->PrintExpr(args[i], os);
868869
os << ".__x, " << fp8_type << "))";
869870
if (i < args.size() - 1) {
870-
os << ", " << "__half(__nv_cvt_fp8_to_halfraw(";
871+
os << ", "
872+
<< "__half(__nv_cvt_fp8_to_halfraw(";
871873
}
872874
}
873875
os << "))";
@@ -1210,7 +1212,8 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
12101212
this->stream << "\" @!p mov.b32 %0, 0;\\n\"\n";
12111213
this->stream << "\" @p ld.global.nc.f32 %0, [%1];}\\n\"\n";
12121214
// stream << "\" @p ld.global.nc.L2::128B.f32 %0, [%1];}\\n\"\n" ;
1213-
stream << ": \"=f\"(" << reg << "[" << local_addr << "]" << ")\n";
1215+
stream << ": \"=f\"(" << reg << "[" << local_addr << "]"
1216+
<< ")\n";
12141217
stream << ": \"l\"((void*)(" << global_buffer << "+" << global_addr << ")), \"r\"((int)"
12151218
<< guard << ")\n";
12161219
stream << ");\n";
@@ -1396,7 +1399,8 @@ void CodeGenCUDA::VisitExpr_(const RampNode* op, std::ostream& os) {
13961399
PrintVecConstructor(op->dtype, os);
13971400
os << "(";
13981401
for (int i = 0; i < lanes; i++) {
1399-
os << "(" << PrintExpr(op->base) << ")" << "+(" << PrintExpr(op->stride) << "*" << i << ")";
1402+
os << "(" << PrintExpr(op->base) << ")"
1403+
<< "+(" << PrintExpr(op->stride) << "*" << i << ")";
14001404
if (i != lanes - 1) os << ", ";
14011405
}
14021406
os << ")";
@@ -1762,9 +1766,9 @@ inline void PrintBinaryExpr(const T* op, const char* opstr,
17621766
CodeGenCUDA* p) {
17631767
if (op->dtype.lanes() == 1) {
17641768
if (op->dtype.is_float8()) {
1765-
std::string fp8_type = (op->dtype.is_e5m2_float8() ? "__NV_E5M2" : "__NV_E4M3");
1769+
std::string fp8_type = (op->dtype.is_float8_e5m2() ? "__NV_E5M2" : "__NV_E4M3");
17661770
if (isalpha(opstr[0]) || opstr[0] == '_') {
1767-
os << "__nv_fp8_" << (op->dtype.is_e5m2_float8() ? "e5m2" : "e4m3") << "(";
1771+
os << "__nv_fp8_" << (op->dtype.is_float8_e5m2() ? "e5m2" : "e4m3") << "(";
17681772
os << opstr << "(";
17691773
os << "__half(__nv_cvt_fp8_to_halfraw(";
17701774
p->PrintExpr(op->a, os);
@@ -1773,7 +1777,7 @@ inline void PrintBinaryExpr(const T* op, const char* opstr,
17731777
os << ".__x, " << fp8_type << ")))";
17741778
os << ")";
17751779
} else {
1776-
os << "__nv_fp8_" << (op->dtype.is_e5m2_float8() ? "e5m2" : "e4m3") << "(";
1780+
os << "__nv_fp8_" << (op->dtype.is_float8_e5m2() ? "e5m2" : "e4m3") << "(";
17771781
os << "__half(__nv_cvt_fp8_to_halfraw(";
17781782
p->PrintExpr(op->a, os);
17791783
os << ".__x, " << fp8_type << ")) " << opstr << " __half(__nv_cvt_fp8_to_halfraw(";

tests/python/codegen/test_target_codegen_cuda_fp8_operators.py

+10-10
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828

2929
@tvm.testing.requires_cuda_compute_version(8, 9)
3030
@pytest.mark.parametrize("original_dtype", ["float16", "float32"])
31-
@pytest.mark.parametrize("dtype", ["e4m3_float8", "e5m2_float8", "float16"])
31+
@pytest.mark.parametrize("dtype", ["float8_e4m3fn", "float8_e5m2", "float16"])
3232
@pytest.mark.parametrize("batch_size", [1, 64])
3333
def test_fp8_matmul_compile(dtype, original_dtype, batch_size):
3434
bb = relax.BlockBuilder()
@@ -66,7 +66,7 @@ def test_fp8_matmul_compile(dtype, original_dtype, batch_size):
6666

6767
@tvm.testing.requires_cuda_compute_version(8, 9)
6868
@pytest.mark.parametrize("original_dtype", ["float16", "float32"])
69-
@pytest.mark.parametrize("dtype", ["e4m3_float8", "e5m2_float8"])
69+
@pytest.mark.parametrize("dtype", ["float8_e4m3fn", "float8_e5m2"])
7070
@pytest.mark.parametrize("batch_size", [1, 64])
7171
def test_fp8_conv2d_compile(dtype, original_dtype, batch_size):
7272
bb = relax.BlockBuilder()
@@ -116,7 +116,7 @@ def test_fp8_conv2d_compile(dtype, original_dtype, batch_size):
116116

117117
@tvm.testing.requires_cuda_compute_version(8, 9)
118118
@pytest.mark.parametrize("original_dtype", ["float16", "float32"])
119-
@pytest.mark.parametrize("dtype", ["e4m3_float8", "e5m2_float8"])
119+
@pytest.mark.parametrize("dtype", ["float8_e4m3fn", "float8_e5m2"])
120120
@pytest.mark.parametrize("batch_size", [1, 64])
121121
def test_fp8_maxpool2d_compile(dtype, original_dtype, batch_size):
122122
bb = relax.BlockBuilder()
@@ -164,7 +164,7 @@ def test_fp8_maxpool2d_compile(dtype, original_dtype, batch_size):
164164

165165
@tvm.testing.requires_cuda_compute_version(8, 9)
166166
@pytest.mark.parametrize("original_dtype", ["float16", "float32"])
167-
@pytest.mark.parametrize("dtype", ["e4m3_float8", "e5m2_float8"])
167+
@pytest.mark.parametrize("dtype", ["float8_e4m3fn", "float8_e5m2"])
168168
@pytest.mark.parametrize("batch_size", [1, 64])
169169
def test_fp8_add_compile(dtype, original_dtype, batch_size):
170170
bb = relax.BlockBuilder()
@@ -202,7 +202,7 @@ def test_fp8_add_compile(dtype, original_dtype, batch_size):
202202

203203
@tvm.testing.requires_cuda_compute_version(8, 9)
204204
@pytest.mark.parametrize("original_dtype", ["float16", "float32"])
205-
@pytest.mark.parametrize("dtype", ["e4m3_float8", "e5m2_float8"])
205+
@pytest.mark.parametrize("dtype", ["float8_e4m3fn", "float8_e5m2"])
206206
@pytest.mark.parametrize("batch_size", [1, 64])
207207
def test_fp8_relu_compile(dtype, original_dtype, batch_size):
208208
bb = relax.BlockBuilder()
@@ -238,7 +238,7 @@ def test_fp8_relu_compile(dtype, original_dtype, batch_size):
238238

239239
@tvm.testing.requires_cuda_compute_version(8, 9)
240240
@pytest.mark.parametrize("original_dtype", ["float16", "float32"])
241-
@pytest.mark.parametrize("dtype", ["e4m3_float8", "e5m2_float8"])
241+
@pytest.mark.parametrize("dtype", ["float8_e4m3fn", "float8_e5m2"])
242242
@pytest.mark.parametrize("batch_size", [1, 64])
243243
def test_fp8_gelu_compile(dtype, original_dtype, batch_size):
244244
bb = relax.BlockBuilder()
@@ -274,7 +274,7 @@ def test_fp8_gelu_compile(dtype, original_dtype, batch_size):
274274

275275
@tvm.testing.requires_cuda_compute_version(8, 9)
276276
@pytest.mark.parametrize("original_dtype", ["float16", "float32"])
277-
@pytest.mark.parametrize("dtype", ["e4m3_float8", "e5m2_float8"])
277+
@pytest.mark.parametrize("dtype", ["float8_e4m3fn", "float8_e5m2"])
278278
@pytest.mark.parametrize("batch_size", [1, 64])
279279
def test_fp8_gelu_tanh_compile(dtype, original_dtype, batch_size):
280280
bb = relax.BlockBuilder()
@@ -310,7 +310,7 @@ def test_fp8_gelu_tanh_compile(dtype, original_dtype, batch_size):
310310

311311
@tvm.testing.requires_cuda_compute_version(8, 9)
312312
@pytest.mark.parametrize("original_dtype", ["float16", "float32"])
313-
@pytest.mark.parametrize("dtype", ["e4m3_float8", "e5m2_float8"])
313+
@pytest.mark.parametrize("dtype", ["float8_e4m3fn", "float8_e5m2"])
314314
@pytest.mark.parametrize("batch_size", [1, 64])
315315
def test_fp8_sigmoid_compile(dtype, original_dtype, batch_size):
316316
bb = relax.BlockBuilder()
@@ -346,7 +346,7 @@ def test_fp8_sigmoid_compile(dtype, original_dtype, batch_size):
346346

347347
@tvm.testing.requires_cuda_compute_version(8, 9)
348348
@pytest.mark.parametrize("original_dtype", ["float16", "float32"])
349-
@pytest.mark.parametrize("dtype", ["e4m3_float8", "e5m2_float8"])
349+
@pytest.mark.parametrize("dtype", ["float8_e4m3fn", "float8_e5m2"])
350350
@pytest.mark.parametrize("batch_size", [1, 64])
351351
def test_fp8_silu_compile(dtype, original_dtype, batch_size):
352352
bb = relax.BlockBuilder()
@@ -382,7 +382,7 @@ def test_fp8_silu_compile(dtype, original_dtype, batch_size):
382382

383383
@tvm.testing.requires_cuda_compute_version(8, 9)
384384
@pytest.mark.parametrize("original_dtype", ["float16", "float32"])
385-
@pytest.mark.parametrize("dtype", ["e4m3_float8", "e5m2_float8"])
385+
@pytest.mark.parametrize("dtype", ["float8_e4m3fn", "float8_e5m2"])
386386
@pytest.mark.parametrize("batch_size", [1, 64])
387387
def test_fp8_softmax_compile(dtype, original_dtype, batch_size):
388388
bb = relax.BlockBuilder()

0 commit comments

Comments
 (0)