@@ -149,8 +149,7 @@ std::string CodeGenCUDA::Finish() {
149
149
if (enable_fp16_) {
150
150
decl_stream << " #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)\n " ;
151
151
decl_stream << " #include <cuda_fp16.h>\n " ;
152
- decl_stream << " __device__ half max"
153
- << " (half a, half b)\n "
152
+ decl_stream << " __device__ half max" << " (half a, half b)\n "
154
153
<< " {\n return __hgt(__half(a), __half(b)) ? a : b;\n }\n " ;
155
154
decl_stream << " __device__ half min(half a, half b)\n "
156
155
<< " {\n return __hlt(__half(a), __half(b)) ? a : b;\n }\n " ;
@@ -165,8 +164,7 @@ std::string CodeGenCUDA::Finish() {
165
164
if (enable_bf16_) {
166
165
decl_stream << " #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)\n " ;
167
166
decl_stream << " #include <cuda_bf16.h>\n " ;
168
- decl_stream << " __device__ nv_bfloat16 max"
169
- << " (nv_bfloat16 a, nv_bfloat16 b)\n "
167
+ decl_stream << " __device__ nv_bfloat16 max" << " (nv_bfloat16 a, nv_bfloat16 b)\n "
170
168
<< " {\n return __hgt(a, b) ? a : b;\n }\n " ;
171
169
decl_stream << " __device__ nv_bfloat16 min(nv_bfloat16 a, nv_bfloat16 b)\n "
172
170
<< " {\n return __hlt(a, b) ? a : b;\n }\n " ;
@@ -542,8 +540,7 @@ void CodeGenCUDA::PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr l
542
540
}
543
541
for (int i = 0 , lanes = t.lanes () / 2 ; i < lanes; ++i) {
544
542
if (isalpha (op[0 ]) || op[0 ] == ' _' ) {
545
- value_temp << op << " 2"
546
- << " (__half2(" ;
543
+ value_temp << op << " 2" << " (__half2(" ;
547
544
PrintVecElemLoad (vlhs, lhs.dtype (), i * lanes, value_temp);
548
545
value_temp << " ), __half2(" ;
549
546
PrintVecElemLoad (vrhs, rhs.dtype (), i * lanes, value_temp);
@@ -653,8 +650,7 @@ void CodeGenCUDA::PrintVecElemStore(const std::string& vec, DataType t, int i,
653
650
ICHECK (i >= 0 && i < (t.bits () == 8 ? 16 : (t.bits () == 16 || t.bits () == 32 ) ? 8 : 4 ));
654
651
if (t.bits () == 8 && (t.is_int () || t.is_uint ())) {
655
652
if (t.lanes () == 2 || t.lanes () == 3 ) {
656
- stream << vec << ' .' << access [i % t.lanes ()] << " ="
657
- << " (" << value << " );\n " ;
653
+ stream << vec << ' .' << access [i % t.lanes ()] << " =" << " (" << value << " );\n " ;
658
654
} else {
659
655
std::string ac = t.lanes () == 4 ? vec : (vec + " ." + access [i / 4 ]);
660
656
stream << ac << " =" ;
@@ -861,7 +857,23 @@ void CodeGenCUDA::PrintCallExtern(Type ret_type, String global_symbol, const Arr
861
857
}
862
858
os << sret;
863
859
} else {
864
- CodeGenC::PrintCallExtern (ret_type, global_symbol, args, skip_first_arg, os);
860
+ if (ret_dtype.is_float8 ()) {
861
+ std::string fp8_type = (ret_dtype.is_e5m2_float8 () ? " __NV_E5M2" : " __NV_E4M3" );
862
+ os << " __nv_fp8_" << (ret_dtype.is_e5m2_float8 () ? " e5m2" : " e4m3" ) << " (" ;
863
+
864
+ LOG_INFO << global_symbol;
865
+ os << global_symbol << " (__half(__nv_cvt_fp8_to_halfraw(" ;
866
+ for (size_t i = static_cast <size_t >(skip_first_arg); i < args.size (); ++i) {
867
+ this ->PrintExpr (args[i], os);
868
+ os << " .__x, " << fp8_type << " ))" ;
869
+ if (i < args.size () - 1 ) {
870
+ os << " , " << " __half(__nv_cvt_fp8_to_halfraw(" ;
871
+ }
872
+ }
873
+ os << " ))" ;
874
+ } else {
875
+ CodeGenC::PrintCallExtern (ret_type, global_symbol, args, skip_first_arg, os);
876
+ }
865
877
}
866
878
}
867
879
@@ -1198,8 +1210,7 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
1198
1210
this ->stream << " \" @!p mov.b32 %0, 0;\\ n\"\n " ;
1199
1211
this ->stream << " \" @p ld.global.nc.f32 %0, [%1];}\\ n\"\n " ;
1200
1212
// stream << "\" @p ld.global.nc.L2::128B.f32 %0, [%1];}\\n\"\n" ;
1201
- stream << " : \" =f\" (" << reg << " [" << local_addr << " ]"
1202
- << " )\n " ;
1213
+ stream << " : \" =f\" (" << reg << " [" << local_addr << " ]" << " )\n " ;
1203
1214
stream << " : \" l\" ((void*)(" << global_buffer << " +" << global_addr << " )), \" r\" ((int)"
1204
1215
<< guard << " )\n " ;
1205
1216
stream << " );\n " ;
@@ -1385,8 +1396,7 @@ void CodeGenCUDA::VisitExpr_(const RampNode* op, std::ostream& os) {
1385
1396
PrintVecConstructor (op->dtype , os);
1386
1397
os << " (" ;
1387
1398
for (int i = 0 ; i < lanes; i++) {
1388
- os << " (" << PrintExpr (op->base ) << " )"
1389
- << " +(" << PrintExpr (op->stride ) << " *" << i << " )" ;
1399
+ os << " (" << PrintExpr (op->base ) << " )" << " +(" << PrintExpr (op->stride ) << " *" << i << " )" ;
1390
1400
if (i != lanes - 1 ) os << " , " ;
1391
1401
}
1392
1402
os << " )" ;
0 commit comments