K 近邻算法进行丁香花预测
概述
K 近邻(K-Nearest Neighbors,简称:KNN)算法是最近邻(NN)算法的一个推广,也是机器学习分类算法中最简单的方法之一。KNN 算法的核心思想和最近邻算法思想相似,都是通过寻找和未知样本相似的类别进行分类
工作原理
- 准备数据:收集用于训练和测试的数据集,并对数据进行预处理,如特征选择、特征缩放等。
- 选择距离度量:KNN 算法使用距离度量来计算样本之间的相似度。常用的距离度量有欧氏距离、曼哈顿距离等。根据问题的实际情况选择合适的距离度量。
- 确定 k 值:根据实际需求和数据集的特点,选择一个合适的 k 值。k 值的选择对算法的性能有很大影响,过小的 k 值可能导致过拟合,过大的 k 值可能导致欠拟合。
- 分类决策:对于待分类的样本,计算它与训练集中每个样本的距离,找出距离它最近的 k 个样本。然后根据这 k 个样本的类别进行投票,将待分类样本划分到得票最多的类别中。
应用场景
- 图像识别
- 文本分类
- 推荐系统
案例
丁香花识别
数据预处理
python
# 数据处理
# 读取csv数据
origin_data = pd.read_csv(r'./data/course-9-syringa.csv')
print(origin_data.size)
# 预览前面10行数据
print(origin_data.head(10))
# 得到特征序列 sepal_length sepal_width petal_length petal_width
feature_data = origin_data.iloc[:, :-1]
# 得到 label 列
label_data = origin_data["labels"]
划分训练,测试集
python
x_train, x_test, y_train, y_test = train_test_split(
feature_data, label_data, test_size=0.3, random_state=2, shuffle=True
)
训练模型确定 K 值
python
def sklearn_classify(train_data, label_data, test_data, k_num):
# sklearn 构建KNN预测模型
knn = KNeighborsClassifier(n_neighbors=k_num, algorithm="kd_tree")
# 训练数据集
knn.fit(train_data, label_data)
# 预测
predict_label = knn.predict(test_data)
# 返回预测值
return predict_label
def get_accuracy(test_labels, pred_labels):
# 准确率计算函数
correct = np.sum(test_labels == pred_labels) # 计算预测正确的数据个数
n = len(test_labels) # 总测试集数据个数
accur = correct / n
return accur
normal_accuracy = []
k_value = range(2, 11)
for k in k_value:
y_predict = sklearn_classify(x_train, y_train, x_test, k)
accuracy = get_accuracy(y_test, y_predict)
res = {
"k": k,
"accuracy": accuracy
}
normal_accuracy.append(res)
# 使用sorted函数进行降序排序,找到最佳的K值
sorted_data = sorted(normal_accuracy, key=lambda x: x['accuracy'], reverse=True)
保存最佳模型
python
best_knn = KNeighborsClassifier(n_neighbors=sorted_data[0]["k"], algorithm="kd_tree")
# 训练数据集
best_knn.fit(x_train, y_train)
knn_model_name = 'best_knn_model.pkl'
joblib.dump(best_knn, knn_model_name)
模型预测
python
import joblib
import numpy as np
model = joblib.load('best_knn_model.pkl')
feature = np.array([[5.1, 3.5, 2.4, 2.1]])
label = 'daphne'
result = model.predict(feature)
print(result) # daphne