Open
Description
Actual behavior
Traceback (most recent call last):
File "/data/qshenaf/remote_pc/TirFuzz/bugs/bug1.py", line 10, in <module>
mod = tir.transform.FP8StorageLegalize()(mod)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/qshenaf/envs/tvm/python/tvm/ir/transform.py", line 238, in __call__
return _ffi_transform_api.RunPass(self, mod)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "tvm/_ffi/_cython/./packed_func.pxi", line 339, in tvm._ffi._cy3.core.PackedFuncBase.__call__
File "tvm/_ffi/_cython/./packed_func.pxi", line 270, in tvm._ffi._cy3.core.FuncCall
File "tvm/_ffi/_cython/./packed_func.pxi", line 259, in tvm._ffi._cy3.core.FuncCall3
File "tvm/_ffi/_cython/./base.pxi", line 185, in tvm._ffi._cy3.core.CHECK_CALL
File "/data/qshenaf/envs/tvm/python/tvm/_ffi/base.py", line 468, in raise_last_ffi_error
raise py_err
File "/data/qshenaf/envs/tvm/src/tir/ir/transform.cc", line 121, in tvm::tir::transform::PrimFuncPassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
func = pass_func(std::move(func), mod, pass_ctx);
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/qshenaf/envs/tvm/src/tir/transforms/unsupported_dtype_legalize.cc", line 796, in operator()
return FP8StorageLegalizer().Legalize(f);
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/qshenaf/envs/tvm/src/tir/transforms/unsupported_dtype_legalize.cc", line 519, in tvm::tir::StorageLegalizer::Legalize(tvm::tir::PrimFunc)
ICHECK_EQ(func->buffer_map.size(), 0) << "This pass must be called after MakePackedAPI";
^^^^^^^^^^^^^^^^^^^^^^^^^^^
tvm.error.InternalError: Traceback (most recent call last):
2: tvm::tir::transform::PrimFuncPassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
at /data/qshenaf/envs/tvm/src/tir/ir/transform.cc:121
1: operator()
at /data/qshenaf/envs/tvm/src/tir/transforms/unsupported_dtype_legalize.cc:796
0: tvm::tir::StorageLegalizer::Legalize(tvm::tir::PrimFunc)
at /data/qshenaf/envs/tvm/src/tir/transforms/unsupported_dtype_legalize.cc:519
File "/data/qshenaf/envs/tvm/src/tir/transforms/unsupported_dtype_legalize.cc", line 519
InternalError: Check failed: func->buffer_map.size() == 0 (2 vs. 0) : This pass must be called after MakePackedAPI
Environment
TVM-0.21.dev0 (latest)
Steps to reproduce
import tvm
from tvm import te, topi, tir
data = te.placeholder((1, 3, 224, 224), dtype='float32', name='data')
op_output = topi.nn.adaptive_pool(data, output_size=(112, 112), pool_type='max', layout='NCHW')
sch = tir.Schedule(te.create_prim_func([data, op_output]).with_attr("target", tvm.target.Target("llvm")))
mod = tir.transform.MakePackedAPI()(sch.mod)
print(mod)
mod = tir.transform.FP8StorageLegalize()(mod)
The IRs after executing the ``.
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def main(data: T.Buffer((1, 3, 224, 224), "float32"), adaptive_pool_max: T.Buffer((1, 3, 112, 112), "float32")):
T.func_attr({"target": T.target({"keys": ["cpu"], "kind": "llvm", "mtriple": "x86_64-unknown-linux-gnu", "tag": ""}), "tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1, ax2, ax3, rv0, rv1 in T.grid(1, 3, 112, 112, 2, 2):
with T.block("adaptive_pool_max"):
v_ax0, v_ax1, v_ax2, v_ax3, v_rv0, v_rv1 = T.axis.remap("SSSSRR", [ax0, ax1, ax2, ax3, rv0, rv1])
data_1 = T.Buffer((150528,), data=data.data)
T.reads(data_1[v_ax0 * 150528 + v_ax1 * 50176 + v_ax2 * 448 + v_rv0 * 224 + v_ax3 * 2 + v_rv1:v_ax0 * 150528+ v_ax1 * 50176 + v_ax2 * 448 + v_rv0 * 224 + v_ax3 * 2 + v_rv1 + (v_ax0 * 150528 + v_ax1 * 50176 + v_ax2 * 448 + v_rv0 * 224 + v_ax3 * 2 + v_rv1 + 1 - (v_ax0 * 150528 + v_ax1 * 50176 + v_ax2 * 448 + v_rv0 * 224 + v_ax3 * 2 + v_rv1))])
adaptive_pool_max_1 = T.Buffer((37632,), data=adaptive_pool_max.data)
T.writes(adaptive_pool_max_1[v_ax0 * 37632 + v_ax1 * 12544 + v_ax2 * 112 + v_ax3:v_ax0 * 37632 + v_ax1 * 12544 + v_ax2 * 112 + v_ax3 + (v_ax0 * 37632 + v_ax1 * 12544 + v_ax2 * 112 + v_ax3 + 1 - (v_ax0 * 37632 + v_ax1 * 12544 + v_ax2 * 112 + v_ax3))])
T.block_attr({"schedule_rule": "meta_schedule.adaptive_pool_max"})
with T.init():
adaptive_pool_max_1[v_ax0 * 37632 + v_ax1 * 12544 + v_ax2 * 112 + v_ax3] = T.float32(-340282346638528859811704183484516925440.0)
adaptive_pool_max_1[v_ax0 * 37632 + v_ax1 * 12544 + v_ax2 * 112 + v_ax3] = T.max(adaptive_pool_max_1[v_ax0 *37632 + v_ax1 * 12544 + v_ax2 * 112 + v_ax3], data_1[v_ax0 * 150528 + v_ax1 * 50176 + v_ax2 * 448 + v_rv0 * 224 + v_ax3 * 2 + v_rv1])
Triage
Please refer to the list of label tags here to find the relevant tags and add them below in a bullet format (example below).
- needs-triage