+1 vote

Best answer

**torch.chunk()** function can divide a tensor into "k" chunks along a given axis. This function returns a list of tensors. Each chunk may not be of the same size; it depends on the size of the tensor along the given axis and the number of chunks.

E.g., If there are five columns in a tensor and you want to divide it into two chunks, one tensor will have three columns, and the other tensor will have two columns.

Here is an example:

I am splitting the tensor into two chunks, horizontally and vertically.

>> import torch

>>> x=torch.tensor([[1,2,3,4,5],[11,12,13,14,15],[21,22,23,24,25],[31,32,33,34,35]])

>>> x

tensor([[ 1, 2, 3, 4, 5],

[11, 12, 13, 14, 15],

[21, 22, 23, 24, 25],

[31, 32, 33, 34, 35]])

>>> torch.chunk(x,2,dim=0)(tensor([[ 1, 2, 3, 4, 5],

[11, 12, 13, 14, 15]]), tensor([[21, 22, 23, 24, 25],

[31, 32, 33, 34, 35]]))

>>> torch.chunk(x,2,dim=1)(tensor([[ 1, 2, 3],

[11, 12, 13],

[21, 22, 23],

[31, 32, 33]]), tensor([[ 4, 5],

[14, 15],

[24, 25],

[34, 35]]))