Eztorch Modules

Gather

class eztorch.modules.GatherLayer(*args, **kwargs)[source]

Gather tensor across devices with grad.

eztorch.modules.concat_all_gather_with_backprop(x, dim=0)[source]

Gather tensor across devices with grad.

Parameters:
  • x (Tensor) – Tensor to gather.

  • dim (int, optional) – Dimension to concat.

    Default: 0

Return type:

Tensor

Returns:

Gathered tensor.

eztorch.modules.concat_all_gather_without_backprop(x, dim=0)[source]

Gather tensor across devices without grad.

Parameters:
  • x (Tensor) – Tensor to gather.

  • dim (int, optional) – Dimension to concat.

    Default: 0

Return type:

Tensor

Returns:

Gathered tensor.

eztorch.modules.get_world_size()[source]

Returns the world size.

Return type:

int

Returns:

The world size.

Split Batch Norm

eztorch.modules.convert_to_split_batchnorm(module, num_splits)[source]

Convert BatchNorm layers to SplitBatchNorm layers in module.

Parameters:
  • module (Module) – Module to convert.

  • num_splits (int) – Number of splits for the SplitBatchNorm2D layers.

Return type:

Module

Returns:

The converted module.