资讯专栏INFORMATION COLUMN

JavaScript机器学习之KNN算法

enrecul101 / 1877人阅读

摘要:是的缩写,它是一种监督学习算法。每一个机器学习算法都需要数据,这次我将使用数据集。其数据集包含了个样本,都属于鸢尾属下的三个亚属,分别是山鸢尾变色鸢尾和维吉尼亚鸢尾。四个特征被用作样本的定量分析,它们分别是花萼和花瓣的长度和宽度。

译者按: 机器学习原来很简单啊,不妨动手试试!

原文: Machine Learning with JavaScript : Part 2

译者: Fundebug

为了保证可读性,本文采用意译而非直译。另外,本文版权归原作者所有,翻译仅用于学习。另外,我们修正了原文代码中的错误

上图使用plot.ly所画。

上次我们用JavaScript实现了线性规划,这次我们来聊聊KNN算法。

KNN是k-Nearest-Neighbours的缩写,它是一种监督学习算法。KNN算法可以用来做分类,也可以用来解决回归问题。

GitHub仓库: machine-learning-with-js

KNN算法简介

简单地说,KNN算法由那离自己最近的K个点来投票决定待分类数据归为哪一类

如果待分类的数据有这些邻近数据,NY: 7, NJ: 0, IN: 4,即它有7个NY邻居,0个NJ邻居,4个IN邻居,则这个数据应该归类为NY

假设你在邮局工作,你的任务是为邮递员分配信件,目标是最小化到各个社区的投递旅程。不妨假设一共有7个街区。这就是一个实际的分类问题。你需要将这些信件分类,决定它属于哪个社区,比如上东城曼哈顿下城等。

最坏的方案是随意分配信件分配给邮递员,这样每个邮递员会拿到各个社区的信件。

最佳的方案是根据信件地址进行分类,这样每个邮递员只需要负责邻近社区的信件。

也许你是这样想的:"将邻近3个街区的信件分配给同一个邮递员"。这时,邻近街区的个数就是k。你可以不断增加k,直到获得最佳的分配方案。这个k就是分类问题的最佳值。

KNN代码实现

像上次一样,我们将使用mljs的KNN模块ml-knn来实现。

每一个机器学习算法都需要数据,这次我将使用IRIS数据集。其数据集包含了150个样本,都属于鸢尾属下的三个亚属,分别是山鸢尾、变色鸢尾和维吉尼亚鸢尾。四个特征被用作样本的定量分析,它们分别是花萼和花瓣的长度和宽度。

1. 安装模块

</>复制代码

  1. $ npm install ml-knn@2.0.0 csvtojson prompt

ml-knn: k-Nearest-Neighbours模块,不同版本的接口可能不同,这篇博客使用了2.0.0

csvtojson: 用于将CSV数据转换为JSON

prompt: 在控制台输入输出数据

2. 初始化并导入数据

IRIS数据集由加州大学欧文分校提供。

</>复制代码

  1. curl https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data > iris.csv

假设你已经初始化了一个NPM项目,请在index.js中输入以下内容:

</>复制代码

  1. const KNN = require("ml-knn");
  2. const csv = require("csvtojson");
  3. const prompt = require("prompt");
  4. var knn;
  5. const csvFilePath = "iris.csv"; // 数据集
  6. const names = ["sepalLength", "sepalWidth", "petalLength", "petalWidth", "type"];
  7. let seperationSize; // 分割训练和测试数据
  8. let data = [],
  9. X = [],
  10. y = [];
  11. let trainingSetX = [],
  12. trainingSetY = [],
  13. testSetX = [],
  14. testSetY = [];

seperationSize用于分割数据和测试数据

使用csvtojson模块的fromFile方法加载数据:

</>复制代码

  1. csv(
  2. {
  3. noheader: true,
  4. headers: names
  5. })
  6. .fromFile(csvFilePath)
  7. .on("json", (jsonObj) =>
  8. {
  9. data.push(jsonObj); // 将数据集转换为JS对象数组
  10. })
  11. .on("done", (error) =>
  12. {
  13. seperationSize = 0.7 * data.length;
  14. data = shuffleArray(data);
  15. dressData();
  16. });

我们将seperationSize设为样本数目的0.7倍。注意,如果训练数据集太小的话,分类效果将变差。

由于数据集是根据种类排序的,所以需要使用shuffleArray函数对数据进行混淆,这样才能方便分割出训练数据。这个函数的定义请参考StackOverflow的提问How to randomize (shuffle) a JavaScript array?:

