Skip to content

[Bug][TIR] Incorrect transform for the MakePackedAPI pass #17870

Open
@Cookiee235

Description

@Cookiee235

Actual behavior

Traceback (most recent call last):
  File "/data/qshenaf/remote_pc/TirFuzz/bugs/bug1.py", line 10, in <module>
    mod = tir.transform.FP8StorageLegalize()(mod)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/qshenaf/envs/tvm/python/tvm/ir/transform.py", line 238, in __call__
    return _ffi_transform_api.RunPass(self, mod)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "tvm/_ffi/_cython/./packed_func.pxi", line 339, in tvm._ffi._cy3.core.PackedFuncBase.__call__
  File "tvm/_ffi/_cython/./packed_func.pxi", line 270, in tvm._ffi._cy3.core.FuncCall
  File "tvm/_ffi/_cython/./packed_func.pxi", line 259, in tvm._ffi._cy3.core.FuncCall3
  File "tvm/_ffi/_cython/./base.pxi", line 185, in tvm._ffi._cy3.core.CHECK_CALL
  File "/data/qshenaf/envs/tvm/python/tvm/_ffi/base.py", line 468, in raise_last_ffi_error
    raise py_err
  File "/data/qshenaf/envs/tvm/src/tir/ir/transform.cc", line 121, in tvm::tir::transform::PrimFuncPassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
    func = pass_func(std::move(func), mod, pass_ctx);
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/qshenaf/envs/tvm/src/tir/transforms/unsupported_dtype_legalize.cc", line 796, in operator()
    return FP8StorageLegalizer().Legalize(f);
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/qshenaf/envs/tvm/src/tir/transforms/unsupported_dtype_legalize.cc", line 519, in tvm::tir::StorageLegalizer::Legalize(tvm::tir::PrimFunc)
    ICHECK_EQ(func->buffer_map.size(), 0) << "This pass must be called after MakePackedAPI";
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^
tvm.error.InternalError: Traceback (most recent call last):
  2: tvm::tir::transform::PrimFuncPassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
        at /data/qshenaf/envs/tvm/src/tir/ir/transform.cc:121
  1: operator()
        at /data/qshenaf/envs/tvm/src/tir/transforms/unsupported_dtype_legalize.cc:796
  0: tvm::tir::StorageLegalizer::Legalize(tvm::tir::PrimFunc)
        at /data/qshenaf/envs/tvm/src/tir/transforms/unsupported_dtype_legalize.cc:519
  File "/data/qshenaf/envs/tvm/src/tir/transforms/unsupported_dtype_legalize.cc", line 519
InternalError: Check failed: func->buffer_map.size() == 0 (2 vs. 0) : This pass must be called after MakePackedAPI

Environment

TVM-0.21.dev0 (latest)

Steps to reproduce

import tvm
from tvm import te, topi, tir

data = te.placeholder((1, 3, 224, 224), dtype='float32', name='data')
op_output = topi.nn.adaptive_pool(data, output_size=(112, 112), pool_type='max', layout='NCHW')
sch = tir.Schedule(te.create_prim_func([data, op_output]).with_attr("target", tvm.target.Target("llvm")))
mod = tir.transform.MakePackedAPI()(sch.mod)
print(mod)
mod = tir.transform.FP8StorageLegalize()(mod)

The IRs after executing the ``.

# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def main(data: T.Buffer((1, 3, 224, 224), "float32"), adaptive_pool_max: T.Buffer((1, 3, 112, 112), "float32")):
        T.func_attr({"target": T.target({"keys": ["cpu"], "kind": "llvm", "mtriple": "x86_64-unknown-linux-gnu", "tag": ""}), "tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1, ax2, ax3, rv0, rv1 in T.grid(1, 3, 112, 112, 2, 2):
            with T.block("adaptive_pool_max"):
                v_ax0, v_ax1, v_ax2, v_ax3, v_rv0, v_rv1 = T.axis.remap("SSSSRR", [ax0, ax1, ax2, ax3, rv0, rv1])
                data_1 = T.Buffer((150528,), data=data.data)
                T.reads(data_1[v_ax0 * 150528 + v_ax1 * 50176 + v_ax2 * 448 + v_rv0 * 224 + v_ax3 * 2 + v_rv1:v_ax0 * 150528+ v_ax1 * 50176 + v_ax2 * 448 + v_rv0 * 224 + v_ax3 * 2 + v_rv1 + (v_ax0 * 150528 + v_ax1 * 50176 + v_ax2 * 448 + v_rv0 * 224 + v_ax3 * 2 + v_rv1 + 1 - (v_ax0 * 150528 + v_ax1 * 50176 + v_ax2 * 448 + v_rv0 * 224 + v_ax3 * 2 + v_rv1))])
                adaptive_pool_max_1 = T.Buffer((37632,), data=adaptive_pool_max.data)
                T.writes(adaptive_pool_max_1[v_ax0 * 37632 + v_ax1 * 12544 + v_ax2 * 112 + v_ax3:v_ax0 * 37632 + v_ax1 * 12544 + v_ax2 * 112 + v_ax3 + (v_ax0 * 37632 + v_ax1 * 12544 + v_ax2 * 112 + v_ax3 + 1 - (v_ax0 * 37632 + v_ax1 * 12544 + v_ax2 * 112 + v_ax3))])
                T.block_attr({"schedule_rule": "meta_schedule.adaptive_pool_max"})
                with T.init():
                    adaptive_pool_max_1[v_ax0 * 37632 + v_ax1 * 12544 + v_ax2 * 112 + v_ax3] = T.float32(-340282346638528859811704183484516925440.0)
                adaptive_pool_max_1[v_ax0 * 37632 + v_ax1 * 12544 + v_ax2 * 112 + v_ax3] = T.max(adaptive_pool_max_1[v_ax0 *37632 + v_ax1 * 12544 + v_ax2 * 112 + v_ax3], data_1[v_ax0 * 150528 + v_ax1 * 50176 + v_ax2 * 448 + v_rv0 * 224 + v_ax3 * 2 + v_rv1])

Triage

Please refer to the list of label tags here to find the relevant tags and add them below in a bullet format (example below).

  • needs-triage

cc @Hzfengsy @junrushao @quic-sanirudh @shingjan

Metadata

Metadata

Assignees

No one assigned

    Labels

    needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions