在神经网络的输出层,典型的是使用softmax函数来近似一个概率分布:

因为指数的缘故,计算起来很费钱。为什么不简单地执行Z变换,使所有输出都是正的,然后通过将所有输出除以所有输出的和来归一化?


当前回答

q_i的值是无界的分数,有时被解释为对数概率。根据这种解释,为了恢复原始概率值,必须对它们求幂。

统计算法经常使用对数似然损失函数的一个原因是它们在数值上更稳定:概率的乘积可以表示为一个非常小的浮点数。使用对数似然损失函数,概率的乘积变成一个和。

另一个原因是,当假设从多元高斯分布中提取随机变量的估计量时,对数似然性自然发生。例如,请参阅最大似然(ML)估计器及其与最小二乘连接的方式。

其他回答

我已经有这个问题好几个月了。似乎我们只是聪明地把softmax猜测为输出函数,然后把softmax的输入解释为对数概率。正如你所说,为什么不简单地通过除以它们的和来规范化所有输出呢?我在Goodfellow、Bengio和Courville(2016)撰写的《深度学习》一书的6.2.2节中找到了答案。

假设最后一个隐藏层给出z作为激活。那么softmax定义为

非常简短的解释

softmax函数中的exp大致抵消了交叉熵损失中的log,导致损失在z_i中大致为线性。这导致了一个大致恒定的梯度,当模型是错误的,允许它迅速纠正自己。因此,一个错误的饱和软最大值不会导致梯度消失。

简短的解释

训练神经网络最流行的方法是最大似然估计。我们以最大化训练数据(大小为m)的似然的方式估计参数theta。由于整个训练数据集的似然是每个样本的似然的乘积,因此更容易最大化数据集的对数似然,从而最大化以k为索引的每个样本的对数似然的和:

现在,我们只关注z已经给定的软最大值,所以我们可以替换

I是KTH样本的正确类。现在,我们看到,当我们取softmax的对数,来计算样本的对数似然时,我们得到:

,对于z相差较大的情况,大致近似于

首先,我们看到线性分量z_i。其次,我们可以在两种情况下检验max(z)的行为:

如果模型正确,那么max(z)将是z_i。因此,随着z_i和z中的其他项之间的差异越来越大,对数似然渐近于0(即似然为1)。 如果模型不正确,则max(z)将是另一个z_j > z_i。因此,z_i的相加并不能完全抵消-z_j,对数似然值大致为(z_i -z_j)。这清楚地告诉模型如何增加log-likelihood:增加z_i,减少z_j。

我们看到,总体对数似然将由样本主导,其中模型是不正确的。此外,即使模型确实不正确,导致饱和软最大值,损失函数也不会饱和。它在z_j中近似是线性的,这意味着我们有一个近似恒定的梯度。这使得模型能够快速自我修正。请注意,这不是均方误差的例子。

长解释

如果softmax对你来说仍然是一个随意的选择,你可以看看在逻辑回归中使用sigmoid的理由:

为什么是s型函数而不是其他函数?

软最大是对多类问题的sigmoid的推广。

选择softmax函数似乎有些武断,因为有许多其他可能的归一化函数。因此,目前还不清楚为什么log-softmax损耗会比其他损耗替代品表现更好。

来自“属于球形损失家族的Softmax替代方案的探索”https://arxiv.org/abs/1511.05042

作者探索了其他一些函数,其中包括泰勒exp展开和所谓的球形软最大值,并发现有时它们可能比通常的软最大值执行得更好。

我认为其中一个原因可能是处理负数并除以0,因为exp(x)总是正的并且大于0。

例如,对于a =[-2, -1, 1,2],和将是0,我们可以使用softmax来避免除0。

虽然它确实有些随意,但softmax具有理想的属性,例如:

易微(df/dx = f*(1-f)) 当用作分类任务的输出层时,输入的分数可以解释为log-odds

与标准归一化相比,Softmax有一个很好的属性。

它对分布均匀的神经网络的低刺激(想象一个模糊的图像)和高刺激(例如。大数字,想想清晰的图像),概率接近0和1。

而标准归一化并不关心,只要比例相同。

看看当soft max有10倍大的输入时会发生什么,即你的神经网络得到一个清晰的图像,许多神经元被激活

>>> softmax([1,2])              # blurry image of a ferret
[0.26894142,      0.73105858])  #     it is a cat perhaps !?
>>> softmax([10,20])            # crisp image of a cat
[0.0000453978687, 0.999954602]) #     it is definitely a CAT !

然后与标准归一化进行比较

>>> std_norm([1,2])                      # blurry image of a ferret
[0.3333333333333333, 0.6666666666666666] #     it is a cat perhaps !?
>>> std_norm([10,20])                    # crisp image of a cat
[0.3333333333333333, 0.6666666666666666] #     it is a cat perhaps !?