</>复制代码

  1. function shuffleArray(array)
  2. {
  3. for (var i = array.length - 1; i > 0; i--)
  4. {
  5. var j = Math.floor(Math.random() * (i + 1));
  6. var temp = array[i];
  7. array[i] = array[j];
  8. array[j] = temp;
  9. }
  10. return array;
  11. }
3. 转换数据

数据集中每一条数据可以转换为一个JS对象:

</>复制代码

  1. {
  2. sepalLength: ‘5.1’,
  3. sepalWidth: ‘3.5’,
  4. petalLength: ‘1.4’,
  5. petalWidth: ‘0.2’,
  6. type: ‘Iris-setosa’
  7. }

在使用KNN算法训练数据之前,需要对数据进行这些处理:

将属性(sepalLength, sepalWidth,petalLength,petalWidth)由字符串转换为浮点数. (parseFloat)

将分类 (type)用数字表示

</>复制代码

  1. function dressData()
  2. {
  3. let types = new Set();
  4. data.forEach((row) =>
  5. {
  6. types.add(row.type);
  7. });
  8. let typesArray = [...types];
  9. data.forEach((row) =>
  10. {
  11. let rowArray, typeNumber;
  12. rowArray = Object.keys(row).map(key => parseFloat(row[key])).slice(0, 4);
  13. typeNumber = typesArray.indexOf(row.type); // Convert type(String) to type(Number)
  14. X.push(rowArray);
  15. y.push(typeNumber);
  16. });
  17. trainingSetX = X.slice(0, seperationSize);
  18. trainingSetY = y.slice(0, seperationSize);
  19. testSetX = X.slice(seperationSize);
  20. testSetY = y.slice(seperationSize);
  21. train();
  22. }
4. 训练数据并测试

</>复制代码

  1. function train()
  2. {
  3. knn = new KNN(trainingSetX, trainingSetY,
  4. {
  5. k: 7
  6. });
  7. test();
  8. }

train方法需要2个必须的参数: 输入数据,即花萼和花瓣的长度和宽度;实际分类,即山鸢尾、变色鸢尾和维吉尼亚鸢尾。另外,第三个参数是可选的,用于提供调整KNN算法的内部参数。我将k参数设为7,其默认值为5。

训练好模型之后,就可以使用测试数据来检查准确性了。我们主要对预测出错的个数比较感兴趣。

</>复制代码

  1. function test()
  2. {
  3. const result = knn.predict(testSetX);
  4. const testSetLength = testSetX.length;
  5. const predictionError = error(result, testSetY);
  6. console.log(`Test Set Size = ${testSetLength} and number of Misclassifications = ${predictionError}`);
  7. predict();
  8. }

比较预测值与真实值,就可以得到出错个数:

</>复制代码

  1. function error(predicted, expected)
  2. {
  3. let misclassifications = 0;
  4. for (var index = 0; index < predicted.length; index++)
  5. {
  6. if (predicted[index] !== expected[index])
  7. {
  8. misclassifications++;
  9. }
  10. }
  11. return misclassifications;
  12. }
5. 进行预测(可选)

任意输入属性值,就可以得到预测值

</>复制代码

  1. function predict()
  2. {
  3. let temp = [];
  4. prompt.start();
  5. prompt.get(["Sepal Length", "Sepal Width", "Petal Length", "Petal Width"], function(err, result)
  6. {
  7. if (!err)
  8. {
  9. for (var key in result)
  10. {
  11. temp.push(parseFloat(result[key]));
  12. }
  13. console.log(`With ${temp} -- type = ${knn.predict(temp)}`);
  14. }
  15. });
  16. }
6. 完整程序

完整的程序index.js是这样的:

