爱因斯坦求和约定的pytorch用法

Reference:

  1. https://zhuanlan.zhihu.com/p/361209187

导入库

1
2
import torch
import numpy as np

torch.einsum

  1. 提取矩阵对角线元素
1
2
3
4
5
6
7
a = torch.arange(9).reshape(3, 3)
torch.einsum('ii->i', a)

Out:
tensor([[0, 1, 2], -> tensor([0, 4, 8])
[3, 4, 5],
[6, 7, 8]])
  1. 矩阵转置
1
2
3
4
5
6
7
a = torch.arange(6).reshape(2,3)
torch.einsum('ij->ji', a)

Out:
tensor([[0, 1, 2], -> tensor([[0, 3],
[3, 4, 5]]) [1, 4],
[2, 5]])
  1. permute 高维张量转置
1
2
3
4
5
6
7
8
9
10
11
12
a = torch.arange(24).reshape(2,3,4)      # torch.Size([2, 3, 4])
torch.einsum('...ij->...ji', a) # torch.Size([2, 4, 3])

Out:
tensor([[[ 0, 1, 2, 3], -> tensor([[[ 0, 4, 8],
[ 4, 5, 6, 7], [ 1, 5, 9],
[ 8, 9, 10, 11]], [ 2, 6, 10],
[[12, 13, 14, 15], [ 3, 7, 11]],
[16, 17, 18, 19], [[12, 16, 20],
[20, 21, 22, 23]]]) [13, 17, 21],
[14, 18, 22],
[15, 19, 23]]])
  1. reduce sum
1
2
3
4
5
6
7
a = torch.arange(6).reshape(2, 3)
torch.einsum('ij->', a)

Out:
tensor([[0, 1, 2], -> tensor(15)
[3, 4, 5]])

  1. 矩阵按列/按行求和
1
2
3
4
5
6
7
8
9
10
11
12
13
a = torch.arange(6).reshape(2, 3)
torch.einsum('ij->j', a) # 行没了

Out:
tensor([[0, 1, 2], -> tensor([3, 5, 7])
[3, 4, 5]])

a = torch.arange(6).reshape(2, 3)
torch.einsum('ij->i', a) # 列没了

Out:
tensor([[0, 1, 2], -> tensor([ 3, 12])
[3, 4, 5]])
  1. 矩阵向量乘法
1
2
3
4
5
6
7
8
a = torch.arange(6).reshape(2,3)
b = torch.arange(3)
torch.einsum('ik, k->i', a, b)
torch.einsum('ik, k', a, b) # 等价形式

Out:
tensor([[0, 1, 2], * tensor([0, 1, 2]) -> tensor([ 5, 14])
[3, 4, 5]])
  1. 矩阵乘法
1
2
3
4
5
6
7
8
9
a = torch.arange(6).reshape(2,3)
b = torch.arange(12).reshape(3,4)
torch.einsum('ij, jk -> ik', a,b)
torch.einsum('ij, jk', a, b) # 等价形式

Out:
tensor([[0, 1, 2], * tensor([[ 0, 1, 2, 3], -> tensor([[20, 23, 26, 29],
[3, 4, 5]]) [ 4, 5, 6, 7], [56, 68, 80, 92]])
[ 8, 9, 10, 11]])
  1. 向量内积
1
2
3
4
5
6
7
a = torch.arange(3)
b = torch.arange(3, 6)
torch.einsum('i, i->', a, b)
torch.einsum('i, i', a, b) # 等价形式

Out:
tensor([0, 1, 2]) * tensor([3, 4, 5]) -> tensor(14)
  1. 矩阵元素对应相乘并求reduce sum
1
2
3
4
5
6
7
8
9
10
11
12
13
a = torch.arange(6).reshape(2, 3)
b = torch.arange(6, 12).reshape(2, 3)
torch.einsum('ij, ij->', a, b)
torch.einsum('ij, ij', a, b) # 等价形式

Out:
tensor([[0, 1, 2], * tensor([[ 6, 7, 8], -> tensor(145)
[3, 4, 5]]) [ 9, 10, 11]])

torch.einsum('ij, ij->ij', a, b) # 对应相乘
Out:
tensor([[0, 1, 2], * tensor([[ 6, 7, 8], -> tensor([[ 0, 7, 16],
[3, 4, 5]]) [ 9, 10, 11]]) [27, 40, 55]])
  1. 向量外积
1
2
3
4
5
6
7
8
a = torch.arange(3)
b = torch.arange(3,7)
torch.einsum('i, j->ij', a, b)

Out:
tensor([0, 1, 2]) * tensor([3, 4, 5, 6]) -> tensor([[ 0, 0, 0, 0],
[ 3, 4, 5, 6],
[ 6, 8, 10, 12]])
  1. batch 矩阵乘法
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
a = torch.arange(30).view(2,3,5)
b = torch.arange(40).view(2,5,4)
torch.einsum('ijk, ikl->ijl', a, b) # torch.Size([2, 3, 4])

Out:
a -> tensor([[[ 0, 1, 2, 3, 4],
[ 5, 6, 7, 8, 9],
[10, 11, 12, 13, 14]],
[[15, 16, 17, 18, 19],
[20, 21, 22, 23, 24],
[25, 26, 27, 28, 29]]])

