概述

K-近邻法是一种分类算法,原理可见《小瓜讲机器学习——分类算法(四)K近邻法算法原理及Python代码实现》。当K=1的时候一般称为最近邻算法。

3.1 sklearn.neighbors

3.1.1 sklearn.neighbors.KNeighborsClassifier
sklearn.neighbors.KNeighborsClassifier(n_neighbors=5, weights='uniform', algorithm='auto', leaf_size=30, p=2
											metric='minkowski', metric_params=None, n_jobs=None, **kwargs)

参数说明:
1.n_neighbors(default=5):K近邻法中的k值,即计算邻域内的样本点数
2.weights(default=‘uniform’):邻域内的样本点的权重系数,可选值‘uniform’、‘distance’、【callable】(function defined by user)
3.algorithm:训练样本点的数据索引方式,有‘auto’、‘ball_tree’、‘KDTree’和‘brute’几种,其中KDTree原理见《小瓜讲机器学习——分类算法(四)K近邻法算法原理及Python代码实现》;
4.leaf_size(default=30):
5.p(default=2):计算距离的次方,p=2相当于欧式距离;
6.metric
7.metric_params
8.n_jobs

方法说明:
1.fit():训练模型;
2.predict():预测测试集;
3.predict_proba():预测测试集中每个样本分属不同类别的概率;
4.kneighbors_graph():获得测试样本的邻域内的训练样本点。

3.1.2 KNeighborsClassifier示例

栗子一
下例子是官网帮助文档中的

from sklearn.neighbors import KNeighborsClassifier

x = [[0], [1], [2], [3]]
y = [0, 0, 1, 1]

neigh = KNeighborsClassifier(n_neighbors=3)
neigh.fit(x, y)
print(neigh.predict([[1.1]]))
print(neigh.predict_proba([[0.9]]))

输出为

[0]
[[0.66666667 0.33333333]]

栗子二
生成随机数

import numpy as np
import matplotlib.pyplot as plt
from  matplotlib.colors import ListedColormap
from sklearn.datasets.samples_generator import make_classification
from sklearn.neighbors import KNeighborsClassifier

# create random samples
feature_X, label_Y = make_classification(n_samples=1000, n_features=2, n_redundant=0, n_clusters_per_class=1, n_classes=3)

#plt.scatter(feature_X[:, 0], feature_X[:, 1], marker='o', c=label_Y)
#plt.show()

#train model
clf=KNeighborsClassifier(n_neighbors=5)
clf.fit(feature_X, label_Y)

#data visiualization
cmap_light = ListedColormap(['#FFAAAA', '#AAFFAA', '#AAAAFF'])
cmap_bold = ListedColormap(['#FF0000', '#00FF00', '#0000FF'])

x1_min = feature_X[:, 0].min() - 1
x1_max = feature_X[:, 0].max() + 1
x2_min = feature_X[:, 1].min() - 1
x2_max = feature_X[:, 1].max() + 1

x1, x2 = np.meshgrid(np.arange(x1_min, x1_max, 0.01), np.arange(x2_min, x2_max, 0.01))

label_predict = clf.predict(np.c_[x1.ravel(), x2.ravel()])
label_predict = label_predict.reshape(x1.shape)
plt.figure()
plt.pcolormesh(x1, x2, label_predict, cmap=cmap_light)
plt.scatter(feature_X[:, 0], feature_X[:, 1], c=label_Y, cmap=cmap_bold)
plt.xlim(x1.min(), x1.max())
plt.ylim(x2.min(), x2.max())

plt.show()

结果如下图
在这里插入图片描述

Logo

永洪科技,致力于打造全球领先的数据技术厂商,具备从数据应用方案咨询、BI、AIGC智能分析、数字孪生、数据资产、数据治理、数据实施的端到端大数据价值服务能力。

更多推荐