Skip to content

[Bug][TIR] Incorrect transform for the MakePackedAPI pass #17870

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
Cookiee235 opened this issue Apr 21, 2025 · 0 comments
Open

[Bug][TIR] Incorrect transform for the MakePackedAPI pass #17870

Cookiee235 opened this issue Apr 21, 2025 · 0 comments
Labels
needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it type: bug

Comments

@Cookiee235
Copy link
Contributor

Cookiee235 commented Apr 21, 2025

Actual behavior

Traceback (most recent call last):
  File "/data/qshenaf/remote_pc/TirFuzz/bugs/bug1.py", line 10, in <module>
    mod = tir.transform.FP8StorageLegalize()(mod)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/qshenaf/envs/tvm/python/tvm/ir/transform.py", line 238, in __call__
    return _ffi_transform_api.RunPass(self, mod)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "tvm/_ffi/_cython/./packed_func.pxi", line 339, in tvm._ffi._cy3.core.PackedFuncBase.__call__
  File "tvm/_ffi/_cython/./packed_func.pxi", line 270, in tvm._ffi._cy3.core.FuncCall
  File "tvm/_ffi/_cython/./packed_func.pxi", line 259, in tvm._ffi._cy3.core.FuncCall3
  File "tvm/_ffi/_cython/./base.pxi", line 185, in tvm._ffi._cy3.core.CHECK_CALL
  File "/data/qshenaf/envs/tvm/python/tvm/_ffi/base.py", line 468, in raise_last_ffi_error
    raise py_err
  File "/data/qshenaf/envs/tvm/src/tir/ir/transform.cc", line 121, in tvm::tir::transform::PrimFuncPassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
    func = pass_func(std::move(func), mod, pass_ctx);
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/qshenaf/envs/tvm/src/tir/transforms/unsupported_dtype_legalize.cc", line 796, in operator()
    return FP8StorageLegalizer().Legalize(f);
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/qshenaf/envs/tvm/src/tir/transforms/unsupported_dtype_legalize.cc", line 519, in tvm::tir::StorageLegalizer::Legalize(tvm::tir::PrimFunc)
    ICHECK_EQ(func->buffer_map.size(), 0) << "This pass must be called after MakePackedAPI";
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^
tvm.error.InternalError: Traceback (most recent call last):
  2: tvm::tir::transform::PrimFuncPassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
        at /data/qshenaf/envs/tvm/src/tir/ir/transform.cc:121
  1: operator()
        at /data/qshenaf/envs/tvm/src/tir/transforms/unsupported_dtype_legalize.cc:796
  0: tvm::tir::StorageLegalizer::Legalize(tvm::tir::PrimFunc)
        at /data/qshenaf/envs/tvm/src/tir/transforms/unsupported_dtype_legalize.cc:519
  File "/data/qshenaf/envs/tvm/src/tir/transforms/unsupported_dtype_legalize.cc", line 519
InternalError: Check failed: func->buffer_map.size() == 0 (2 vs. 0) : This pass must be called after MakePackedAPI

Environment

TVM-0.21.dev0 (latest)

Steps to reproduce

import tvm
from tvm import te, topi, tir

data = te.placeholder((1, 3, 224, 224), dtype='float32', name='data')
op_output = topi.nn.adaptive_pool(data, output_size=(112, 112), pool_type='max', layout='NCHW')
sch = tir.Schedule(te.create_prim_func([data, op_output]).with_attr("target", tvm.target.Target("llvm")))
mod = tir.transform.MakePackedAPI()(sch.mod)
print(mod)
mod = tir.transform.FP8StorageLegalize()(mod)

The IRs after executing the ``.

# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def main(data: T.Buffer((1, 3, 224, 224), "float32"), adaptive_pool_max: T.Buffer((1, 3, 112, 112), "float32")):
        T.func_attr({"target": T.target({"keys": ["cpu"], "kind": "llvm", "mtriple": "x86_64-unknown-linux-gnu", "tag": ""}), "tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1, ax2, ax3, rv0, rv1 in T.grid(1, 3, 112, 112, 2, 2):
            with T.block("adaptive_pool_max"):
                v_ax0, v_ax1, v_ax2, v_ax3, v_rv0, v_rv1 = T.axis.remap("SSSSRR", [ax0, ax1, ax2, ax3, rv0, rv1])
                data_1 = T.Buffer((150528,), data=data.data)
                T.reads(data_1[v_ax0 * 150528 + v_ax1 * 50176 + v_ax2 * 448 + v_rv0 * 224 + v_ax3 * 2 + v_rv1:v_ax0 * 150528+ v_ax1 * 50176 + v_ax2 * 448 + v_rv0 * 224 + v_ax3 * 2 + v_rv1 + (v_ax0 * 150528 + v_ax1 * 50176 + v_ax2 * 448 + v_rv0 * 224 + v_ax3 * 2 + v_rv1 + 1 - (v_ax0 * 150528 + v_ax1 * 50176 + v_ax2 * 448 + v_rv0 * 224 + v_ax3 * 2 + v_rv1))])
                adaptive_pool_max_1 = T.Buffer((37632,), data=adaptive_pool_max.data)
                T.writes(adaptive_pool_max_1[v_ax0 * 37632 + v_ax1 * 12544 + v_ax2 * 112 + v_ax3:v_ax0 * 37632 + v_ax1 * 12544 + v_ax2 * 112 + v_ax3 + (v_ax0 * 37632 + v_ax1 * 12544 + v_ax2 * 112 + v_ax3 + 1 - (v_ax0 * 37632 + v_ax1 * 12544 + v_ax2 * 112 + v_ax3))])
                T.block_attr({"schedule_rule": "meta_schedule.adaptive_pool_max"})
                with T.init():
                    adaptive_pool_max_1[v_ax0 * 37632 + v_ax1 * 12544 + v_ax2 * 112 + v_ax3] = T.float32(-340282346638528859811704183484516925440.0)
                adaptive_pool_max_1[v_ax0 * 37632 + v_ax1 * 12544 + v_ax2 * 112 + v_ax3] = T.max(adaptive_pool_max_1[v_ax0 *37632 + v_ax1 * 12544 + v_ax2 * 112 + v_ax3], data_1[v_ax0 * 150528 + v_ax1 * 50176 + v_ax2 * 448 + v_rv0 * 224 + v_ax3 * 2 + v_rv1])

Triage

Please refer to the list of label tags here to find the relevant tags and add them below in a bullet format (example below).

  • needs-triage

cc @Hzfengsy @junrushao @quic-sanirudh @shingjan

@Cookiee235 Cookiee235 added needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it type: bug labels Apr 21, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it type: bug
Projects
None yet
Development

No branches or pull requests

1 participant