Eztorch Modules#
Gather#
- eztorch.modules.concat_all_gather_with_backprop(x, dim=0)[source]#
Gather tensor across devices with grad.
- Parameters:
x (
Default:) – Tensor to gather.dim (
Default:, optional) – Dimension to concat.Default:0
- Return type:
Default:- Returns:
Gathered tensor.
- eztorch.modules.concat_all_gather_without_backprop(x, dim=0)[source]#
Gather tensor across devices without grad.
- Parameters:
x (
Default:) – Tensor to gather.dim (
Default:, optional) – Dimension to concat.Default:0
- Return type:
Default:- Returns:
Gathered tensor.
- eztorch.modules.get_world_size()[source]#
Returns the world size.
- Return type:
- Default:
- Returns:
The world size.
Split Batch Norm#
- eztorch.modules.convert_to_split_batchnorm(module, num_splits)[source]#
Convert BatchNorm layers to SplitBatchNorm layers in module.
- Parameters:
module (
Default:) – Module to convert.num_splits (
Default:) – Number of splits for theSplitBatchNorm2D
layers.- Return type:
Default:- Returns:
The converted module.