Losses#
MoCo#
- eztorch.losses.compute_moco_loss(q, k, k_global, use_keys, queue, temp=0.2, rank=0)[source]#
Compute the SCE loss.
- Parameters:
q (
Default:) – The representations of the queries.k (
Default:) – The representations of the keys.k_global (
Default:) – The global representations of the keys.use_keys (
Default:) – Whether to use the non-positive elements from key.temp (
Default:, optional) – Temperature applied to the query similarities.Default:0.2
rank (
Default:, optional) – Rank of the device for positive labels.Default:0
- Return type:
Default:- Returns:
The loss.
- eztorch.losses.compute_mocov3_loss(q, k, temp=1.0, rank=0)[source]#
Compute the MoCov3 loss.
- Parameters:
q (
Default:) – The representations of the queries.k (
Default:) – The global representations of the keys.temp (
Default:, optional) – Temperature for softmax.Default:1.0
rank (
Default:, optional) – Rank of the device for positive labels.Default:0
- Return type:
Default:- Returns:
The loss.
ReSSL#
- eztorch.losses.compute_ressl_loss(q, k, k_global, use_keys, queue, mask, temp=0.1, temp_m=0.04, LARGE_NUM=1000000000.0)[source]#
Compute the RESSL loss.
- Parameters:
q (
Default:) – The representations of the queries.k (
Default:) – The representations of the keys.k_global (
Default:) – The global representations of the keys.use_keys (
Default:) – Whether to use the non-positive elements from key.queue (
Default:) – The queue of representations.mask (
Default:) – Mask of positives for the query.temp (
Default:, optional) – Temperature applied to the query similarities.Default:0.1
temp_m (
Default:, optional) – Temperature applied to the keys similarities.Default:0.04
LARGE_NUM (
Default:, optional) – Large number to mask elements.Default:1000000000.0
- Return type:
Default:- Returns:
The loss.
- eztorch.losses.compute_ressl_mask(batch_size, num_negatives, use_keys=True, rank=0, world_size=1, device='cpu')[source]#
Precompute the mask for ReSSL.
- Parameters:
batch_size (
Default:) – The local batch size.num_negatives (
Default:) – The number of negatives besides the non-positive key elements.use_keys (
Default:, optional) – Whether to use the non-positive elements from the key as negatives.Default:True
rank (
Default:, optional) – Rank of the current process.Default:0
world_size (
Default:, optional) – Number of processes that perform training.Default:1
device (
Default:, optional) – Device that performs training.Default:'cpu'
- Return type:
Default:- Returns:
The mask.
SCE#
- eztorch.losses.compute_sce_loss(q, k, k_global, use_keys, queue, mask, coeff, temp=0.1, temp_m=0.07, LARGE_NUM=1000000000.0)[source]#
Compute the SCE loss.
- Parameters:
q (
Default:) – The representations of the queries.k (
Default:) – The representations of the keys.k_global (
Default:) – The global representations of the keys.use_keys (
Default:) – Whether to use the non-positive elements from key.queue (
Default:) – The queue of representations.mask (
Default:) – Mask of positives for the query.coeff (
Default:) – Coefficient between the contrastive and relational aspects.temp (
Default:, optional) – Temperature applied to the query similarities.Default:0.1
temp_m (
Default:, optional) – Temperature applied to the keys similarities.Default:0.07
LARGE_NUM (
Default:, optional) – Large number to mask elements.Default:1000000000.0
- Return type:
Default:- Returns:
The loss.
- eztorch.losses.compute_sce_mask(batch_size, num_negatives, use_keys=True, rank=0, world_size=1, device='cpu')[source]#
Precompute the mask for SCE.
- Parameters:
batch_size (
Default:) – The local batch size.num_negatives (
Default:) – The number of negatives besides the non-positive key elements.use_keys (
Default:, optional) – Whether to use the non-positive elements from the key as negatives.Default:True
rank (
Default:, optional) – Rank of the current process.Default:0
world_size (
Default:, optional) – Number of processes that perform training.Default:1
device (
Default:, optional) – Device that performs training.Default:'cpu'
- Return type: