动手机器学习-实现逻辑回归

动手实现逻辑回归

逻辑回归完整代码想抄作业的同学直接点这里

逻辑回归算法简介

逻辑回归又称对数几率回归

在线性回归模型中,我们学会了使用线性模型完成回归的问题,但是如果任务是分类任务怎么办呢?答案是使用一个单调可微函数将实际标签y和线性回归的预测值联系起来

首先考虑二分类任务,实际标签为0或1,而线性回归模型的预测值为$z = XW + b$为实值,于是我们使用单位阶跃函数,将z转换为0/1值。

阶跃函数如下所示:

$$ y = \begin{cases} 0, z < 0 \\ 0.5, z = 0 \\ 1, z > 0 \\ \end{cases} $$

但是,我们可以发现阶跃函数有一些硬直,不可求导,因此使用更加平滑的Sigmoid function代替。

Sigmoid函数如下所示:

$$ prob = \frac{1}{1 + e^{-z}} $$

如果Sigmoid函数输出的prob值,我们可以进一步确定样本的类别,比如prob>0.5,样本类别为1,prob<=0.5样本类别为0

怎么求解

逻辑回归计算过程很简单就是:将数据带入z = XW + b中,然后将z带入Sigmoid函数中去,最终求得计算结果。

但是我们怎么求解W的值呢?

损失函数

损失函数的定义: 通过Sigmoid我们可以计算出所属类别的概率,因此我们可以更好的定义损失函数。

逻辑回归中的损失函数如下所示,下面我们进行讨论合理性

$$ loss = - y^* log(prob) - (1 - y^*)log(1 - prob) $$

其中$y^*$为标签的实际值,prob是Sigmoid函数的输出值。

这个函数很有意思

当$y^*=1$时,$loss=-log(prob)$,意思就是prob值越接近1越好。

当$y^*=0$时,$loss=- log(1 - prob)$,意思就是prob值越接近0越好。

总之,该损失函数能使prob值和数据的标签尽可能接近。

求解最优权重-求梯度

为了更好理解梯度的求法,我们重新写一下数据进行计算的过程

$$z = w_0 + w_1 x_1 + w_2 x_2 =
\left[ \begin{matrix} 1 & x_1 & x_2 \end{matrix} \right]
\left[ \begin{matrix} w_0 \\ w_1 \\ w_2 \end{matrix} \right] \tag{1} $$

$$Z = XW \tag{2}$$,其中假设有10组数据每条数据有3个属性1个标签X维度(10,3),Z维度(10,1)

计算Z对W的导数,为下文做铺垫

$$\frac{d Z}{d W} = X^T $$

$$ prob = \frac{1}{1 + e^{-z}} \tag{3}$$ prob维度为(10,1)

计算prob对z的导数,为下文做铺垫

$$\frac{d prob}{d z} = \frac{e^{-z}} { (1+e^{-z})^2 } = \frac{1}{1 + e^{-z}} \frac{e^{-z}}{1 + e^{-z}} = prob (1 - prob) $$

综合公式1-3,可得 $$ prob = \frac{1}{1 + e^{-XW}} $$

$$ loss = -y^*log(prob) -(1-y)*log(1-prob) \tag{4} $$

计算loss对prob的导数,为下文做铺垫

$$ \frac{d loss}{d prob} = - \frac{y^{*}}{prob} - \frac{1 - y^{*}}{1-prob} * -1 $$

接下来进行链式求导,其中loss ~ prob ~ z ~ WX

注意下面的计算都是元素级元素,不涉及到矩阵运算

$$ \begin{equation} \begin{split} grad &= \frac{d loss}{d W} = \frac{d loss}{d prob} * \frac{d prob}{d z} * \frac{d z}{W} \\ &= X^T (- \frac{y^*}{prob} + \frac{1 - y^*}{1-prob}) * prob(1 - prob) \\ &= X^T ( - y^*(1 - prob) + (1 - y^*)prob ) \\ &= X^T (prob - y^*) \end{split} \nonumber \end{equation} $$

最终结果维度为(3,10)乘以(10,1)等于(3,1)

求解最优权重-梯度下降

步骤1. 初始化W

步骤2. W -= lr * grad (其中lr为学习率,值到0-1之间)

步骤3. 迭代都一定次数,或阈值停止

纸上得来终决浅,写一编代码便能收获颇多,这里建议你点击我,查看源代码,进行仿写

参考博客

  1. B站视频推导
  2. Mathjax使用方法
updatedupdated2022-04-052022-04-05