趋近智
K近邻(KNN)算法的运作原理是:通过找出与新数据点最接近的“k”个训练样本,并根据这些邻居中的多数类别进行预测。将演示一个K近邻分类器的实现,使用常用的Python库Scikit-learn,并在一个大家熟知的数据集上操作。
我们将使用知名的鸢尾花数据集。该数据集包含150朵鸢尾花的测量数据,这些花分属三个不同品种:山鸢尾(Setosa)、变色鸢尾(Versicolor)和维吉尼亚鸢尾(Virginica)。每朵花都有以下四个特征:
我们的目的是构建一个K近邻模型,能够根据这四项测量数据预测鸢尾花的品种。这是多类别分类问题的典型例子。
我们将使用Python和Scikit-learn库。如果您之前没有使用过Scikit-learn,它是一个功能强大且广泛用于机器学习 (machine learning)任务的库。您还需要NumPy等库用于数值运算,以及Matplotlib/Seaborn用于绘图(非必需,但有助于理解)。
确保您已安装这些库。通常可以通过pip进行安装:
pip install scikit-learn numpy matplotlib seaborn pandas
Scikit-learn方便地包含了鸢尾花数据集。让我们加载它。
import pandas as pd
from sklearn.datasets import load_iris
import numpy as np
# 加载鸢尾花数据集
iris = load_iris()
# 数据集作为Bunch对象加载(类似于字典)
# iris.data包含特征(NumPy数组)
# iris.target包含标签(0、1、2对应品种)
# iris.feature_names包含特征名称
# iris.target_names包含品种名称
# 为了方便处理,我们将其放入Pandas DataFrame
# 这是可选的,但通常很方便
df = pd.DataFrame(data=np.c_[iris['data'], iris['target']],
columns= iris['feature_names'] + ['target'])
# 将目标数字映射到品种名称以便清晰
df['species'] = df['target'].map({0: 'setosa', 1: 'versicolor', 2: 'virginica'})
print("鸢尾花数据集的前5行:")
print(df.head())
print("\n目标类别(品种):")
print(df['species'].unique())
# 分离特征(X)和目标(y)
X = iris.data # 特征(NumPy数组)
y = iris.target # 目标标签(NumPy数组)
您应该会看到数据的前几行,这些行显示了测量值以及对应的目标标签(0、1或2)和品种名称。
正如在第2章中讲解并在第6章中再次回顾的那样,我们需要拆分数据。我们将在一个部分(训练集)上训练模型,并在一个独立的、未见过部分(测试集)上评估其性能。这有助于我们了解模型在新数据上的泛化能力。
Scikit-learn为此提供了一个方便的函数train_test_split。
from sklearn.model_selection import train_test_split
# 将数据拆分为训练集和测试集
# test_size=0.3 表示30%的数据将用于测试
# random_state 确保结果可复现性(每次运行都得到相同的拆分)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42, stratify=y)
print(f"训练集形状: {X_train.shape}")
print(f"测试集形状: {X_test.shape}")
我们使用stratify=y来确保训练集和测试集中各品种花的比例大致相同,这对于分类任务来说是良好实践。
回顾第6章,K近邻算法依赖于数据点之间的距离计算(例如欧几里得距离)。如果特征的尺度差异很大(例如,一个特征范围是0-1,而另一个是100-1000),范围较大的特征可能会在距离计算中占据主导。因此,将特征缩放到相似范围对于K近邻算法通常很重要。我们将使用Scikit-learn中的StandardScaler,它通过移除均值并缩放到单位方差来标准化特征。
from sklearn.preprocessing import StandardScaler
# 初始化StandardScaler
scaler = StandardScaler()
# 仅在训练数据上拟合缩放器
scaler.fit(X_train)
# 转换训练和测试数据
X_train_scaled = scaler.transform(X_train)
X_test_scaled = scaler.transform(X_test)
# 注意:仅在训练数据上拟合缩放器非常重要,
# 然后使用该拟合的缩放器来转换两个数据集。
# 这可以防止测试集的信息“泄露”到训练过程中。
# 让我们查看缩放后数据的前几行(可选)
# print("\n缩放后训练数据的前5行:")
# print(X_train_scaled[:5])
现在我们可以创建K近邻分类器了。我们需要选择的主要参数 (parameter)是n_neighbors,也就是我们讲解过的“k”值。让我们从一个常用值开始,例如k=5。
from sklearn.neighbors import KNeighborsClassifier
# 使用k=5初始化K近邻分类器
knn = KNeighborsClassifier(n_neighbors=5)
# 使用缩放后的训练数据训练模型
knn.fit(X_train_scaled, y_train)
print("\nK近邻模型已成功训练,k=5。")
对于许多Scikit-learn模型而言,fit方法是“学习”发生的地方。然而,对于K近邻算法,fit非常简单:它主要只是存储训练数据(X_train_scaled和y_train),以便在后续进行预测时可以引用。
有了我们训练好的模型,现在可以预测测试集(X_test_scaled)中鸢尾花的品种了。
# 对缩放后的测试数据进行预测
y_pred = knn.predict(X_test_scaled)
# 显示前10个预测结果及实际标签
print("\n前10个预测值与实际标签对比:")
print(f"Predictions: {y_pred[:10]}")
print(f"Actual: {y_test[:10]}")
# 记住:0=山鸢尾, 1=变色鸢尾, 2=维吉尼亚鸢尾
predict方法接收新数据点(我们缩放后的测试特征),并为每个点找到存储的训练数据中最近的5个邻居。然后,它根据这些邻居中的多数投票来预测类别。
我们的模型表现如何?我们需要将预测结果(y_pred)与实际标签(y_test)进行对比。我们在上一节学习了评估指标。让我们计算准确率并查看混淆矩阵。
from sklearn.metrics import accuracy_score, confusion_matrix, ConfusionMatrixDisplay
import matplotlib.pyplot as plt
# 计算准确率
accuracy = accuracy_score(y_test, y_pred)
print(f"\n模型准确率: {accuracy:.4f}")
# 计算混淆矩阵
cm = confusion_matrix(y_test, y_pred)
print("\n混淆矩阵:")
print(cm)
# 可视化混淆矩阵
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=iris.target_names)
disp.plot(cmap=plt.cm.Blues) # 使用蓝色颜色图
plt.title("K近邻(k=5)的混淆矩阵")
plt.show()
一个混淆矩阵,显示了K近邻分类器在鸢尾花测试集上的性能。行代表真实类别,列代表预测类别。对角线元素显示了正确预测的数量。
准确率告诉我们正确预测的总比例。混淆矩阵提供了更详细的分解:
在本例中(结果可能因random_state略有不同),K=5的K近邻模型通常在鸢尾花数据集上表现非常好,混淆矩阵中显示的错误分类很少,常能达到高准确率。
k值(邻居数量)的选择会影响模型的性能。小的k值可能使模型对噪声敏感,而非常大的k值可能过度平滑决策边界。
尝试在创建KNeighborsClassifier时更改n_neighbors参数 (parameter)(例如,尝试k=1、k=3、k=10),然后重新运行步骤4、5和6。观察准确率和混淆矩阵如何变化。找到最优的k值通常涉及尝试多个值,并查看哪个在验证集上表现最好(或者使用交叉验证等方法,这些是略微高级的主题)。
例如,让我们快速检查k=3:
# 初始化、训练、预测并评估k=3的模型
knn_k3 = KNeighborsClassifier(n_neighbors=3)
knn_k3.fit(X_train_scaled, y_train)
y_pred_k3 = knn_k3.predict(X_test_scaled)
accuracy_k3 = accuracy_score(y_test, y_pred_k3)
print(f"\nk=3的模型准确率: {accuracy_k3:.4f}")
cm_k3 = confusion_matrix(y_test, y_pred_k3)
disp_k3 = ConfusionMatrixDisplay(confusion_matrix=cm_k3, display_labels=iris.target_names)
disp_k3.plot(cmap=plt.cm.Greens) # 这次使用绿色颜色图
plt.title("K近邻(k=3)的混淆矩阵")
plt.show()
一个混淆矩阵,显示了K=3的K近邻分类器在鸢尾花测试集上的性能。
对比结果。在这个特定测试集上,k=3的表现比k=5更好还是更差?并非所有数据集都总是存在一个“最优”的k值;这通常取决于数据结构。
在本次实践中,您成功实现了一个K近邻分类器:
KNeighborsClassifier实例。k值如何影响结果。这个动手练习展现了使用标准工具将监督学习 (supervised learning)算法应用于分类问题的典型工作流程。您现在拥有了实现一种基本分类算法的实践经验。
这部分内容有帮助吗?
© 2026 ApX Machine LearningAI伦理与透明度•