One-hot function in PyTorch

本文介绍并展示了四种在PyTorch中实现One-Hot编码的方法:使用scatter_函数、index_select()函数、Embedding模块以及自定义One_Hot模块。这些方法为不同场景下的One-Hot编码提供了灵活的选择。
Qwen3-32B-Chat 私有部署镜像 | RTX4090D 24G 显存 CUDA12.4 优化版

Qwen3-32B-Chat 私有部署镜像 | RTX4090D 24G 显存 CUDA12.4 优化版

Qwen
文本生成
Qwen3

本镜像基于 RTX 4090D 24GB 显存 + CUDA 12.4 + 驱动 550.90.07 深度优化,内置完整运行环境与 Qwen3-32B 模型依赖,开箱即用。

One-hot function implementation in PyTorch


Method1: use scatter_ function

labels = [0, 1, 4, 7, 3, 2]
one_hot = torch.zeros(6, 8).scatter_(dim = 1, index = labels, value = 1)

Method2: use index_select() funtion

labels = [0, 1, 4, 7, 3, 2]
index = torch.eye(8)
one_hot = torch.index_select(index, dim = 0, index = labels)

Method3: use Embedding module

emb = nn.Embedding(8, 8)
emb.weight.data = torch.eye(8)

then we can get

emb(Variable(torch.LongTensor([1, 2], [3, 4])))
Variable containing:
(0,.,.) = 
0 1 0 0 0 0 0 0 
0 0 1 0 0 0 0 0
(1 ,.,.) =
0 0 0 1 0 0 0 0
0 0 0 0 1 0 0 0

Method4: create a module

class One_Hot(nn.Module):
    def __init__(self, depth):
        super(One_Hot,self).__init__()
        self.depth = depth
        self.ones = torch.sparse.torch.eye(depth)
    def forward(self, X_in):
        X_in = X_in.long()
        return Variable(self.ones.index_select(0,X_in.data))
    def __repr__(self):
        return self.__class__.__name__ + "({})".format(self.depth)

您可能感兴趣的与本文相关的镜像

Qwen3-32B-Chat 私有部署镜像 | RTX4090D 24G 显存 CUDA12.4 优化版

Qwen3-32B-Chat 私有部署镜像 | RTX4090D 24G 显存 CUDA12.4 优化版

Qwen
文本生成
Qwen3

本镜像基于 RTX 4090D 24GB 显存 + CUDA 12.4 + 驱动 550.90.07 深度优化,内置完整运行环境与 Qwen3-32B 模型依赖,开箱即用。

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值