动手机器学习-实现逻辑回归

逻辑回归完整代码想抄作业的同学直接点这里

一、逻辑回归算法简介

逻辑回归又称对数几率回归

在线性回归模型中,我们学会了使用线性模型完成回归的问题,但是如果任务是分类任务怎么办呢?答案是使用一个单调可微函数将实际标签y和线性回归的预测值联系起来

首先考虑二分类任务,实际标签为0或1,而线性回归模型的预测值为$z = XW + b$为实值,于是我们使用单位阶跃函数,将z转换为0/1值。

阶跃函数如下所示:

\[ y = \begin{cases} 0, z < 0 \\ 0.5, z = 0 \\ 1, z > 0 \\ \end{cases} \]

但是,我们可以发现阶跃函数有一些硬直,不可求导,因此使用更加平滑的Sigmoid function代替。

Sigmoid函数如下所示:

\[ prob = \frac{1}{1 + e^{-z}} \]

如果Sigmoid函数输出的prob值,我们可以进一步确定样本的类别,比如prob>0.5,样本类别为1,prob<=0.5样本类别为0

阶跃函数如下图中的红线所示,而Sigmoid如下图中的黑线所示:

sigmoid函数

二、怎么求解

逻辑回归计算过程很简单就是:将数据带入z = XW + b中,然后将z带入Sigmoid函数中去,最终求得计算结果。

但是我们怎么求解W的值呢?

损失函数

损失函数的定义: 通过Sigmoid我们可以计算出所属类别的概率,因此我们可以更好的定义损失函数。

逻辑回归中的损失函数如下所示,下面我们进行讨论合理性

\[ loss = - y^* log(prob) - (1 - y^*)log(1 - prob) \]

其中\(y^*\)为标签的实际值,prob是Sigmoid函数的输出值。

这个函数很有意思

\(y^*=1\)时,\(loss=-log(prob)\),意思就是prob值越接近1越好。

\(y^*=0\)时,\(loss=- log(1 - prob)\),意思就是prob值越接近0越好。

总之,该损失函数能使prob值和数据的标签尽可能接近。

求解最优权重-求梯度

为了更好理解梯度的求法,我们重新写一下数据进行计算的过程

\[z = w_0 + w_1 x_1 + w_2 x_2 = \ \left[ \begin{matrix} 1 & x_1 & x_2 \end{matrix} \right] \ \left[ \begin{matrix} w_0 \\ w_1 \\ w_2 \end{matrix} \right] \tag{1} \]

\(Z = XW \tag{2}\),其中假设有10组数据每条数据有3个属性1个标签X维度(10,3),Z维度(10,1)

计算Z对W的导数,为下文做铺垫

\[\frac{d Z}{d W} = X^T \]

\( prob = \frac{1}{1 + e^{-z}} \tag{3}\) prob维度为(10,1)

计算prob对z的导数,为下文做铺垫

\[\frac{d prob}{d z} = \frac{e^{-z}} { (1+e^{-z})^2 } = \frac{1}{1 + e^{-z}} \frac{e^{-z}}{1 + e^{-z}} = prob (1 - prob) \]

综合公式1-3,可得 \( prob = \frac{1}{1 + e^{-XW}} \)

\[ loss = -y^*log(prob) -(1-y)*log(1-prob) \tag{4}\]

计算loss对prob的导数,为下文做铺垫

\[\frac{d loss}{d prob} = - \frac{y^*}{prob} - \frac{1 - y^*}{1-prob} * -1 \]

接下来进行链式求导,其中loss ~ prob ~ z ~ WX

注意下面的计算都是元素级元素,不涉及到矩阵运算

\[ grad = \frac{d loss}{d W} = \frac{d loss}{d prob} * \frac{d prob}{d z} * \frac{d z}{W} \\ = X^T (- \frac{y^*}{prob} + \frac{1 - y^*}{1-prob}) * prob(1 - prob) \\ = X^T ( - y^*(1 - prob) + (1 - y^*)prob ) \\ = X^T (prob - y^*) \]

最终结果维度为(3,10)乘以(10,1)等于(3,1)

求解最优权重-梯度下降

步骤1. 初始化W

步骤2. W -= lr * grad (其中lr为学习率,值到0-1之间)

步骤3. 迭代都一定次数,或阈值停止

Warning

纸上得来终决浅,写一边代码便能收获颇多,这里建议你点击我,查看源代码,进行仿写

四、代办事项

  1. 可视化梯度下降的整个过程

参考博客

https://www.bilibili.com/video/BV1As411j7zw?from=search&seid=13110094973136261265