1. torch.topk()基础入门
torch.topk()是PyTorch中一个非常实用的函数,它能够帮助我们快速找到张量中的前K个最大或最小值。这个函数在数据处理和模型训练中经常用到,特别是在需要筛选重要数据的场景下。
先来看一个最简单的例子。假设我们有一个包含5个数字的一维张量:
import torch
x = torch.tensor([10, 20, 5, 40, 30])
values, indices = torch.topk(x, 3)
print("前3大值:", values) # 输出: tensor([40, 30, 20])
print("对应索引:", indices) # 输出: tensor([3, 4, 1])
这个例子中,我们找到了数组中最大的3个数字40、30、20,以及它们所在的位置索引3、4、1。这个功能看似简单,但在实际应用中非常有用。
torch.topk()的函数签名是这样的:
torch.topk(input, k, dim=None, largest=True, sorted=True) -> (Tensor, LongTensor)
参数说明:
- input:输入张量,可以是任意维度的
- k:要返回的元素数量
- dim:操作的维度,默认是最后一个维度
- largest:True返回最大值,False返回最小值
- sorted:是否对结果排序,默认True
2. 多维张量的topk操作
在实际项目中,我们更多时候处理的是多维张量。torch.topk()可以指定在哪个维度上进行操作,这给了我们很大的灵活性。
2.1 二维张量的行列操作
假设我们有一个3x4的矩阵:

1万+

被折叠的 条评论
为什么被折叠?



