Pytorch 的 Tensor 用法
官方解释:https://pytorch.org/docs/stable/tensors.html?highlight=scatter_add#torch.Tensor.scatter_add_
函数参数:scatter_add_(dim, indexTensor, otherTensor) → 输出Tensor
函数用法:selfTensor.scatter_add_(dim, indexTensor, otherTensor)
要求:
self,indexandothershould have same number of dimensions.index.size(d) <= other.size(d)for all dimensionsdindex.size(d) <= self.size(d)for all dimensionsd != dim.- as for
gather(), the values ofindexmust be between0andself.size(dim) - 1 - all values in a row along the specified dimension
dimmust be unique.
示例代码:final_dist = vocab_dist_.scatter_add(1, enc_batch_extend_vocab, attn_dist_)
该函数将 otherTensor 的所有值加到 selfTensor 中,加入位置由 indexTensor 指明。
self[ index[i][j][k] ][ j ][ k ] += other[ i ][ j ][ k ] # if dim == 0
本文详细解析PyTorch中Tensor的scatter_add_函数,介绍其参数要求及使用场景,通过实例展示如何将otherTensor的值依据indexTensor指定的位置累加到selfTensor中。
344

被折叠的 条评论
为什么被折叠?



