数据挖掘:K最近邻(KNN)算法的java实现
版权声明:本文为博主原创文章,未经博主允许不得转载。KNN(K最近邻)算法。给定一些已经训练好的数据,输入一个新的测试数据点,计算包含于此测试数据点的最近的点的分类情况,哪个分类的类型占多数,则此测试点的分类与此相同,所以在这里,有的时候可以复制不同的分类点不同的权重。近的点的权重大点,远的点自然就小点。KNN算法的介绍见上一篇博文:数据挖掘:k最近邻(KNN)
·
版权声明:本文为博主原创文章,未经博主允许不得转载。
KNN(K最近邻)算法。给定一些已经训练好的数据,输入一个新的测试数据点,计算包含于此测试数据点的最近的点的分类情况,哪个分类的类型占多数,则此测试点的分类与此相同,所以在这里,有的时候可以复制不同的分类点不同的权重。近的点的权重大点,远的点自然就小点。
KNN算法的介绍见上一篇博文:数据挖掘:k最近邻(KNN)分类法介绍
本算法只适合学习使用,可以大致了解一下KNN算法的原理。
算法作了如下的假定与简化处理:
1.小规模数据集
2.假设所有数据及类别都是数值类型的
3.直接根据数据规模设定了k值
4.对原训练集进行测试
KNN实现代码如下:
- package KNN;
- /**
- * KNN结点类,用来存储最近邻的k个元组相关的信息
- * @author Rowen
- * @qq 443773264
- * @mail luowen3405@163.com
- * @blog blog.csdn.net/luowen3405
- * @data 2011.03.25
- */
- public class KNNNode {
- private int index; // 元组标号
- private double distance; // 与测试元组的距离
- private String c; // 所属类别
- public KNNNode(int index, double distance, String c) {
- super();
- this.index = index;
- this.distance = distance;
- this.c = c;
- }
- public int getIndex() {
- return index;
- }
- public void setIndex(int index) {
- this.index = index;
- }
- public double getDistance() {
- return distance;
- }
- public void setDistance(double distance) {
- this.distance = distance;
- }
- public String getC() {
- return c;
- }
- public void setC(String c) {
- this.c = c;
- }
- }
- package KNN;
- import java.util.ArrayList;
- import java.util.Comparator;
- import java.util.HashMap;
- import java.util.List;
- import java.util.Map;
- import java.util.PriorityQueue;
- /**
- * KNN算法主体类
- * @author Rowen
- * @qq 443773264
- * @mail luowen3405@163.com
- * @blog blog.csdn.net/luowen3405
- * @data 2011.03.25
- */
- public class KNN {
- /**
- * 设置优先级队列的比较函数,距离越大,优先级越高
- */
- private Comparator<KNNNode> comparator = new Comparator<KNNNode>() {
- public int compare(KNNNode o1, KNNNode o2) {
- if (o1.getDistance() >= o2.getDistance()) {
- return 1;
- } else {
- return 0;
- }
- }
- };
- /**
- * 获取K个不同的随机数
- * @param k 随机数的个数
- * @param max 随机数最大的范围
- * @return 生成的随机数数组
- */
- public List<Integer> getRandKNum(int k, int max) {
- List<Integer> rand = new ArrayList<Integer>(k);
- for (int i = 0; i < k; i++) {
- int temp = (int) (Math.random() * max);
- if (!rand.contains(temp)) {
- rand.add(temp);
- } else {
- i--;
- }
- }
- return rand;
- }
- /**
- * 计算测试元组与训练元组之前的距离
- * @param d1 测试元组
- * @param d2 训练元组
- * @return 距离值
- */
- public double calDistance(List<Double> d1, List<Double> d2) {
- double distance = 0.00;
- for (int i = 0; i < d1.size(); i++) {
- distance += (d1.get(i) - d2.get(i)) * (d1.get(i) - d2.get(i));
- }
- return distance;
- }
- /**
- * 执行KNN算法,获取测试元组的类别
- * @param datas 训练数据集
- * @param testData 测试元组
- * @param k 设定的K值
- * @return 测试元组的类别
- */
- public String knn(List<List<Double>> datas, List<Double> testData, int k) {
- PriorityQueue<KNNNode> pq = new PriorityQueue<KNNNode>(k, comparator);
- List<Integer> randNum = getRandKNum(k, datas.size());
- for (int i = 0; i < k; i++) {
- int index = randNum.get(i);
- List<Double> currData = datas.get(index);
- String c = currData.get(currData.size() - 1).toString();
- KNNNode node = new KNNNode(index, calDistance(testData, currData), c);
- pq.add(node);
- }
- for (int i = 0; i < datas.size(); i++) {
- List<Double> t = datas.get(i);
- double distance = calDistance(testData, t);
- KNNNode top = pq.peek();
- if (top.getDistance() > distance) {
- pq.remove();
- pq.add(new KNNNode(i, distance, t.get(t.size() - 1).toString()));
- }
- }
- return getMostClass(pq);
- }
- /**
- * 获取所得到的k个最近邻元组的多数类
- * @param pq 存储k个最近近邻元组的优先级队列
- * @return 多数类的名称
- */
- private String getMostClass(PriorityQueue<KNNNode> pq) {
- Map<String, Integer> classCount = new HashMap<String, Integer>();
- for (int i = 0; i < pq.size(); i++) {
- KNNNode node = pq.remove();
- String c = node.getC();
- if (classCount.containsKey(c)) {
- classCount.put(c, classCount.get(c) + 1);
- } else {
- classCount.put(c, 1);
- }
- }
- int maxIndex = -1;
- int maxCount = 0;
- Object[] classes = classCount.keySet().toArray();
- for (int i = 0; i < classes.length; i++) {
- if (classCount.get(classes[i]) > maxCount) {
- maxIndex = i;
- maxCount = classCount.get(classes[i]);
- }
- }
- return classes[maxIndex].toString();
- }
- }
- package KNN;
- import java.io.BufferedReader;
- import java.io.File;
- import java.io.FileReader;
- import java.util.ArrayList;
- import java.util.List;
- /**
- * KNN算法测试类
- * @author Rowen
- * @qq 443773264
- * @mail luowen3405@163.com
- * @blog blog.csdn.net/luowen3405
- * @data 2011.03.25
- */
- public class TestKNN {
- /**
- * 从数据文件中读取数据
- * @param datas 存储数据的集合对象
- * @param path 数据文件的路径
- */
- public void read(List<List<Double>> datas, String path){
- try {
- BufferedReader br = new BufferedReader(new FileReader(new File(path)));
- String data = br.readLine();
- List<Double> l = null;
- while (data != null) {
- String t[] = data.split(" ");
- l = new ArrayList<Double>();
- for (int i = 0; i < t.length; i++) {
- l.add(Double.parseDouble(t[i]));
- }
- datas.add(l);
- data = br.readLine();
- }
- } catch (Exception e) {
- e.printStackTrace();
- }
- }
- /**
- * 程序执行入口
- * @param args
- */
- public static void main(String[] args) {
- TestKNN t = new TestKNN();
- String datafile = new File("").getAbsolutePath() + File.separator + "datafile";
- String testfile = new File("").getAbsolutePath() + File.separator + "testfile";
- try {
- List<List<Double>> datas = new ArrayList<List<Double>>();
- List<List<Double>> testDatas = new ArrayList<List<Double>>();
- t.read(datas, datafile);
- t.read(testDatas, testfile);
- KNN knn = new KNN();
- for (int i = 0; i < testDatas.size(); i++) {
- List<Double> test = testDatas.get(i);
- System.out.print("测试元组: ");
- for (int j = 0; j < test.size(); j++) {
- System.out.print(test.get(j) + " ");
- }
- System.out.print("类别为: ");
- System.out.println(Math.round(Float.parseFloat((knn.knn(datas, test, 3)))));
- }
- } catch (Exception e) {
- e.printStackTrace();
- }
- }
- }
训练数据文件:
- 1.0 1.1 1.2 2.1 0.3 2.3 1.4 0.5 1
- 1.7 1.2 1.4 2.0 0.2 2.5 1.2 0.8 1
- 1.2 1.8 1.6 2.5 0.1 2.2 1.8 0.2 1
- 1.9 2.1 6.2 1.1 0.9 3.3 2.4 5.5 0
- 1.0 0.8 1.6 2.1 0.2 2.3 1.6 0.5 1
- 1.6 2.1 5.2 1.1 0.8 3.6 2.4 4.5 0
- 1.0 1.1 1.2 2.1 0.3 2.3 1.4 0.5
- 1.7 1.2 1.4 2.0 0.2 2.5 1.2 0.8
- 1.2 1.8 1.6 2.5 0.1 2.2 1.8 0.2
- 1.9 2.1 6.2 1.1 0.9 3.3 2.4 5.5
- 1.0 0.8 1.6 2.1 0.2 2.3 1.6 0.5
- 1.6 2.1 5.2 1.1 0.8 3.6 2.4 4.5
程序运行结果:
- 测试元组: 1.0 1.1 1.2 2.1 0.3 2.3 1.4 0.5 类别为: 1
- 测试元组: 1.7 1.2 1.4 2.0 0.2 2.5 1.2 0.8 类别为: 1
- 测试元组: 1.2 1.8 1.6 2.5 0.1 2.2 1.8 0.2 类别为: 1
- 测试元组: 1.9 2.1 6.2 1.1 0.9 3.3 2.4 5.5 类别为: 0
- 测试元组: 1.0 0.8 1.6 2.1 0.2 2.3 1.6 0.5 类别为: 1
- 测试元组: 1.6 2.1 5.2 1.1 0.8 3.6 2.4 4.5 类别为: 0
由结果可以看出,分类的测试结果是比较准确的!
更多推荐
所有评论(0)