PyTorch F.cosine_similarity 的 dim 参数详解:从基础应用到矩阵两两相似度计算

张开发
2026/5/24 20:04:29 15 分钟阅读
PyTorch F.cosine_similarity 的 dim 参数详解:从基础应用到矩阵两两相似度计算
1. 为什么dim参数让人头疼第一次用PyTorch的F.cosine_similarity函数时我也被dim参数搞得晕头转向。明明只是想算两个向量的相似度结果发现输入矩阵后输出的结果完全不符合预期。后来才发现这个dim参数其实决定了计算相似度的方向——是沿着行比较还是沿着列比较。举个生活中的例子假设你有两本菜谱两个矩阵每道菜每行用几种调料列表示。你想比较这两本菜谱的相似度可以有两种方式横向比较看相同位置的菜是否用了相似调料按行比较纵向比较看相同调料的用量在不同菜谱中是否相似按列比较这就是dim参数的核心作用——决定比较的方向。下面我们用代码来验证这个理解。import torch import torch.nn.functional as F # 创建两个2x2的示例矩阵 a torch.tensor([[1, 2], [3, 4]], dtypetorch.float) b torch.tensor([[5, 6], [7, 8]], dtypetorch.float)2. dim参数实战解析2.1 dim0时的列向量比较当设置dim0时函数会比较两个矩阵对应列的相似度。让我们运行代码看看res F.cosine_similarity(a, b, dim0) print(res) # 输出: tensor([0.9558, 0.9839])这个结果是什么意思呢它表示第一个数字0.9558是两矩阵第一列的相似度[1,3]和[5,7]的余弦相似度第二个数字0.9839是两矩阵第二列的相似度[2,4]和[6,8]的余弦相似度为了验证我们可以手动计算第一列的相似度def manual_cosine_sim(vec1, vec2): dot sum(v1*v2 for v1,v2 in zip(vec1,vec2)) norm1 sum(v**2 for v in vec1)**0.5 norm2 sum(v**2 for v in vec2)**0.5 return dot / (norm1 * norm2) print(manual_cosine_sim([1,3], [5,7])) # 输出: 0.95582.2 dim1时的行向量比较默认行为如果不指定dim参数函数默认使用dim1也就是按行比较res F.cosine_similarity(a, b) # 默认dim1 print(res) # 输出: tensor([0.9734, 0.9972])这个结果表示0.9734是第一行[1,2]和[5,6]的相似度0.9972是第二行[3,4]和[7,8]的相似度同样可以手动验证print(manual_cosine_sim([1,2], [5,6])) # 输出: 0.9734 print(manual_cosine_sim([3,4], [7,8])) # 输出: 0.99722.3 为什么默认是dim1在深度学习领域数据通常以(batch_size, features)的形式组织。比如一个32x128的矩阵表示32个样本每个样本有128个特征。这种情况下按行(dim1)比较更符合直觉——比较样本之间的特征相似度。3. 矩阵两两相似度计算实战3.1 问题场景假设现在有个实际需求计算一个矩阵中所有行向量两两之间的相似度。比如在推荐系统中需要计算所有用户之间的相似度或者在NLP中计算所有句子嵌入的相似度。直接使用F.cosine_similarity(a, b, dim1)只能计算对应行的相似度无法得到所有组合。这时候就需要维度扩展技巧。3.2 解决方案unsqueeze扩展维度关键思路是通过unsqueeze增加维度让PyTorch自动进行广播计算# 原始矩阵 a torch.tensor([[1, 2], [3, 4]], dtypetorch.float) b torch.tensor([[5, 6], [7, 8]], dtypetorch.float) # 维度扩展 x a.unsqueeze(1) # 形状变为[2,1,2] y b.unsqueeze(0) # 形状变为[1,2,2] # 计算相似度 res F.cosine_similarity(x, y, dim-1) print(res) 输出: tensor([[0.9734, 0.9676], [0.9987, 0.9972]]) 这个结果矩阵的含义是res[0,0]: a的第0行和b的第0行相似度res[0,1]: a的第0行和b的第1行相似度res[1,0]: a的第1行和b的第0行相似度res[1,1]: a的第1行和b的第1行相似度3.3 原理解析为什么这样能实现两两计算关键在于广播机制x的形状是[2,1,2]可以看作2个1x2的矩阵y的形状是[1,2,2]可以看作1个2x2的矩阵广播后PyTorch会自动进行2x2次比较dim-1表示在最后一个维度特征维度计算相似度3.4 实际应用示例假设我们有一个用户特征矩阵想计算所有用户之间的相似度# 5个用户每个用户有4个特征 user_features torch.randn(5, 4) # 计算所有用户两两相似度 x user_features.unsqueeze(1) # [5,1,4] y user_features.unsqueeze(0) # [1,5,4] sim_matrix F.cosine_similarity(x, y, dim-1) print(sim_matrix.shape) # 输出: torch.Size([5, 5])这样我们就得到了一个5x5的相似度矩阵sim_matrix[i,j]表示用户i和用户j的相似度。4. 高级技巧与常见陷阱4.1 处理不同形状的输入有时候我们需要计算一个矩阵中各行与另一个向量比如查询向量的相似度matrix torch.randn(10, 128) # 10个样本每个128维 query torch.randn(128) # 单个查询向量 # 需要将query扩展为1x128 query query.unsqueeze(0) # 现在是[1,128] sim F.cosine_similarity(matrix, query, dim1)4.2 归一化的重要性余弦相似度计算前如果数据没有归一化可能会得到不符合直觉的结果。建议先进行L2归一化def safe_cosine_sim(a, b, dim1): a_norm F.normalize(a, p2, dimdim) b_norm F.normalize(b, p2, dimdim) return F.cosine_similarity(a_norm, b_norm, dimdim)4.3 性能考量当处理大规模矩阵时比如计算10000x10000的相似度矩阵直接使用这种方法会消耗大量内存。这时可以考虑分块计算或者使用更高效的实现。5. 其他常见问题解答5.1 dim-1是什么意思dim-1表示在最后一个维度计算相似度。对于形状为[m,n,k]的张量dim-1等价于dim2。这种写法更灵活当你不确定张量的维度时特别有用。5.2 可以计算三维张量的相似度吗当然可以。比如处理一批图像特征时# 假设有32个图像每个图像有5个区域每个区域有128维特征 features torch.randn(32, 5, 128) # 计算每个图像内部区域之间的相似度 x features.unsqueeze(2) # [32,5,1,128] y features.unsqueeze(1) # [32,1,5,128] sim F.cosine_similarity(x, y, dim-1) # 结果形状[32,5,5]5.3 与矩阵乘法的关系余弦相似度计算可以看作是在L2归一化后的矩阵乘法# 等价计算方式 a_norm F.normalize(a, p2, dim1) b_norm F.normalize(b, p2, dim1) sim_matrix torch.mm(a_norm, b_norm.T)不过F.cosine_similarity的实现更高效特别是处理广播情况时。

更多文章