当前位置: 首页 > news >正文

4.2 - 《机器学习基石》Home Work 1 Q.18-20

2019独角兽企业重金招聘Python工程师标准>>> hot3.png

145409_Cknm_1047422.png

第18题要求在第16题 Random PLA 算法的基础上使用 Pocket 算法对数据做二元划分。Pocket算法在第2篇文章介绍过,通常用来处理有杂质的数据集,在每一次更新 Weights(权向量)之后,把当前犯错最少的Weights放在pocket中,直至达到指定迭代次数(50次),pocket中的Weights即为所求。然后用测试数据验证W(pocket)的错误率,进行2000次计算取平均。

#include <fstream>
#include <iostream>
#include <vector>
#include <ctime>

using namespace std;

#define DEMENSION 5                    //数据维度

int index = 0;                         //当前数据条目索引
int step = 0;                          //当前权向量更新次数
char *file = "training_data.txt";
char *file_test = "test_data.txt";

struct record {
	double input[DEMENSION];       //输入
	int output;                    //输出	
};



int sign(double x)
{
	//同Q16
}

//两个向量相加,更新第一个向量
void add(double *v1, double *v2, int demension)
{
	//同Q16
}

//两个向量相乘,返回内积
double multiply(double *v1, double *v2, int demension)
{
	//同Q16
}

//向量与实数相乘,结果通过*result返回,不改变参与计算的向量
void multiply(double *result, double *v, double num, int demension)
{
	//同Q16
}

//对 traininig set 创建一个随机排序
void setRandomOrder(vector<record> &trainingSet, vector<int> &randIndexes)
{
	//同Q16
}

//读取数据
void getData(ifstream & dataFile, vector<record> &data)
{
	//同Q16
}

//错误统计及Pocket向量更新
void errCountAndPocketUpdate(vector<record> &trainingSet, vector<int> &randIndexes, 
        double *weights, double *pocketWeights, double &trainingErrRate, int dataLength)
{
	int errCount = 0;
	double curTrainingErrRate = 1.0;

	for(int i=0;i<dataLength;++i){
		if(trainingSet[randIndexes[i]].output 
		    != sign(multiply(weights,trainingSet[randIndexes[i]].input,DEMENSION))){
			errCount++;	
		}	
	}

	curTrainingErrRate = double(errCount)/double(dataLength);

	if(curTrainingErrRate < trainingErrRate){
		trainingErrRate = curTrainingErrRate;	
		for(int j=0; j<DEMENSION; ++j){
			pocketWeights[j] = weights[j];	
		}	
	}
}


void Pocket(vector<record> &trainingSet, vector<int> &randIndexes, double *weights, 
        double *pocketWeights, double &trainingErrRate)
{
	int length = trainingSet.size();	
	double curInput[DEMENSION];	

	errCountAndPocketUpdate(trainingSet, randIndexes, weights, 
	    pocketWeights, trainingErrRate, length);
	    
        //找到下一个错误记录的index
	while( trainingSet[randIndexes[index]].output == 
	    sign(multiply(weights,trainingSet[randIndexes[index]].input,DEMENSION)) ){
		if(index==length-1)	{index = 0;}
		else				{index++;}
	}
	
	if(step<50){
		step++;
		
		//更新: weights = weights + curOutput * curInput
		multiply( curInput, trainingSet[randIndexes[index]].input, 
		    trainingSet[randIndexes[index]].output, DEMENSION );	
		add( weights, curInput, DEMENSION );	

		if(index==length-1)	{index = 0;}
		else			{index++;}

		Pocket(trainingSet, randIndexes, weights, pocketWeights, trainingErrRate);	
	}else{		
		return;
	}
}

//统计 W(pocket) 在测试数据集上的错误率
double getTestErrRate(vector<record> &testSet, double *pocketWeights, int dataLength)
{
	int errCount = 0;

	for(int i=0;i<dataLength;++i){
		if(testSet[i].output != 
		    sign(multiply(pocketWeights,testSet[i].input,DEMENSION))){
			errCount++;	
		}	
	}

	return double(errCount)/double(dataLength);
}


