sigmoid_focal_loss

paddle.nn.functional. sigmoid_focal_loss ( logit, label, normalizer=None, alpha=0.25, gamma=2.0, reduction='sum', name=None ) [源代码]

Focal Loss 用于解决分类任务中的前景类-背景类数量不均衡的问题。在这种损失函数,易分样本的占比被减少,而难分样本的比重被增加。例如在一阶段的目标检测任务中,前景-背景不均衡表现得非常严重。

该算子通过下式计算 focal loss:

\[Out = -Labels * alpha * {(1 - \sigma(Logit))}^{gamma}\log(\sigma(Logit)) - (1 - Labels) * (1 - alpha) * {\sigma(Logit)}^{gamma}\log(1 - \sigma(Logit))\]

其中 \(\sigma(Logit) = \frac{1}{1 + \exp(-Logit)}\)

normalizer 不为 None 时,该算子会将输出损失 Out 除以 Tensor normalizer

\[Out = \frac{Out}{normalizer}\]

最后,该算子会添加 reduce 操作到前面的输出 Out 上。当 reduction'none' 时,直接返回最原始的 Out 结果。当 reduction'mean' 时,返回输出的均值 \(Out = MEAN(Out)\)。当 reduction'sum' 时,返回输出的求和 \(Out = SUM(Out)\)

注意:标签值 0 表示背景类(即负样本),1 表示前景类(即正样本)。

参数

  • logit (Tensor) - 维度为 \([N, *]\),其中 N 是 batch_size, * 是任意其他维度。输入数据 logit 一般是卷积层的输出,不需要经过 sigmoid 层。数据类型是 float32、float64。

  • label (Tensor) - 维度为 \([N, *]\),标签 label 的维度、数据类型与输入 logit 相同,取值范围 \([0,1]\)。数据类型是 float32、float64。

  • normalizer (Tensor,可选) - 维度为 \([1]\) ,focal loss 的归一化系数,数据类型与输入 logit 相同。若设置为 None,则不会将 focal loss 做归一化操作(即不会将 focal loss 除以 normalizer)。在目标检测任务中,设置为正样本的数量。默认值为 None。

  • alpha (int|float,可选) - 用于平衡正样本和负样本的超参数,取值范围 \([0,1]\)。默认值设置为 0.25。

  • gamma (int|float,可选) - 用于平衡易分样本和难分样本的超参数,默认值设置为 2.0。

  • reduction (str,可选) - 指定应用于输出结果的计算方式,可选值有:'none', 'mean', 'sum'。默认为 'mean',计算 focal loss 的均值;设置为 'sum' 时,计算 focal loss 的总和;设置为 'none' 时,则返回原始 loss。

  • name (str,可选) - 具体用法请参见 Name,一般无需设置,默认值为 None。

返回

  • Tensor,输出的 Tensor。如果 reduction'none',则输出的维度为 \([N, *]\),与输入 logit 的形状相同。如果 reduction'mean''sum',则输出的维度为 \([]\)

代码示例

import paddle

logit = paddle.to_tensor([[0.97, 0.91, 0.03], [0.55, 0.43, 0.71]], dtype='float32')
label = paddle.to_tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]], dtype='float32')
one = paddle.to_tensor([1.], dtype='float32')
fg_label = paddle.greater_equal(label, one)
fg_num = paddle.sum(paddle.cast(fg_label, dtype='float32'))
output = paddle.nn.functional.sigmoid_focal_loss(logit, label, normalizer=fg_num)
print(output)  # 0.65782464