diff --git a/include/tvm/relax/attrs/nn.h b/include/tvm/relax/attrs/nn.h index e2ce2be6a882..f0f80ad8f4a0 100644 --- a/include/tvm/relax/attrs/nn.h +++ b/include/tvm/relax/attrs/nn.h @@ -603,6 +603,15 @@ struct PadAttrs : public tvm::AttrsNode { } }; +/*! \brief Attributes used for the pixel shuffle operator */ +struct PixelShuffleAttrs : public tvm::AttrsNode { + int upscale_factor; + + TVM_DECLARE_ATTRS(PixelShuffleAttrs, "relax.attrs.PixelShuffleAttrs") { + TVM_ATTR_FIELD(upscale_factor).describe("Scale factor for spatial upsampling."); + } +}; + } // namespace relax } // namespace tvm diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 33f6ffc3132e..2a244ac0c4e0 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -883,6 +883,12 @@ def _pad(self, node: fx.Node) -> relax.Var: return self.block_builder.emit(relax.op.nn.pad(x, pad_width, mode, value)) + def _pixel_shuffle(self, node: fx.Node) -> relax.Var: + data = self.env[node.args[0]] + upscale_factor = node.args[1] + + return self.block_builder.emit(relax.op.nn.pixel_shuffle(data, upscale_factor)) + def _scaled_dot_product_attention(self, node: fx.Node) -> relax.Var: transpose_S_H = lambda tensor: relax.op.permute_dims(tensor, [0, 2, 1, 3]) query = transpose_S_H(self.env[node.args[0]]) diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 0434712050ed..58e060b7595d 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -309,6 +309,7 @@ def create_convert_map( "log_softmax.int": self._log_softmax, "neg.default": self._unary_op(relax.op.negative), "pad.default": self._pad, + "pixel_shuffle.default": self._pixel_shuffle, "prelu.default": self._prelu, "reciprocal.default": self._reciprocal, "relu.default": self._unary_op(relax.op.nn.relu), diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 55abf20fcc03..83a9ad55dfbd 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -460,6 +460,13 @@ def _max_pool2d_module(self, node: fx.Node) -> relax.Var: return self._max_pool2d_impl(x, kernel_size, stride, padding, dilation, ceil_mode) + def _pixel_shuffle_module(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + module = self.named_modules[node.target] + upscale_factor = module.upscale_factor + + return self.block_builder.emit(relax.op.nn.pixel_shuffle(x, upscale_factor)) + ########## Linear Interpolation ########## def _lerp(self, node: fx.Node) -> relax.Var: @@ -665,6 +672,7 @@ def create_convert_map( nn.Linear: self._linear_module, nn.MaxPool2d: self._max_pool2d_module, nn.modules.sparse.Embedding: self._embedding_module, + nn.PixelShuffle: self._pixel_shuffle_module, # tensor manipulation nn.Flatten: self._flatten_module, ## call_function and call_method @@ -703,6 +711,7 @@ def create_convert_map( "log_softmax": self._log_softmax, "neg": self._unary_op(relax.op.negative), "pad": self._pad, + "pixel_shuffle": self._pixel_shuffle, "prelu": self._prelu, "reciprocal": self._reciprocal, "relu": self._unary_op(relax.op.nn.relu), diff --git a/python/tvm/relax/op/nn/__init__.py b/python/tvm/relax/op/nn/__init__.py index 14b5dcfc0681..08ecda275c3e 100644 --- a/python/tvm/relax/op/nn/__init__.py +++ b/python/tvm/relax/op/nn/__init__.py @@ -43,6 +43,7 @@ max_pool3d, nll_loss, pad, + pixel_shuffle, prelu, relu, rms_norm, diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py index e201b596f936..e234e8ad7b18 100644 --- a/python/tvm/relax/op/nn/nn.py +++ b/python/tvm/relax/op/nn/nn.py @@ -549,6 +549,39 @@ def pad( return _ffi_api.pad(data, pad_width, pad_mode, pad_value) +def pixel_shuffle(data: Expr, upscale_factor: int): + r""" + Pixel Shuffle Operator + + This operator performs the pixel shuffle operation on the input tensor, + which is often used for efficient sub-pixel convolution in image + super-resolution tasks. It rearranges elements in a tensor of shape + (N, C × r^2, H, W) to a tensor of shape (N, C, H × r, W × r), where `r` + is the upscale factor. + + Parameters + ---------- + data : relax.Expr + The input tensor to the pixel shuffle operator. It must have 4 dimensions + with the format (N, C * r^2, H, W), where `r` is the upscale factor. + + upscale_factor : int + The upscaling factor `r`. It determines how much to increase the spatial + resolution (height and width) of the input tensor. + + Returns + ------- + result : relax.Expr + The transformed tensor with shape (N, C, H * r, W * r). + + Example + ------- + If the input tensor has shape (1, 8, 10, 15) and `upscale_factor` is 2, + the resulting tensor will have shape (1, 2, 20, 30). + """ + return _ffi_api.pixel_shuffle(data, upscale_factor) + + def max_pool1d( data: Expr, pool_size: Union[int, Tuple[int, int]] = (1,), diff --git a/python/tvm/relax/transform/legalize_ops/nn.py b/python/tvm/relax/transform/legalize_ops/nn.py index 6a6f0ed6cb93..f18ad6097f06 100644 --- a/python/tvm/relax/transform/legalize_ops/nn.py +++ b/python/tvm/relax/transform/legalize_ops/nn.py @@ -249,6 +249,12 @@ def _nn_pad(bb: BlockBuilder, call: Call) -> Expr: ) +@register_legalize("relax.nn.pixel_shuffle") +def _nn_pixel_shuffle(bb: BlockBuilder, call: Call) -> Expr: + upscale_factor = call.attrs.upscale_factor + return bb.call_te(topi.nn.pixel_shuffle, call.args[0], upscale_factor=upscale_factor) + + @register_legalize("relax.nn.max_pool1d") def _nn_max_pool1d(bb: BlockBuilder, call: Call) -> Expr: if call.attrs.out_layout != call.attrs.layout: diff --git a/python/tvm/topi/nn/pixel_shuffle.py b/python/tvm/topi/nn/pixel_shuffle.py new file mode 100644 index 000000000000..78966ee4d9d7 --- /dev/null +++ b/python/tvm/topi/nn/pixel_shuffle.py @@ -0,0 +1,75 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""TVM operator pixel shuffle compute.""" +from __future__ import absolute_import + +import tvm + + +def pixel_shuffle(data, upscale_factor, name="PixelShuffle"): + """PixelShuffle operator that rearranges elements in a tensor of shape + [..., C * r * r, H, W] to [..., C, H * r, W * r]. + + Parameters + ---------- + data : tvm.te.Tensor + N-D input tensor with at least 3 dimensions. Channel must be at index -3. + + upscale_factor : int + The upscale factor (r). + + name : str + Name of the output tensor. + + Returns + ------- + output : tvm.te.Tensor + Pixel shuffled tensor with shape [..., C, H*r, W*r] + """ + assert isinstance(upscale_factor, int) and upscale_factor > 0 + ndim = len(data.shape) + assert ndim >= 3, "Input must be at least 3D" + + upscale_factor_const = tvm.tir.const(upscale_factor, "int32") + c_in, h_in, w_in = data.shape[-3], data.shape[-2], data.shape[-1] + + c_out = tvm.tir.floordiv(c_in, upscale_factor_const * upscale_factor_const) + h_out = h_in * upscale_factor_const + w_out = w_in * upscale_factor_const + + out_shape = list(data.shape[:-3]) + [c_out, h_out, w_out] + + def _compute(*indices): + batch_indices = indices[:-3] + c_out_idx, h_out_idx, w_out_idx = indices[-3], indices[-2], indices[-1] + + h_idx = tvm.tir.floordiv(h_out_idx, upscale_factor_const) + h_offset = h_out_idx % upscale_factor_const + + w_idx = tvm.tir.floordiv(w_out_idx, upscale_factor_const) + w_offset = w_out_idx % upscale_factor_const + + c_in_idx = ( + (c_out_idx * upscale_factor_const * upscale_factor_const) + + (h_offset * upscale_factor_const) + + w_offset + ) + + index_tuple = batch_indices + (c_in_idx, h_idx, w_idx) + return data[index_tuple] + + return tvm.te.compute(out_shape, _compute, name=name) diff --git a/src/relax/op/nn/nn.cc b/src/relax/op/nn/nn.cc index 8c0b86fe5f8e..3519cbcf59b8 100644 --- a/src/relax/op/nn/nn.cc +++ b/src/relax/op/nn/nn.cc @@ -224,6 +224,70 @@ TVM_REGISTER_OP("relax.nn.pad") .set_attr("FInferStructInfo", InferStructInfoPad) .set_attr("FPurity", Bool(true)); +/* relax.nn.pixel_shuffle */ +TVM_REGISTER_NODE_TYPE(PixelShuffleAttrs); + +Expr pixel_shuffle(Expr data, int upscale_factor) { + auto attrs = make_object(); + attrs->upscale_factor = upscale_factor; + static const Op& op = Op::Get("relax.nn.pixel_shuffle"); + return Call(op, {data}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.nn.pixel_shuffle").set_body_typed(pixel_shuffle); + +StructInfo InferStructInfoPixelShuffle(const Call& call, const BlockBuilder& ctx) { + Array input_sinfo = GetInputTensorStructInfo(call, ctx); + const auto* attrs = call->attrs.as(); + int r = attrs->upscale_factor; + ICHECK_GT(r, 0) << "Upscale factor must be positive"; + + const TensorStructInfo& input = input_sinfo[0]; + int ndim = input->ndim; + ICHECK_GE(ndim, 3) << "PixelShuffle requires at least 3D input tensor"; + + if (!input->shape.defined()) { + return TensorStructInfo(input->dtype, ndim); + } + + const auto* shape = input->shape.as(); + Array in_shape = shape->values; + + int channel_idx = ndim - 3; + int h_idx = ndim - 2; + int w_idx = ndim - 1; + + PrimExpr c_in = in_shape[channel_idx]; + PrimExpr h_in = in_shape[h_idx]; + PrimExpr w_in = in_shape[w_idx]; + + PrimExpr r_expr = IntImm(DataType::Int(32), r); + PrimExpr r_squared = r_expr * r_expr; + + // Output shape: + Array out_shape; + for (int i = 0; i < ndim; ++i) { + if (i == channel_idx) { + out_shape.push_back(c_in / r_squared); + } else if (i == h_idx) { + out_shape.push_back(h_in * r_expr); + } else if (i == w_idx) { + out_shape.push_back(w_in * r_expr); + } else { + out_shape.push_back(in_shape[i]); + } + } + + return TensorStructInfo(ShapeExpr(out_shape), input->dtype); +} + +TVM_REGISTER_OP("relax.nn.pixel_shuffle") + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_attrs_type() + .set_attr("FInferStructInfo", InferStructInfoPixelShuffle) + .set_attr("FPurity", Bool(true)); + /* relax.nn.batchnorm */ bool NormCheckDtypeAndShape(const Call& call, const BlockBuilder& ctx, const Array& input_sinfo, Array axes) { diff --git a/src/relax/op/nn/nn.h b/src/relax/op/nn/nn.h index a9c3dd0a5767..018741430199 100644 --- a/src/relax/op/nn/nn.h +++ b/src/relax/op/nn/nn.h @@ -75,6 +75,9 @@ Expr softplus(Expr data, double beta, double threshold); /*! \brief LogSoftmax function. */ Expr log_softmax(Expr data, int axis); +/*! \brief Pixel Shuffle function. */ +Expr pixel_shuffle(Expr data, int upscale_factor); + /*! \brief Compute batch normalization. */ Expr batch_norm(Expr data, Expr gamma, Expr beta, Expr moving_mean, Expr moving_var, // int axis, double epsilon, bool center, bool scale, double momentum, bool training); diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 108617991b1f..93ecc454902d 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -1985,6 +1985,42 @@ def main( verify_model(PadModel(pad=[1, 1, 2, 2], mode="circular"), example_args, {}, expected_circular) +def test_pixel_shuffle(): + class PixelShuffle1(torch.nn.Module): + def __init__(self, upscale_factor=2): + super().__init__() + self.pixel_shuffle = torch.nn.PixelShuffle(upscale_factor) + + def forward(self, x): + return self.pixel_shuffle(x) + + class PixelShuffle2(torch.nn.Module): + def __init__(self, upscale_factor=2): + super().__init__() + self.upscale_factor = upscale_factor + + def forward(self, x): + return torch.nn.functional.pixel_shuffle(x, self.upscale_factor) + + @tvm.script.ir_module + class expected: + @R.function + def main( + x: R.Tensor((1, 8, 10, 15), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 2, 20, 30), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((1, 2, 20, 30), dtype="float32") = R.nn.pixel_shuffle( + x, upscale_factor=2 + ) + gv: R.Tuple(R.Tensor((1, 2, 20, 30), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 8, 10, 15, dtype=torch.float32),) + verify_model(PixelShuffle1(upscale_factor=2), example_args, {}, expected) + verify_model(PixelShuffle2(upscale_factor=2), example_args, {}, expected) + + def test_einsum(): class Einsum1(Module): def __init__(self): diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index cb69398e0a00..2989164f1259 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -592,6 +592,42 @@ def main( verify_model(PadModel(pad=[1, 1, 2, 2], mode="circular"), input_infos, {}, expected_circular) +def test_pixel_shuffle(): + class PixelShuffle1(torch.nn.Module): + def __init__(self, upscale_factor=2): + super().__init__() + self.pixel_shuffle = torch.nn.PixelShuffle(upscale_factor) + + def forward(self, x): + return self.pixel_shuffle(x) + + class PixelShuffle2(torch.nn.Module): + def __init__(self, upscale_factor=2): + super().__init__() + self.upscale_factor = upscale_factor + + def forward(self, x): + return torch.nn.functional.pixel_shuffle(x, self.upscale_factor) + + @tvm.script.ir_module + class expected: + @R.function + def main( + inp_0: R.Tensor((1, 8, 10, 15), dtype="float32") + ) -> R.Tensor((1, 2, 20, 30), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((1, 2, 20, 30), dtype="float32") = R.nn.pixel_shuffle( + inp_0, upscale_factor=2 + ) + gv: R.Tensor((1, 2, 20, 30), dtype="float32") = lv + R.output(gv) + return gv + + input_infos = [([1, 8, 10, 15], "float32")] + verify_model(PixelShuffle1(2), input_infos, {}, expected) + verify_model(PixelShuffle2(2), input_infos, {}, expected) + + def test_linear(): # nn.Linear class Dense1(Module): diff --git a/tests/python/relax/test_op_nn.py b/tests/python/relax/test_op_nn.py index 1c03d8fe4649..bb61329da3e0 100644 --- a/tests/python/relax/test_op_nn.py +++ b/tests/python/relax/test_op_nn.py @@ -1822,5 +1822,25 @@ def test_pad_infer_struct_info(): ) +def test_pixel_shuffle_infer_struct_info(): + bb = relax.BlockBuilder() + x1 = relax.Var("x1", R.Tensor((1, 8, 10, 15), "float32")) + x2 = relax.Var("x2", R.Tensor((2, 6, 18, 5, 4), "float32")) + + upscale_factor1 = 2 + _check_inference( + bb, + relax.op.nn.pixel_shuffle(x1, upscale_factor1), + relax.TensorStructInfo((1, 2, 20, 30), dtype="float32"), + ) + + upscale_factor2 = 3 + _check_inference( + bb, + relax.op.nn.pixel_shuffle(x2, upscale_factor2), + relax.TensorStructInfo((2, 6, 2, 15, 12), dtype="float32"), + ) + + if __name__ == "__main__": tvm.testing.main()