已屏蔽 原因:{{ notice.reason }}已屏蔽
{{notice.noticeContent}}
~~空空如也

光说不练从来不是科创的风格,下面讲一讲 Logistic Regression 的程序实现和实验。

在实验中我使用 Python,在 scientific computing 上,Python一直是主流,机器学习也不能免俗,而 numpy 提供的便捷的向量化运算,给诸多机器学习的算法实施提供了便利。另外强烈推荐ML新手装 sklearn 这个package。

导入库:

<code class="language-python">import numpy as np
from numpy import linalg as LA
import matplotlib.pyplot as plt
import sklearn.datasets
import sklearn.cross_validation
</code>

以下是计算梯度的函数,使用了numpy的向量化计算特性:

<code class="language-python">def calculate_gradient(w,x_batch,y_batch):
    sigmoid=1/(1+np.dot(x_batch,np.transpose(w)))
    dL=np.dot(sigmoid-y_batch,x_batch)/y_batch.size
    return dL
</code>

计算Loss function,注意,数据溢出是Logistic Regression程序实现的一个主要问题,因为exp函数的输出,用float 表示时,实际上输入被限制在[-750,750]这个区间内,不做处理的话基本上肯定会上溢。这个问题的解决同样在这本书中有讲

<code class="language-python">def calculate_loss(w,x_all,y_all):
    ### Avoid Overflow! ###
    pos_index=np.where(y_all==1)    
    neg_index=np.where(y_all==0)
    Loss=np.sum(-np.log(1+np.exp(-np.dot(x_all[pos_index,:],np.transpose(w)))))+np.sum(-np.log(1+np.exp(-np.dot(x_all[pos_index,:],np.transpose(w)))))
    return Loss   
</code>

训练过程主循环,注意对learning rate采用了annealing,在SGD过程中分段一点点减小步长,不然很容易最后变成在global minimum周围反复徘徊难以收敛:

<code class="language-python">def train(x_train,y_train,alpha,batch_sz,loss_thresh,Max_iter,w0):
    ### bias trick ###
    w=w0
    data_sz=y_train.size
    x_train_b=np.concatenate((x_train,np.ones((data_sz,1))),axis=1)
    Loss_old=0
    Loss=[]
    stepCnt=0
    ### Run SGD ###
    for iter in range(1,Max_iter):
        ### sample a mini batch ###
        batch=np.arange(data_sz)
        np.random.shuffle(batch)
        x_batch=x_train_b[batch[:batch_sz],:]
        y_batch=y_train[batch[:batch_sz]]
        ### update weight ###
        dL=calculate_gradient(w,x_batch,y_batch)
        w-=alpha*dL
        ### record loss changes ###
        Loss.append(calculate_loss(w,x_train_b,y_train))
        ### learning rate annealing ###
        stepCnt+=1
        if stepCnt==10:
            stepCnt=0
            alpha*=0.8

        ### Check if converge ###
        if abs(Loss[-1]-Loss_old)<loss_thresh: break loss_old="Loss[-1]" return w,loss < code></loss_thresh:></code>

使用了sklearn的数据生成函数,之前的图表数据皆来源于此:

<code class="language-python">def make_data():
    centers = [(-10, -10),(10, 10)]
    x, y = sklearn.datasets.make_blobs(n_samples=2000, n_features=2, cluster_std=5.0,
                  centers=centers, shuffle=False, random_state=100)
    x_train, x_test, y_train, y_test = sklearn.cross_validation.train_test_split(x, y, test_size=.4)
    return x_train, x_test, y_train, y_test
</code>

主函数,w被初始化为全零向量,mini batch size 是50 :

<code class="language-python">def main():
    alpha=0.5
    batch_sz=50
    Max_iter=2000
    loss_thresh=1e-5
    w0=[0,0,0]
    x_train, x_test, y_train, y_test = make_data()
    w,Loss = train(x_train,y_train,alpha,batch_sz,loss_thresh,Max_iter,w0)
    plt.plot(Loss)
</code>

以下是记录的学习过程中 Loss function 的收敛过程。
output_9_1.png
得到的w的值为[ 123.01618818 125.42445694 11.78221087]。

画成直线可以看出,跟理想情况非常近似
result.png

文号 / 822839

千古风流
名片发私信
学术分 2
总主题 34 帖总回复 364 楼拥有证书:专家 进士 老干部 学者 机友 笔友
注册于 2012-09-03 13:32最后登录 2024-04-15 13:21
主体类型:个人
所属领域:无
认证方式:手机号
IP归属地:未同步

个人简介

Machine Learning, computer vision enthusiast

Google

文件下载
加载中...
{{errorInfo}}
{{downloadWarning}}
你在 {{downloadTime}} 下载过当前文件。
文件名称:{{resource.defaultFile.name}}
下载次数:{{resource.hits}}
上传用户:{{uploader.username}}
所需积分:{{costScores}},{{holdScores}}下载当前附件免费{{description}}
积分不足,去充值
文件已丢失

当前账号的附件下载数量限制如下:
时段 个数
{{f.startingTime}}点 - {{f.endTime}}点 {{f.fileCount}}
视频暂不能访问,请登录试试
仅供内部学术交流或培训使用,请先保存到本地。本内容不代表科创观点,未经原作者同意,请勿转载。
音频暂不能访问,请登录试试
投诉或举报
加载中...
{{tip}}
请选择违规类型:
{{reason.type}}

空空如也

插入资源
全部
图片
视频
音频
附件
全部
未使用
已使用
正在上传
空空如也~
上传中..{{f.progress}}%
处理中..
上传失败,点击重试
等待中...
{{f.name}}
空空如也~
(视频){{r.oname}}
{{selectedResourcesId.indexOf(r.rid) + 1}}
处理中..
处理失败
插入表情
我的表情
共享表情
Emoji
上传
注意事项
最大尺寸100px,超过会被压缩。为保证效果,建议上传前自行处理。
建议上传自己DIY的表情,严禁上传侵权内容。
点击重试等待上传{{s.progress}}%处理中...已上传,正在处理中
空空如也~
处理中...
处理失败
加载中...
草稿箱
加载中...
此处只插入正文,如果要使用草稿中的其余内容,请点击继续创作。
{{fromNow(d.toc)}}
{{getDraftInfo(d)}}
标题:{{d.t}}
内容:{{d.c}}
继续创作
删除插入插入
插入公式
评论控制
加载中...
文号:{{pid}}
加载中...
详情
详情
推送到专栏从专栏移除
设为匿名取消匿名
查看作者
回复
只看作者
加入收藏取消收藏
收藏
取消收藏
折叠回复
置顶取消置顶
评学术分
鼓励
设为精选取消精选
管理提醒
编辑
通过审核
评论控制
退修或删除
历史版本
违规记录
投诉或举报
加入黑名单移除黑名单
查看IP
{{format('YYYY/MM/DD HH:mm:ss', toc)}}
ID: {{user.uid}}