b -> tensor([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15],
[16, 17, 18, 19]],
[[20, 21, 22, 23],
[24, 25, 26, 27],
[28, 29, 30, 31],
[32, 33, 34, 35],
[36, 37, 38, 39]]])

result -> tensor([[[ 120, 130, 140, 150],
[ 320, 355, 390, 425],
[ 520, 580, 640, 700]],
[[2420, 2505, 2590, 2675],
[3120, 3230, 3340, 3450],
[3820, 3955, 4090, 4225]]])
  1. 张量收缩
1
2
3
4
a = torch.randn(2,3,5,7)
b = torch.randn(11,13,3,17,5)
# 对 r q 所在维度进行reduce sum,调整其他维度
torch.einsum('pqrs,tuqvr->pstuv', a, b) # torch.Size([2, 7, 11, 13, 17])
  1. 二次变换 (bilinear transformation)
1
2
3
4
a = torch.randn(2,3)
b = torch.randn(5,3,7)
c = torch.randn(2,7)
torch.einsum('ik, jkl, il->ij', a, b, c) # torch.Size([2, 5])

einops.rearrange

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# suppose we have a set of 32 images in "h w c" format (height-width-channel)
>>> images = [np.random.randn(30, 40, 3) for _ in range(32)]
# stack along first (batch) axis, output is a single array
>>> rearrange(images, 'b h w c -> b h w c').shape
(32, 30, 40, 3)
# concatenate images along height (vertical axis), 960 = 32 * 30
>>> rearrange(images, 'b h w c -> (b h) w c').shape
(960, 40, 3)
# concatenated images along horizontal axis, 1280 = 32 * 40
>>> rearrange(images, 'b h w c -> h (b w) c').shape
(30, 1280, 3)
# reordered axes to "b c h w" format for deep learning
>>> rearrange(images, 'b h w c -> b c h w').shape
(32, 3, 30, 40)
# flattened each image into a vector, 3600 = 30 * 40 * 3
>>> rearrange(images, 'b h w c -> b (c h w)').shape
(32, 3600)
# split each image into 4 smaller (top-left, top-right, bottom-left, bottom-right), 128 = 32 * 2 * 2
>>> rearrange(images, 'b (h1 h) (w1 w) c -> (b h1 w1) h w c', h1=2, w1=2).shape
(128, 15, 20, 3)
# space-to-depth operation
>>> rearrange(images, 'b (h h1) (w w1) c -> b h w (c h1 w1)', h1=2, w1=2).shape
(32, 15, 20, 12)

einops.repeat

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# a grayscale image (of shape height x width)
>>> image = np.random.randn(30, 40)
# change it to RGB format by repeating in each channel
>>> repeat(image, 'h w -> h w c', c=3).shape
(30, 40, 3)
# repeat image 2 times along height (vertical axis)
>>> repeat(image, 'h w -> (repeat h) w', repeat=2).shape
(60, 40)
# repeat image 2 time along height and 3 times along width
>>> repeat(image, 'h w -> h (repeat w)', repeat=3).shape
(30, 120)
# convert each pixel to a small square 2x2. Upsample image by 2x
>>> repeat(image, 'h w -> (h h2) (w w2)', h2=2, w2=2).shape
(60, 80)
# pixelate image first by downsampling by 2x, then upsampling
>>> downsampled = reduce(image, '(h h2) (w w2) -> h w', 'mean', h2=2, w2=2)
>>> repeat(downsampled, 'h w -> (h h2) (w w2)', h2=2, w2=2).shape
(30, 40)

einops.reduce

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
>>> x = np.random.randn(100, 32, 64)
# perform max-reduction on the first axis
>>> y = reduce(x, 't b c -> b c', 'max')
# same as previous, but with clearer axes meaning
>>> y = reduce(x, 'time batch channel -> batch channel', 'max')
>>> x = np.random.randn(10, 20, 30, 40)
# 2d max-pooling with kernel size = 2 * 2 for image processing
>>> y1 = reduce(x, 'b c (h1 h2) (w1 w2) -> b c h1 w1', 'max', h2=2, w2=2)
# if one wants to go back to the original height and width, depth-to-space trick can be applied
>>> y2 = rearrange(y1, 'b (c h2 w2) h1 w1 -> b c (h1 h2) (w1 w2)', h2=2, w2=2)
>>> assert parse_shape(x, 'b _ h w') == parse_shape(y2, 'b _ h w')
# Adaptive 2d max-pooling to 3 * 4 grid
>>> reduce(x, 'b c (h1 h2) (w1 w2) -> b c h1 w1', 'max', h1=3, w1=4).shape
(10, 20, 3, 4)
# Global average pooling
>>> reduce(x, 'b c h w -> b c', 'mean').shape
(10, 20)
# Subtracting mean over batch for each channel
>>> y = x - reduce(x, 'b c h w -> () c () ()', 'mean')
# Subtracting per-image mean for each channel
>>> y = x - reduce(x, 'b c h w -> b c () ()', 'mean')