Eztorch Modules¶
Gather¶
- eztorch.modules.concat_all_gather_with_backprop(x, dim=0)[source]¶
Gather tensor across devices with grad.
- Parameters:
x (
Tensor
) – Tensor to gather.dim (
int
, optional) – Dimension to concat.Default:0
- Return type:
Tensor
- Returns:
Gathered tensor.
- eztorch.modules.concat_all_gather_without_backprop(x, dim=0)[source]¶
Gather tensor across devices without grad.
- Parameters:
x (
Tensor
) – Tensor to gather.dim (
int
, optional) – Dimension to concat.Default:0
- Return type:
Tensor
- Returns:
Gathered tensor.
- eztorch.modules.get_world_size()[source]¶
Returns the world size.
- Return type:
int
- Returns:
The world size.
Split Batch Norm¶
- eztorch.modules.convert_to_split_batchnorm(module, num_splits)[source]¶
Convert BatchNorm layers to SplitBatchNorm layers in module.
- Parameters:
module (
Module
) – Module to convert.num_splits (
int
) – Number of splits for theSplitBatchNorm2D
layers.
- Return type:
Module
- Returns:
The converted module.