Skip to content

Commit d1495ee

Browse files
committed
[BugFix][Codegen, CUDA] Fix faulty codegen for FP8
Fixed bug where CUDA codegen produces faulty code when a vectorizable BufferLoadNode contains a Float8 type. Codegen generated the invalid signature "make___nv_fp8x2_e5m2(param_0[v_.x], param_0[v_.y])" where "param_0" is of type "__nv_fp8_e5m2* __restrict__". This commit adds a missing check "is_float8()" for CodeGenCUDA::PrintVecElemLoadExpr that is called for vectorizable BufferLoadNodes. Which instead correctly generates the signature "_nv_fp8x2_e5m2(make_float2(static_cast<float>(param_0[v_.x], static_cast<float>(param_0[v_.y]))) Additionally this commit removes the added "make_" prefix for float8 in CodeGenCuda::PrintVecConstructor as the correct way to instansiate an nv_fp8x2_[e5m2/e4m3] is through the "_nv_fp8x2_[e5m2/e4m3]" constructor itself.
1 parent 432ccfa commit d1495ee

File tree

1 file changed

+16
-1
lines changed

1 file changed

+16
-1
lines changed

src/target/source/codegen_cuda.cc

+16-1
Original file line numberDiff line numberDiff line change
@@ -473,7 +473,9 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*)
473473
}
474474

475475
void CodeGenCUDA::PrintVecConstructor(DataType t, std::ostream& os) {
476-
os << "make_";
476+
if (!t.is_float8()) {
477+
os << "make_"; // There is no make___nv_fp8 (/usr/local/cuda/include/vector_functions.hpp)
478+
}
477479
PrintType(t, os);
478480
}
479481

@@ -1554,6 +1556,19 @@ void CodeGenCUDA::PrintVecElemLoadExpr(DataType t, int i, const std::string& val
15541556
return;
15551557
}
15561558

1559+
if (t.is_float8()) {
1560+
if (i == 0) {
1561+
PrintVecConstructor(t, os);
1562+
os << "(make_float" << t.lanes() << "(";
1563+
}
1564+
if (i != 0) os << ", ";
1565+
os << "static_cast<float>(" << value << ")";
1566+
if (i == t.lanes() - 1) {
1567+
os << "))";
1568+
}
1569+
return;
1570+
}
1571+
15571572
if (i == 0) {
15581573
PrintVecConstructor(t, os);
15591574
os << "(";

0 commit comments

Comments
 (0)