文章目录

问题

pairwise distances即输入两个张量,比如张量 A M × D , B N × D A^{M \times D} ,B^{N \times D} AM×D,BN×D,M,N分布代表数据数量,D为特征维数,输出张量A和B 两两之间的距离,即一个 M × N M \times N M×N 的张量.
这个在 sklearn 中有个很方便的函数 pairwise_distances,其实这个功能在 pytorch 中也有实现.
但是很坑跌的是,torch 中居然要求张量 A,B 的形状一样= =||||.
因此这里记录一下,自己的处理方法:即借用广播机制重新处理一下输入张量

解决方法

import torch
from torch import nn

a=torch.tensor([[1,1,1],[2,2,2]])
b=torch.tensor([
                [2,2,2],[1,1,1],[2,2,2],[1,1,1],[2,2,2]
])
print(a.shape)
print(b.shape)
def pdist(a: torch.Tensor, b: torch.Tensor, p: int = 2) -> torch.Tensor:
    return (a-b).abs().pow(p).sum(-1).pow(1/p)

a_=a.unsqueeze(1)
b_=b.unsqueeze(0)

print(pdist(a_,b_))

输出:

>>> torch.Size([2, 3])
>>> torch.Size([5, 3])
>>> tensor([[1.7321, 0.0000, 1.7321, 0.0000, 1.7321],
        [0.0000, 1.7321, 0.0000, 1.7321, 0.0000]])
Logo

鸿蒙生态一站式服务平台。

更多推荐