Skip to content

[Relax][PyTorch] Add Pixel Shuffle Op Support for Exported Program and FX graph #17886

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions include/tvm/relax/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -603,6 +603,15 @@ struct PadAttrs : public tvm::AttrsNode<PadAttrs> {
}
};

/*! \brief Attributes used for the pixel shuffle operator */
struct PixelShuffleAttrs : public tvm::AttrsNode<PixelShuffleAttrs> {
int upscale_factor;

TVM_DECLARE_ATTRS(PixelShuffleAttrs, "relax.attrs.PixelShuffleAttrs") {
TVM_ATTR_FIELD(upscale_factor).describe("Scale factor for spatial upsampling.");
}
};

} // namespace relax
} // namespace tvm

Expand Down
6 changes: 6 additions & 0 deletions python/tvm/relax/frontend/torch/base_fx_graph_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -883,6 +883,12 @@ def _pad(self, node: fx.Node) -> relax.Var:

return self.block_builder.emit(relax.op.nn.pad(x, pad_width, mode, value))

def _pixel_shuffle(self, node: fx.Node) -> relax.Var:
data = self.env[node.args[0]]
upscale_factor = node.args[1]

return self.block_builder.emit(relax.op.nn.pixel_shuffle(data, upscale_factor))

def _scaled_dot_product_attention(self, node: fx.Node) -> relax.Var:
transpose_S_H = lambda tensor: relax.op.permute_dims(tensor, [0, 2, 1, 3])
query = transpose_S_H(self.env[node.args[0]])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,7 @@ def create_convert_map(
"log_softmax.int": self._log_softmax,
"neg.default": self._unary_op(relax.op.negative),
"pad.default": self._pad,
"pixel_shuffle.default": self._pixel_shuffle,
"prelu.default": self._prelu,
"reciprocal.default": self._reciprocal,
"relu.default": self._unary_op(relax.op.nn.relu),
Expand Down
9 changes: 9 additions & 0 deletions python/tvm/relax/frontend/torch/fx_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,13 @@ def _max_pool2d_module(self, node: fx.Node) -> relax.Var:

return self._max_pool2d_impl(x, kernel_size, stride, padding, dilation, ceil_mode)

def _pixel_shuffle_module(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
module = self.named_modules[node.target]
upscale_factor = module.upscale_factor

return self.block_builder.emit(relax.op.nn.pixel_shuffle(x, upscale_factor))

########## Linear Interpolation ##########

def _lerp(self, node: fx.Node) -> relax.Var:
Expand Down Expand Up @@ -665,6 +672,7 @@ def create_convert_map(
nn.Linear: self._linear_module,
nn.MaxPool2d: self._max_pool2d_module,
nn.modules.sparse.Embedding: self._embedding_module,
nn.PixelShuffle: self._pixel_shuffle_module,
# tensor manipulation
nn.Flatten: self._flatten_module,
## call_function and call_method
Expand Down Expand Up @@ -703,6 +711,7 @@ def create_convert_map(
"log_softmax": self._log_softmax,
"neg": self._unary_op(relax.op.negative),
"pad": self._pad,
"pixel_shuffle": self._pixel_shuffle,
"prelu": self._prelu,
"reciprocal": self._reciprocal,
"relu": self._unary_op(relax.op.nn.relu),
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relax/op/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
max_pool3d,
nll_loss,
pad,
pixel_shuffle,
prelu,
relu,
rms_norm,
Expand Down
33 changes: 33 additions & 0 deletions python/tvm/relax/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,39 @@ def pad(
return _ffi_api.pad(data, pad_width, pad_mode, pad_value)


def pixel_shuffle(data: Expr, upscale_factor: int):
r"""
Pixel Shuffle Operator

This operator performs the pixel shuffle operation on the input tensor,
which is often used for efficient sub-pixel convolution in image
super-resolution tasks. It rearranges elements in a tensor of shape
(N, C × r^2, H, W) to a tensor of shape (N, C, H × r, W × r), where `r`
is the upscale factor.

Parameters
----------
data : relax.Expr
The input tensor to the pixel shuffle operator. It must have 4 dimensions
with the format (N, C * r^2, H, W), where `r` is the upscale factor.

upscale_factor : int
The upscaling factor `r`. It determines how much to increase the spatial
resolution (height and width) of the input tensor.

Returns
-------
result : relax.Expr
The transformed tensor with shape (N, C, H * r, W * r).

Example
-------
If the input tensor has shape (1, 8, 10, 15) and `upscale_factor` is 2,
the resulting tensor will have shape (1, 2, 20, 30).
"""
return _ffi_api.pixel_shuffle(data, upscale_factor)


def max_pool1d(
data: Expr,
pool_size: Union[int, Tuple[int, int]] = (1,),
Expand Down
6 changes: 6 additions & 0 deletions python/tvm/relax/transform/legalize_ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,12 @@ def _nn_pad(bb: BlockBuilder, call: Call) -> Expr:
)


@register_legalize("relax.nn.pixel_shuffle")
def _nn_pixel_shuffle(bb: BlockBuilder, call: Call) -> Expr:
upscale_factor = call.attrs.upscale_factor
return bb.call_te(topi.nn.pixel_shuffle, call.args[0], upscale_factor=upscale_factor)


@register_legalize("relax.nn.max_pool1d")
def _nn_max_pool1d(bb: BlockBuilder, call: Call) -> Expr:
if call.attrs.out_layout != call.attrs.layout:
Expand Down
75 changes: 75 additions & 0 deletions python/tvm/topi/nn/pixel_shuffle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""TVM operator pixel shuffle compute."""
from __future__ import absolute_import

import tvm


def pixel_shuffle(data, upscale_factor, name="PixelShuffle"):
"""PixelShuffle operator that rearranges elements in a tensor of shape
[..., C * r * r, H, W] to [..., C, H * r, W * r].

Parameters
----------
data : tvm.te.Tensor
N-D input tensor with at least 3 dimensions. Channel must be at index -3.

upscale_factor : int
The upscale factor (r).

name : str
Name of the output tensor.

Returns
-------
output : tvm.te.Tensor
Pixel shuffled tensor with shape [..., C, H*r, W*r]
"""
assert isinstance(upscale_factor, int) and upscale_factor > 0
ndim = len(data.shape)
assert ndim >= 3, "Input must be at least 3D"

upscale_factor_const = tvm.tir.const(upscale_factor, "int32")
c_in, h_in, w_in = data.shape[-3], data.shape[-2], data.shape[-1]

c_out = tvm.tir.floordiv(c_in, upscale_factor_const * upscale_factor_const)
h_out = h_in * upscale_factor_const
w_out = w_in * upscale_factor_const

out_shape = list(data.shape[:-3]) + [c_out, h_out, w_out]

def _compute(*indices):
batch_indices = indices[:-3]
c_out_idx, h_out_idx, w_out_idx = indices[-3], indices[-2], indices[-1]

h_idx = tvm.tir.floordiv(h_out_idx, upscale_factor_const)
h_offset = h_out_idx % upscale_factor_const

w_idx = tvm.tir.floordiv(w_out_idx, upscale_factor_const)
w_offset = w_out_idx % upscale_factor_const

c_in_idx = (
(c_out_idx * upscale_factor_const * upscale_factor_const)
+ (h_offset * upscale_factor_const)
+ w_offset
)

index_tuple = batch_indices + (c_in_idx, h_idx, w_idx)
return data[index_tuple]

return tvm.te.compute(out_shape, _compute, name=name)
64 changes: 64 additions & 0 deletions src/relax/op/nn/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,70 @@ TVM_REGISTER_OP("relax.nn.pad")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoPad)
.set_attr<Bool>("FPurity", Bool(true));

/* relax.nn.pixel_shuffle */
TVM_REGISTER_NODE_TYPE(PixelShuffleAttrs);

Expr pixel_shuffle(Expr data, int upscale_factor) {
auto attrs = make_object<PixelShuffleAttrs>();
attrs->upscale_factor = upscale_factor;
static const Op& op = Op::Get("relax.nn.pixel_shuffle");
return Call(op, {data}, Attrs(attrs), {});
}

TVM_REGISTER_GLOBAL("relax.op.nn.pixel_shuffle").set_body_typed(pixel_shuffle);

StructInfo InferStructInfoPixelShuffle(const Call& call, const BlockBuilder& ctx) {
Array<TensorStructInfo> input_sinfo = GetInputTensorStructInfo(call, ctx);
const auto* attrs = call->attrs.as<PixelShuffleAttrs>();
int r = attrs->upscale_factor;
ICHECK_GT(r, 0) << "Upscale factor must be positive";

const TensorStructInfo& input = input_sinfo[0];
int ndim = input->ndim;
ICHECK_GE(ndim, 3) << "PixelShuffle requires at least 3D input tensor";

if (!input->shape.defined()) {
return TensorStructInfo(input->dtype, ndim);
}

const auto* shape = input->shape.as<ShapeExprNode>();
Array<PrimExpr> in_shape = shape->values;

int channel_idx = ndim - 3;
int h_idx = ndim - 2;
int w_idx = ndim - 1;

PrimExpr c_in = in_shape[channel_idx];
PrimExpr h_in = in_shape[h_idx];
PrimExpr w_in = in_shape[w_idx];

PrimExpr r_expr = IntImm(DataType::Int(32), r);
PrimExpr r_squared = r_expr * r_expr;

// Output shape:
Array<PrimExpr> out_shape;
for (int i = 0; i < ndim; ++i) {
if (i == channel_idx) {
out_shape.push_back(c_in / r_squared);
} else if (i == h_idx) {
out_shape.push_back(h_in * r_expr);
} else if (i == w_idx) {
out_shape.push_back(w_in * r_expr);
} else {
out_shape.push_back(in_shape[i]);
}
}

return TensorStructInfo(ShapeExpr(out_shape), input->dtype);
}

TVM_REGISTER_OP("relax.nn.pixel_shuffle")
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.")
.set_attrs_type<PixelShuffleAttrs>()
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoPixelShuffle)
.set_attr<Bool>("FPurity", Bool(true));

/* relax.nn.batchnorm */
bool NormCheckDtypeAndShape(const Call& call, const BlockBuilder& ctx,
const Array<TensorStructInfo>& input_sinfo, Array<Integer> axes) {
Expand Down
3 changes: 3 additions & 0 deletions src/relax/op/nn/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ Expr softplus(Expr data, double beta, double threshold);
/*! \brief LogSoftmax function. */
Expr log_softmax(Expr data, int axis);

/*! \brief Pixel Shuffle function. */
Expr pixel_shuffle(Expr data, int upscale_factor);

/*! \brief Compute batch normalization. */
Expr batch_norm(Expr data, Expr gamma, Expr beta, Expr moving_mean, Expr moving_var, //
int axis, double epsilon, bool center, bool scale, double momentum, bool training);
Expand Down
36 changes: 36 additions & 0 deletions tests/python/relax/test_frontend_from_exported_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -1985,6 +1985,42 @@ def main(
verify_model(PadModel(pad=[1, 1, 2, 2], mode="circular"), example_args, {}, expected_circular)


def test_pixel_shuffle():
class PixelShuffle1(torch.nn.Module):
def __init__(self, upscale_factor=2):
super().__init__()
self.pixel_shuffle = torch.nn.PixelShuffle(upscale_factor)

def forward(self, x):
return self.pixel_shuffle(x)

class PixelShuffle2(torch.nn.Module):
def __init__(self, upscale_factor=2):
super().__init__()
self.upscale_factor = upscale_factor

def forward(self, x):
return torch.nn.functional.pixel_shuffle(x, self.upscale_factor)

@tvm.script.ir_module
class expected:
@R.function
def main(
x: R.Tensor((1, 8, 10, 15), dtype="float32")
) -> R.Tuple(R.Tensor((1, 2, 20, 30), dtype="float32")):
with R.dataflow():
lv: R.Tensor((1, 2, 20, 30), dtype="float32") = R.nn.pixel_shuffle(
x, upscale_factor=2
)
gv: R.Tuple(R.Tensor((1, 2, 20, 30), dtype="float32")) = (lv,)
R.output(gv)
return gv

example_args = (torch.randn(1, 8, 10, 15, dtype=torch.float32),)
verify_model(PixelShuffle1(upscale_factor=2), example_args, {}, expected)
verify_model(PixelShuffle2(upscale_factor=2), example_args, {}, expected)


def test_einsum():
class Einsum1(Module):
def __init__(self):
Expand Down
36 changes: 36 additions & 0 deletions tests/python/relax/test_frontend_from_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,6 +592,42 @@ def main(
verify_model(PadModel(pad=[1, 1, 2, 2], mode="circular"), input_infos, {}, expected_circular)


def test_pixel_shuffle():
class PixelShuffle1(torch.nn.Module):
def __init__(self, upscale_factor=2):
super().__init__()
self.pixel_shuffle = torch.nn.PixelShuffle(upscale_factor)

def forward(self, x):
return self.pixel_shuffle(x)

class PixelShuffle2(torch.nn.Module):
def __init__(self, upscale_factor=2):
super().__init__()
self.upscale_factor = upscale_factor

def forward(self, x):
return torch.nn.functional.pixel_shuffle(x, self.upscale_factor)

@tvm.script.ir_module
class expected:
@R.function
def main(
inp_0: R.Tensor((1, 8, 10, 15), dtype="float32")
) -> R.Tensor((1, 2, 20, 30), dtype="float32"):
with R.dataflow():
lv: R.Tensor((1, 2, 20, 30), dtype="float32") = R.nn.pixel_shuffle(
inp_0, upscale_factor=2
)
gv: R.Tensor((1, 2, 20, 30), dtype="float32") = lv
R.output(gv)
return gv

input_infos = [([1, 8, 10, 15], "float32")]
verify_model(PixelShuffle1(2), input_infos, {}, expected)
verify_model(PixelShuffle2(2), input_infos, {}, expected)


def test_linear():
# nn.Linear
class Dense1(Module):
Expand Down
20 changes: 20 additions & 0 deletions tests/python/relax/test_op_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1822,5 +1822,25 @@ def test_pad_infer_struct_info():
)


def test_pixel_shuffle_infer_struct_info():
bb = relax.BlockBuilder()
x1 = relax.Var("x1", R.Tensor((1, 8, 10, 15), "float32"))
x2 = relax.Var("x2", R.Tensor((2, 6, 18, 5, 4), "float32"))

upscale_factor1 = 2
_check_inference(
bb,
relax.op.nn.pixel_shuffle(x1, upscale_factor1),
relax.TensorStructInfo((1, 2, 20, 30), dtype="float32"),
)

upscale_factor2 = 3
_check_inference(
bb,
relax.op.nn.pixel_shuffle(x2, upscale_factor2),
relax.TensorStructInfo((2, 6, 2, 15, 12), dtype="float32"),
)


if __name__ == "__main__":
tvm.testing.main()