Einops and Einsum Summarized

4 minute read

Published:

A brief summary on einops and einsum, usage documentation and an implementation of Average Pooling in CNNs using einops (inspired from the max pooling layer implemented in the original library documentation).

Einops

Visualizing how einops modify the image is a great way to observe its impact on tensors. Einops provides a universal language to perform tensor operations. It supports multiple tensor frameworkes (PyTorch, Tensorflow, Numpy etc) and maps the universal language to native operations on the tensor frameworks. Due to native support, it can be used in Deep Learning without worrying about backpropagation.

Rearrange

  1. Can be used to take a transpose
     '''
     taking transpose of length and breadth, keeping colour channel intact.
     '''
     >> rearrange(tensor, "l b c -> b l c")
    
  2. Composition of a new dimension
     '''
     combining the batch and length dimension
     to stack all images one below the other
     '''
     >> rearrange(tensor, "bat l b c -> (bat l) b c")
    
     '''
     combining batch and breadth dimension to stack all images one after the other from left to right.
     '''
     >> rearrange(tensor, "bat l b c -> l (bat b) c")
    
  3. Decomposition (reduce a dimension into 2)
     '''
     Implementing a 2 level batching. Changing from batch*length*breadth*color to batch1*batch2*length*breadth*color
     bat2 automatically set based on the value of bat1
     '''
     >> rearrange(tensor, "(bat1 bat2) l b c -> bat1 bat2 l b c", bat1=2)
    
  4. Stacking and Concatenating
     '''
     Merging 2 batches.
     '''
     >> rearrange([tensor1, tensor2], "bat1 bat2 l b c -> (bat1 bat2) l b c", bat1=2)
    
  5. Expanding Dimensions (Adding a unit dimension). Similar to squeeze() and unsqueeze() functions in PyTorch.
     '''
     Resulting tensor is of shape (ba, 1, l, b, c, 1).
     Only unitary dimensions are allowed in this. For non unitary dimensions, see repeat()
     '''
     >> rearrange(tensor, "ba l b c -> ba 1 l b c 1")
    

Reduce

The axis that is dropped in the string is the one that is reduced

'''
Computes Average over the color channel dimension i.e. converts a colour image into black-white
'''
>> reduce(tensor, "l b c -> l b", "mean")

'''
Reducing a batched set
'''
>> reduce(tensor, "ba l b c -> (ba l) b", "mean")

Repeat

The opposite of Reduce. Here you add a dimension by creating a duplicate of the data

'''
Create duplicate copy of the image across batch dimension
'''
>> repeat(tensor, "l b c -> 5 l b c")

>> repeat(tensor, "l b c -> batch l b c", batch=10) # does the same thing

'''
Duplicates the image across the width dimension
'''
>> repeat(tensor, "l b c -> l (3 b) c")

Demo: Reduce Operation <> Max Pooling Layer

>> feature_map = torch.tensor([[1,2,1,2],[3,4,3,4],[1,2,1,2],[3,4,3,4]])
'''
[
    [1,2,1,2],
    [3,4,3,4],
    [1,2,1,2],
    [3,4,3,4]
]
'''
>> post_pooling = reduce(feature_map, "(le 2) (br 2) -> le br", "max")
'''
Result:
[
    [4,4],
    [4,4]
]
'''

Demo: Reduce Operation <> Avg Pooling Layer

>> feature_map = torch.tensor([[1,2,1,2],[3,4,3,4],[1,2,1,2],[3,4,3,4]])
'''
[
    [1,2,1,2],
    [3,4,3,4],
    [1,2,1,2],
    [3,4,3,4]
]
'''
>> post_pooling = reduce(feature_map, "(le 2) (br 2) -> le br", "sum") / 4.0
'''
Result:
[
    [2.500,2.500],
    [2.500,2.500]
]
'''

Using Layers (Class) instead of Operations (function)

Using the class directly allows you to add it to model definition.

model = Sequential(
    Conv2d(3, 6, kernel_size=5),
    MaxPool2d(kernel_size=2),
    Conv2d(6, 16, kernel_size=5),
    # combined pooling and flattening in a single step
    Reduce('b c (h 2) (w 2) -> b (c h w)', 'max'), 
    Linear(16*5*5, 120), 
    ReLU(),
    Linear(120, 10), 
)

Einsum

Common Operations

OperationNoteCode
Matrix transpose torch.einsum('ij -> ji',[a])
SumSum of all elements of a matrixtorch.einsum('ij ->',[a])
Column SumColumn-wise sumtorch.einsum('ij -> j',[a])
Row SumRow-wise sumtorch.einsum('ij -> i',[a])
Matrix-vector multiplication torch.einsum('ij,j -> i',[a, b])
Matrix-Matrix multiplication torch.einsum('ij,jk -> ik',[a, b])
Dot productVector-Vectortorch.einsum('i,i ->',[a, b])
Dot productMatrix-Matrixtorch.einsum('ij,ij ->',[a, b])
Hardman productElementwise-multiplication between 2 matricestorch.einsum('ij,ij -> ij',[a, b])
Outer product torch.einsum('i,j -> ij',[a, b])
Batch matrix multiplication torch.einsum('bij,bjk -> bik',[a, b])

References

  1. Einops Documentation
  2. Einsum Documentation
  3. Neel Nanda’s Getting Started with Interpretability (Inspiration)