1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29
   | a = torch.arange(30).view(2,3,5) b = torch.arange(40).view(2,5,4) torch.einsum('ijk, ikl->ijl', a, b)  
  Out: a -> tensor([[[ 0,  1,  2,  3,  4],               [ 5,  6,  7,  8,  9],               [10, 11, 12, 13, 14]],              [[15, 16, 17, 18, 19],               [20, 21, 22, 23, 24],               [25, 26, 27, 28, 29]]]) 
  b -> tensor([[[ 0,  1,  2,  3],               [ 4,  5,  6,  7],               [ 8,  9, 10, 11],               [12, 13, 14, 15],               [16, 17, 18, 19]],              [[20, 21, 22, 23],               [24, 25, 26, 27],               [28, 29, 30, 31],               [32, 33, 34, 35],               [36, 37, 38, 39]]])
  result -> tensor([[[ 120,  130,  140,  150],                    [ 320,  355,  390,  425],                    [ 520,  580,  640,  700]],                   [[2420, 2505, 2590, 2675],                    [3120, 3230, 3340, 3450],                    [3820, 3955, 4090, 4225]]])
   |