在pytorch 中实现真正的 pairwise distances
文章目录问题解决方法问题pairwise distances即输入两个张量,比如张量 AM×D,BN×DA^{M \times D} ,B^{N \times D}AM×D,BN×D,M,N分布代表数据数量,D为特征维数,输出张量A和B 两两之间的距离,即一个 M×NM \times NM×N 的张量.这个在 sklearn 中有个很方便的函数 pairwise_distances,其实这个功能在
问题
pairwise distances即输入两个张量,比如张量
A
M
×
D
,
B
N
×
D
A^{M \times D} ,B^{N \times D}
AM×D,BN×D,M,N分布代表数据数量,D为特征维数,输出张量A和B 两两之间的距离,即一个
M
×
N
M \times N
M×N 的张量.
这个在 sklearn 中有个很方便的函数 pairwise_distances
,其实这个功能在 pytorch 中也有实现.
但是很坑跌的是,torch 中居然要求张量 A,B 的形状一样= =||||.
因此这里记录一下,自己的处理方法:即借用广播机制重新处理一下输入张量
解决方法
import torch
from torch import nn
a=torch.tensor([[1,1,1],[2,2,2]])
b=torch.tensor([
[2,2,2],[1,1,1],[2,2,2],[1,1,1],[2,2,2]
])
print(a.shape)
print(b.shape)
def pdist(a: torch.Tensor, b: torch.Tensor, p: int = 2) -> torch.Tensor:
return (a-b).abs().pow(p).sum(-1).pow(1/p)
a_=a.unsqueeze(1)
b_=b.unsqueeze(0)
print(pdist(a_,b_))
输出:
>>> torch.Size([2, 3])
>>> torch.Size([5, 3])
>>> tensor([[1.7321, 0.0000, 1.7321, 0.0000, 1.7321],
[0.0000, 1.7321, 0.0000, 1.7321, 0.0000]])
更多推荐
所有评论(0)