1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29
| a = torch.arange(30).view(2,3,5) b = torch.arange(40).view(2,5,4) torch.einsum('ijk, ikl->ijl', a, b)
Out: a -> tensor([[[ 0, 1, 2, 3, 4], [ 5, 6, 7, 8, 9], [10, 11, 12, 13, 14]], [[15, 16, 17, 18, 19], [20, 21, 22, 23, 24], [25, 26, 27, 28, 29]]])
b -> tensor([[[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11], [12, 13, 14, 15], [16, 17, 18, 19]], [[20, 21, 22, 23], [24, 25, 26, 27], [28, 29, 30, 31], [32, 33, 34, 35], [36, 37, 38, 39]]])
result -> tensor([[[ 120, 130, 140, 150], [ 320, 355, 390, 425], [ 520, 580, 640, 700]], [[2420, 2505, 2590, 2675], [3120, 3230, 3340, 3450], [3820, 3955, 4090, 4225]]])
|