void main()
{
	double totalTestErrRate = 1.0;

	for(int i=0;i<2000;++i){

		double weights[DEMENSION];              //当前权重向量
		double pocketWeights[DEMENSION];        //当前最优权重向量
		vector<record> trainingSet;             //训练数据
		vector<record> testSet;                 //测试数据
		vector<int> randIndexes;                //访问数据的随机索引列表
		ifstream dataFile(file);
		ifstream testDataFile(file_test);
		double trainingErrRate = 1.0;           //训练集中的错误率[0.0, 1.0]
		double testErrRate = 1.0;               //测试集中的错误率[0.0, 1.0]

		step = 0;			   
		index = 0;			    

		if( dataFile.is_open() && testDataFile.is_open() ){
			getData(dataFile,trainingSet);	
			getData(testDataFile,testSet);	
			setRandomOrder(trainingSet,randIndexes);
		}else{
			cerr<<"ERROR ---> 文件打开失败"<<endl;
			exit(1);
		}

		for(int j=0;j<DEMENSION;++j){ 
			weights[j] = 0.0; 
			pocketWeights[j] = 0.0;
		}

		Pocket(trainingSet, randIndexes, weights, pocketWeights, trainingErrRate);

		testErrRate = getTestErrRate(testSet,pocketWeights,testSet.size());
		totalTestErrRate += testErrRate;

		cout<<"\n  ********************   第 "<<i+1
		<<" 次计算结束   *************************  \n"<<endl;
	}

	cout<<"Average Perportion = "<<totalTestErrRate/2000<<endl;
}

151115_1aMB_1047422.png

151143_U4i0_1047422.png

题19要求用经过50次更新的W(50)进行验证,而不是W(pocket),由于W(50)未必是当下最优,所以平均错误率一定会升高。代码几乎没有改动,只需在调用 getTestErrRate 函数是传入W(50)的指针即可。

151809_iiLD_1047422.png

151917_wLgr_1047422.png

本题要求把 Weights 的更新次数从50增加到100,可以预计平均错误率是降低的。

152128_sSSv_1047422.png

转载于:https://my.oschina.net/findbill/blog/208931

相关文章:

  • 修改SAP系统字段文本
  • 今天我的生日,纪念一下
  • Yii Framework 开发教程(关键是链接)Zii组件-ListView 示例
  • 技术开发频道一周精选2007-7-28
  • maven plugin的execution出错
  • 自动在Web.config中生成连接字符串
  • DNS服务器系列之二:高级配置之-DNS子域授权、区域转发、acl列表及view
  • 解决ASP.NET程序上传到虚拟主机的问题
  • [爱情] 『转载』女生写的追MM秘籍,看了马上告别光棍
  • 物流配送管理自动化采用无线传输解决方案
  • 80后管理:难以理解 但无法回避
  • PHP安全配置
  • 又一套BlogEngine主题Andreas
  • SQL Server Mobile 和 .NET 数据访问接口之间的数据类型映射
  • JavaScript 经典代码大全
  • 9月CHINA-PUB-OPENDAY技术沙龙——IPHONE
  • [case10]使用RSQL实现端到端的动态查询
  • Angular2开发踩坑系列-生产环境编译
  • CSS 三角实现
  • in typeof instanceof ===这些运算符有什么作用
  • java架构面试锦集:开源框架+并发+数据结构+大企必备面试题
  • k个最大的数及变种小结
  • LeetCode541. Reverse String II -- 按步长反转字符串
  • node和express搭建代理服务器(源码)
  • php的插入排序,通过双层for循环
  • Spring Boot MyBatis配置多种数据库
  • Vue UI框架库开发介绍
  • 高程读书笔记 第六章 面向对象程序设计
  • 关于字符编码你应该知道的事情
  • 互联网大裁员:Java程序员失工作,焉知不能进ali?
  • 如何解决微信端直接跳WAP端
  • 软件开发学习的5大技巧,你知道吗?
  • 腾讯优测优分享 | 你是否体验过Android手机插入耳机后仍外放的尴尬?
  • 异步
  • 用Visual Studio开发以太坊智能合约
  • 职业生涯 一个六年开发经验的女程序员的心声。
  • ionic异常记录
  • ​LeetCode解法汇总2696. 删除子串后的字符串最小长度
  • #1014 : Trie树
  • #1015 : KMP算法
  • $(document).ready(function(){}), $().ready(function(){})和$(function(){})三者区别
  • (js)循环条件满足时终止循环
  • (Matalb回归预测)PSO-BP粒子群算法优化BP神经网络的多维回归预测
  • (pojstep1.1.1)poj 1298(直叙式模拟)
  • (附源码)springboot家庭财务分析系统 毕业设计641323
  • (每日持续更新)jdk api之StringBufferInputStream基础、应用、实战
  • (一)Dubbo快速入门、介绍、使用
  • (转)memcache、redis缓存
  • ***php进行支付宝开发中return_url和notify_url的区别分析
  • *setTimeout实现text输入在用户停顿时才调用事件!*
  • .NET C#版本和.NET版本以及VS版本的对应关系
  • .net core webapi 部署iis_一键部署VS插件:让.NET开发者更幸福
  • .net web项目 调用webService
  • .NET中使用Protobuffer 实现序列化和反序列化
  • @Mapper作用