Eztorch Modules#

Gather#

class eztorch.modules.GatherLayer(*args, **kwargs)[source]#

Gather tensor across devices with grad.

eztorch.modules.concat_all_gather_with_backprop(x, dim=0)[source]#

Gather tensor across devices with grad.

Parameters:
  • x (

    Default:) – Tensor to gather.

  • dim (

    Default:, optional) – Dimension to concat.
    Default: 0

Return type:

Default:

Returns:

Gathered tensor.

eztorch.modules.concat_all_gather_without_backprop(x, dim=0)[source]#

Gather tensor across devices without grad.

Parameters:
  • x (

    Default:) – Tensor to gather.

  • dim (

    Default:, optional) – Dimension to concat.
    Default: 0

Return type:

Default:

Returns:

Gathered tensor.

eztorch.modules.get_world_size()[source]#

Returns the world size.

Return type:

Default:

Returns:

The world size.

Split Batch Norm#

eztorch.modules.convert_to_split_batchnorm(module, num_splits)[source]#

Convert BatchNorm layers to SplitBatchNorm layers in module.

Parameters:
  • module (

    Default:) – Module to convert.

  • num_splits (

    Default:) – Number of splits for the SplitBatchNorm2D layers.

Return type:

Default:

Returns:

The converted module.