</>复制代码

  1. const KNN = require("ml-knn");
  2. const csv = require("csvtojson");
  3. const prompt = require("prompt");
  4. var knn;
  5. const csvFilePath = "iris.csv"; // 数据集
  6. const names = ["sepalLength", "sepalWidth", "petalLength", "petalWidth", "type"];
  7. let seperationSize; // 分割训练和测试数据
  8. let data = [],
  9. X = [],
  10. y = [];
  11. let trainingSetX = [],
  12. trainingSetY = [],
  13. testSetX = [],
  14. testSetY = [];
  15. csv(
  16. {
  17. noheader: true,
  18. headers: names
  19. })
  20. .fromFile(csvFilePath)
  21. .on("json", (jsonObj) =>
  22. {
  23. data.push(jsonObj); // 将数据集转换为JS对象数组
  24. })
  25. .on("done", (error) =>
  26. {
  27. seperationSize = 0.7 * data.length;
  28. data = shuffleArray(data);
  29. dressData();
  30. });
  31. function dressData()
  32. {
  33. let types = new Set();
  34. data.forEach((row) =>
  35. {
  36. types.add(row.type);
  37. });
  38. let typesArray = [...types];
  39. data.forEach((row) =>
  40. {
  41. let rowArray, typeNumber;
  42. rowArray = Object.keys(row).map(key => parseFloat(row[key])).slice(0, 4);
  43. typeNumber = typesArray.indexOf(row.type); // Convert type(String) to type(Number)
  44. X.push(rowArray);
  45. y.push(typeNumber);
  46. });
  47. trainingSetX = X.slice(0, seperationSize);
  48. trainingSetY = y.slice(0, seperationSize);
  49. testSetX = X.slice(seperationSize);
  50. testSetY = y.slice(seperationSize);
  51. train();
  52. }
  53. // 使用KNN算法训练数据
  54. function train()
  55. {
  56. knn = new KNN(trainingSetX, trainingSetY,
  57. {
  58. k: 7
  59. });
  60. test();
  61. }
  62. // 测试训练的模型
  63. function test()
  64. {
  65. const result = knn.predict(testSetX);
  66. const testSetLength = testSetX.length;
  67. const predictionError = error(result, testSetY);
  68. console.log(`Test Set Size = ${testSetLength} and number of Misclassifications = ${predictionError}`);
  69. predict();
  70. }
  71. // 计算出错个数
  72. function error(predicted, expected)
  73. {
  74. let misclassifications = 0;
  75. for (var index = 0; index < predicted.length; index++)
  76. {
  77. if (predicted[index] !== expected[index])
  78. {
  79. misclassifications++;
  80. }
  81. }
  82. return misclassifications;
  83. }
  84. // 根据输入预测结果
  85. function predict()
  86. {
  87. let temp = [];
  88. prompt.start();
  89. prompt.get(["Sepal Length", "Sepal Width", "Petal Length", "Petal Width"], function(err, result)
  90. {
  91. if (!err)
  92. {
  93. for (var key in result)
  94. {
  95. temp.push(parseFloat(result[key]));
  96. }
  97. console.log(`With ${temp} -- type = ${knn.predict(temp)}`);
  98. }
  99. });
  100. }
  101. // 混淆数据集的顺序
  102. function shuffleArray(array)
  103. {
  104. for (var i = array.length - 1; i > 0; i--)
  105. {
  106. var j = Math.floor(Math.random() * (i + 1));
  107. var temp = array[i];
  108. array[i] = array[j];
  109. array[j] = temp;
  110. }
  111. return array;
  112. }

在控制台执行node index.js

</>复制代码

  1. $ node index.js

输出如下:

</>复制代码

  1. Test Set Size = 45 and number of Misclassifications = 2
  2. prompt: Sepal Length: 1.7
  3. prompt: Sepal Width: 2.5
  4. prompt: Petal Length: 0.5
  5. prompt: Petal Width: 3.4
  6. With 1.7,2.5,0.5,3.4 -- type = 2
参考链接

K NEAREST NEIGHBOR 算法

安德森鸢尾花卉数据集

欢迎加入我们Fundebug的全栈BUG监控交流群: 622902485

版权声明:
转载时请注明作者Fundebug以及本文地址:
https://blog.fundebug.com/2017/07/10/javascript-machine-learning-knn/

文章版权归作者所有,未经允许请勿转载,若此文章存在违规行为,您可以联系管理员删除。

转载请注明本文地址:https://www.ucloud.cn/yun/84018.html

相关文章

  • 机器习之 K-近邻算法

    摘要:近邻算法通过测量不同特征值之间的距离方法进行分类。对于近邻算法来说,它是一个特殊的没有模型的算法,但是我们将其训练数据集看作是模型。算法优缺点近邻算法是一个比较简单的算法,有其优点但也有缺点。 k-近邻算法通过测量不同特征值之间的距离方法进行分类。 k-近邻算法原理 对于一个存在标签的训练样本集,输入没有标签的新数据后,将新数据的每个特征与样本集中数据对应的特征进行比较,根据算法选择样...

    Jensen 评论0 收藏0
  • 机器习之多项式回归与模型泛化

    摘要:还提供了,将多项式特征数据归一化和线性回归组合在了一起,大大方便的编程的过程。在机器学习算法中,主要的挑战来自方差,解决的方法主要有降低模型复杂度降维增加样本数使用验证集模型正则化。 多项式回归 多项式回归使用线性回归的基本思路 非线性曲线如图: showImg(https://segmentfault.com/img/bVbkn4q?w=372&h=252); 假设曲线表达式为:$y...

    huhud 评论0 收藏0

发表评论

0条评论

最新活动
阅读需要支付1元查看
<