趋近智
我们将实际操作训练和评估监督学习 (supervised learning)模型,使用MLJ.jl。我们将使用一个常见的数据集,实现几种不同的模型,使用交叉验证评估它们的性能,然后观察超参数 (parameter) (hyperparameter)调整如何改善结果。通过学习,您将更熟练地将MLJ.jl的工作流程应用于自己的监督学习任务。
首先,请确保您已安装MLJ.jl及其他所需包。我们将从加载这些包以及本次练习使用的数据集开始。鸢尾花数据集是分类任务的经典选项,MLJ.jl提供了简便的加载方式。
using MLJ
using DataFrames
using PrettyPrinting # 为MLJ对象提供更美观的输出
using StableRNGs # 为了结果可重复
# 加载鸢尾花数据集
X, y = @load_iris; # X是特征表,y是目标分类向量
# 用于数据划分和模型训练的可重复性
rng = StableRNG(123)
# 显示特征和目标的前几行
first(X, 3) |> pretty
first(y, 3) |> pretty
特征X是150朵鸢尾花的萼片长度、萼片宽度、花瓣长度和花瓣宽度的测量值。目标y是每朵花的种类。
接下来,我们将数据分为训练集和测试集。这是评估模型在新数据上表现的标准做法。
# 将数据分为训练集和测试集(70%训练,30%测试)
train_rows, test_rows = partition(eachindex(y), 0.7, rng=rng);
X_train = X[train_rows, :];
y_train = y[train_rows];
X_test = X[test_rows, :];
y_test = y[test_rows];
我们从线性模型:逻辑回归开始(在MLJ.jl中,对于多分类问题也称为MultinomialClassifier)。
@load使模型类型可用。machine对象中。fit!训练模型。predict。# 加载多项式分类器模型类型
LogisticClassifier = @load MultinomialClassifier pkg=MLJLinearModels verbosity=0
# 创建模型实例
logreg_model = LogisticClassifier()
# 将模型和训练数据封装到机器中
logreg_machine = machine(logreg_model, X_train, y_train)
# 训练模型
fit!(logreg_machine, verbosity=0)
# 在测试集上进行预测
y_pred_logreg = predict(logreg_machine, X_test)
# 评估准确率
accuracy_logreg = accuracy(mode.(y_pred_logreg), y_test) # 使用mode是因为predict返回的是分布
println("Logistic Regression Accuracy: $(round(accuracy_logreg, digits=3))")
您应该会看到打印出的准确率分数。对于鸢尾花数据集,逻辑回归通常表现良好。在本次演示中,我们假设它达到了约0.956的准确率。
现在,让我们尝试不同类型的模型,一个决策树分类器。
# 加载决策树分类器模型类型
DecisionTreeClassifier = @load DecisionTreeClassifier pkg=DecisionTree verbosity=0
# 创建模型实例
tree_model = DecisionTreeClassifier(rng=deepcopy(rng)) # 传入rng以保证树本身的可重复性
# 封装到机器中
tree_machine = machine(tree_model, X_train, y_train)
# 训练模型
fit!(tree_machine, verbosity=0)
# 进行预测
y_pred_tree = predict(tree_machine, X_test)
# 评估准确率
accuracy_tree = accuracy(mode.(y_pred_tree), y_test)
println("Decision Tree Accuracy (default): $(round(accuracy_tree, digits=3))")
默认参数 (parameter)的决策树可能会给出略有不同的结果。例如,我们可能会观察到约0.933的准确率。
在单次训练-测试集划分上评估有时会因为数据划分的特定方式而产生误导。交叉验证提供了更可靠的模型性能估计。MLJ.jl的evaluate!函数使这变得简单。我们将对决策树模型使用6折交叉验证。
# 定义重采样策略:6折交叉验证
cv_strategy = CV(nfolds=6, rng=deepcopy(rng))
# 使用交叉验证评估决策树模型
# 我们直接使用“模型”,而不是已绑定数据的机器
# 我们将verbosity=0设置为在评估期间抑制输出
tree_eval = evaluate(tree_model, X_train, y_train,
resampling=cv_strategy,
measure=accuracy,
verbosity=0)
# 显示评估结果
println("Decision Tree Cross-Validation Results:")
println("Mean Accuracy: $(round(tree_eval.measurement[1], digits=3))")
println("Per-fold Accuracy: $(round.(tree_eval.per_fold[1], digits=3))")
输出将显示每折的准确率以及所有折的平均准确率。这使我们更能了解模型平均的表现。例如,平均准确率可能在0.945左右。
大多数机器学习 (machine learning)模型都有超参数,可以通过调整来改善性能。对于DecisionTreeClassifier,其中一个超参数是max_depth,它控制树的最大深度。让我们使用Grid搜索策略来调整它。
max_depth值。Grid搜索。TunedModel:将基础模型、调整策略、重采样策略和参数范围封装。TunedModel:此过程会为每个超参数组合训练模型并选择最佳模型。# 定义max_depth超参数的范围
tree_model_tunable = DecisionTreeClassifier() # 用于调整的新实例
r_max_depth = range(tree_model_tunable, :max_depth, lower=1, upper=10, scale=:linear);
# 定义调整策略(网格搜索)
tuning_strategy = Grid(resolution=10) # resolution表示范围内有10个值
# 定义用于调整的重采样(例如,3折交叉验证以加速调整)
resampling_strategy_tuning = CV(nfolds=3, rng=deepcopy(rng))
# 创建TunedModel
tuned_tree_model = TunedModel(model=tree_model_tunable,
resampling=resampling_strategy_tuning,
tuning=tuning_strategy,
range=r_max_depth,
measure=accuracy,
train_best=true) # 自动在完整训练数据上重新训练最佳模型
# 将TunedModel封装到机器中并拟合
tuned_tree_machine = machine(tuned_tree_model, X_train, y_train)
fit!(tuned_tree_machine, verbosity=0)
# 检查调整结果报告
tuning_report = report(tuned_tree_machine)
best_model_params = tuning_report.best_model
best_max_depth = best_model_params.max_depth
println("Best max_depth found: $best_max_depth")
# 提取最佳模型的拟合参数
fitted_params(tuned_tree_machine).best_model |> pretty
# 在测试集上评估调整后的模型
y_pred_tuned_tree = predict(tuned_tree_machine, X_test)
accuracy_tuned_tree = accuracy(mode.(y_pred_tuned_tree), y_test)
println("Tuned Decision Tree Accuracy: $(round(accuracy_tuned_tree, digits=3))")
调整后,您可能会发现特定的max_depth(例如3或4)在调整期间使用的交叉验证集上能带来更好的性能。将此调整后的模型应用于我们的保留测试集,可能会使准确率得到提升,例如达到约0.978。
我们现在已经训练并评估了几种模型。可视化它们的性能通常很有用。
逻辑回归、默认决策树以及超参数 (parameter) (hyperparameter)调整后的决策树在鸢尾花测试集上的准确率比较。示例值显示了潜在的改善。
本次动手练习展示了MLJ.jl中监督学习 (supervised learning)的主要工作流程:
建议您进一步尝试。可以尝试MLJ.jl中提供的不同模型,研究其他超参数,或将这些技术应用于其他数据集。您在此练习中掌握的技能是使用Julia构建更复杂机器学习 (machine learning)解决方案的基础。
这部分内容有帮助吗?
© 2026 ApX Machine Learning用心打造