一个用scikit-learn 框架训练手写数字图片识别图片上的数字 的例子
这是一个学习如何使用scikit-learn 框架,通过它提供的手写图片数据集做机器学习以识别8x8像素图片说包含的数字字符
本项目采用scikit-learn 框架开展机器学习,主要完成三个任务:
- 通过数据集训练模型,并生成模型文件
- 通过训练好的模型文件构建服务端,提供图片数字字符识别预测服务
- 通过一个简单的客户端程序提交需要识别的图片文件(8x8)像素发送到服务端返回识别后的结果
- training.py : 通过数据集训练模型,并生成模型文件
- server.py : 通过训练好的模型文件构建服务端,提供预测服务,服务地址http://localhost:5000/prediction
- client.py : 通过一个简单的客户端程序提交需要识别的图片文件(8x8)像素发送到服务端返回识别后的结果
- model :存放训练好的模型文件
- my_test_images:存放用于测试的10张8x8像素的数字图片,其中的数字可以根据自己的需要修改用于新的测试
- 手动安装项目依赖项。
pip install scikit-learn matplotlib numpy joblib flask
- 通过requirements命令来安装项目依赖项。
pip install -r requirements.txt
以面向过程为例
- 执行训练,并生成训练好的模型文件
python training.py
- 启动服务端 启动服务等待http链接端口5000
python server.py [-p 5000]
- 测试服务端 把my_test_images/digit2.png文件提交给服务端进行数字识别,并返回结果
python client.py -f my_test_images/digit2.png
通常使用 sklearn 进行机器学习任务的步骤如下:
- 准备数据:读取数据集并对数据进行预处理,包括数据清洗、特征提取、特征选择和特征转换等。
- 分离数据:将数据集划分为训练集和测试集,其中训练集用于训练模型,测试集用于评估模型性能。
- 选择模型:根据任务需求选择合适的分类器或回归器模型,并根据训练集数据进行模型训练。
- 模型评估:使用测试集数据对模型进行评估,包括计算模型的精度、召回率、F1 值等性能指标。
- 调整模型:根据模型评估结果调整模型参数或选择不同的模型,重新进行训练和评估,直到达到满意的性能。
- 预测新数据:使用训练好的模型对新数据进行预测,得到分类或回归结果。 在以上步骤中,数据预处理和特征工程对于机器学习任务的结果至关重要,良好的数据预处理和特征工程可以帮助机器学习模型更好地理解数据,提高模型的预测能力。对于模型的选择和调整,需要根据具体问题选择合适的模型,并对模型进行调参,以达到最佳的性能。最后,预测新数据是机器学习任务的最终目标,需要确保模型能够在新数据上具有较好的泛化能力。
sklearn.datasets 包中提供了许多常用的数据集,其中之一就是 load_digits 数据集。load_digits 数据集包含了 8x8 的手写数字图像数据,每个图像都对应一个 0 到 9 的数字。 load_digits 数据集的返回值是一个字典对象,包含了以下属性:
- data:一个二维数组,包含了所有的图像数据,每行代表一个图像,每列代表一个像素的灰度值。
- target:一个一维数组,包含了所有图像对应的数字标签。
- images:一个三维数组,包含了所有的图像数据,每个元素是一个 8x8 的数组,代表一个图像的像素灰度值。
- target_names:一个一维数组,包含了数字标签的名称。
- DESCR:一个字符串,包含了数据集的描述信息。 使用 load_digits 数据集的示例代码如下:
train_test_split 函数是 sklearn 中用于将数据集分割为训练集和测试集的函数。其函数原型为:
train_test_split(*arrays, test_size=None, train_size=None, random_state=None, shuffle=True, stratify=None)svm.SVC 是 sklearn 中实现支持向量机(SVM)分类器的类。支持向量机是一种常用的分类算法,可以用于线性和非线性分类任务。svm.SVC 类使用的是一种基于最大间隔的分类方法,即找到最优的超平面,使得离该超平面最近的样本点(即支持向量)到该超平面的距离最大化。 svm.SVC 的主要参数包括:
- C:正则化参数,用于平衡模型的复杂度和误差之间的权衡。较小的 C 值会使得模型更加简单,但也可能会导致欠拟合;较大的 C 值会使得模型更加复杂,但也可能会导致过拟合。默认值为 1.0。
- kernel:核函数选择,用于处理非线性分类问题。常用的核函数包括线性核函数、多项式核函数和高斯核函数等。默认值为 'rbf'。
- degree:多项式核函数的次数,用于控制多项式核函数的复杂度。默认值为 3。
- gamma:核函数系数,用于控制核函数的宽度。较小的 gamma 值会使得核函数更宽,决策边界更平滑;较大的 gamma 值会使得核函数更窄,决策边界更复杂。默认值为 'scale',表示 gamma 的值为 1 / ( n_features * X.var())。
- coef0:核函数中的常数项,用于控制多项式核函数和 Sigmoid 核函数的偏置。默认值为 0.0。
- shrinking:是否使用启发式方法来加速计算支持向量的选择。默认值为 True。
- probability:是否开启概率估计功能,用于计算每个样本属于各个类别的概率。默认值为 False。
- tol:精度阈值,用于控制迭代的停止条件。默认值为 1e-3。
svm.SVC 类也有一些常用的方法,例如:
- fit(X, y):使用数据集 X 和目标向量 y 训练 SVM 模型。
- predict(X):对数据集 X 进行分类预测,返回预测结果。
- score(X, y):评估模型在数据集 X 和目标向量 y 上的性能,返回分类准确率。
- decision_function(X):计算每个样本到超平面的距离,用于后续的阈值判定和概率计算。
- predict_proba(X):计算每个样本属于各个类别的概率,需要先开启 probability 参数。
- *arrays:需要分割的数据集,可以是多个数组,但必须保证第一个数组为特征矩阵,第二个数组为目标向量。
- test_size:测试集的大小,可以为浮点数(表示比例)或整数(表示样本数量),默认为 0.25。
- train_size:训练集的大小,可以为浮点数(表示比例)或整数(表示样本数量),默认为 1 - test_size。
- random_state:随机数种子,用于保证每次划分结果一致,默认为 None。
- shuffle:是否打乱数据集,以随机选择训练集和测试集,默认为 True。
- stratify:用于分层抽样,将数据集划分为具有相同标签分布的训练集和测试集,适用于分类任务。
train_test_split 函数的返回值是一个元组,包含了划分后的特征矩阵和目标向量,可以通过多个变量进行接收,例如:
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)上述代码将数据集 X 和目标向量 y 划分为训练集和测试集,其中测试集的大小为 0.3,随机数种子为 42。划分后的训练集特征矩阵为 X_train,训练集目标向量为 y_train,测试集特征矩阵为 X_test,测试集目标向量为 y_test。
fit 是 sklearn 中许多模型类的方法之一,用于使用给定的训练数据集和标签数据集拟合(训练)模型参数。该方法是模型训练的核心步骤,通过不断地调整模型参数,使得模型能够对训练数据进行更好的拟合,从而使得在未见过的数据上的表现也更好。 fit 方法的函数原型通常如下:
fit(X, y, **fit_params)其中,X 表示训练数据集的特征矩阵,y 表示训练数据集的标签向量。对于分类问题,y 通常是离散的类别标签,对于回归问题,y 通常是连续的实数值。**fit_params 表示其他的可选参数,不同的模型类可能会有不同的参数,具体可以参考对应的文档。 在 fit 方法被调用之后,模型会根据训练数据集和标签数据集来不断地调整参数,直到达到一定的拟合程度或者达到预设的迭代次数为止。一旦模型训练完成,我们就可以通过其他的方法(例如 predict、score 等)来对新的数据进行预测或者评估模型的性能。 需要注意的是,为了避免模型过拟合或者欠拟合的情况,我们有时需要对训练数据集进行一些预处理操作,例如数据清洗、特征选择、特征缩放等。在进行模型训练之前,我们应该先对数据进行预处理,以获得更好的模型效果。
score 是 sklearn 中许多模型类的方法之一,用于在给定的测试数据集上评估模型的性能。该方法通常用于分类和回归任务中,用于计算模型对于测试数据的准确率、精度、召回率、F1 值、R2 值等评价指标。 score 方法的函数原型通常如下:
score(X, y, sample_weight=None)其中,X 表示测试数据集的特征矩阵,y 表示测试数据集的标签向量,sample_weight 表示样本权重,用于调整不同样本的重要性。对于分类问题,y 通常是离散的类别标签,对于回归问题,y 通常是连续的实数值。 在 score 方法被调用之后,模型会根据给定的测试数据集进行预测,并将预测结果与真实的标签数据进行比较,从而计算出模型在测试数据集上的性能指标。具体的性能指标取决于具体的任务类型和模型类别。例如,在分类问题中,我们通常会计算模型在测试数据集上的准确率(accuracy)、精度(precision)、召回率(recall)、F1 值(F1-score)等指标;在回归问题中,我们通常会计算模型在测试数据集上的 R2 值(决定系数)等指标。 需要注意的是,为了防止模型在测试数据集上过拟合的情况,我们通常会将测试数据集划分成训练集和验证集两部分,用训练集来训练模型,用验证集来调整模型参数和评估模型性能,最后再用测试数据集来评估模型的最终性能。在对模型进行评估时,我们应该选择合适的评价指标,并结合具体任务需求来进行模型选择和参数调整。
predict 是 sklearn 中许多模型类的方法之一,用于对给定的数据集进行预测。该方法通常用于分类和回归任务中,用于将输入数据映射到输出标签或者数值上。 predict 方法的函数原型通常如下:
predict(X)其中,X 表示输入数据集的特征矩阵。对于分类问题,predict 方法会输出每个样本的预测类别标签,对于回归问题,predict 方法会输出每个样本的预测数值。 在 predict 方法被调用之后,模型会根据给定的输入数据集进行预测,并将预测结果输出。具体的预测结果取决于模型的具体类型和训练数据集的特征,例如,在分类问题中,预测结果可能是某个样本属于哪个类别,而在回归问题中,预测结果可能是某个样本对应的数值。 需要注意的是,为了获得更好的预测结果,我们应该对输入数据集进行预处理操作,例如数据清洗、特征选择、特征缩放等。在进行模型预测之前,我们应该先对数据进行预处理,以获得更好的预测效果。此外,为了防止模型在预测数据集上过拟合的情况,我们通常会将预测数据集划分成训练集和验证集两部分,用训练集来训练模型,用验证集来调整模型参数和评估模型性能,最后再用测试数据集来进行最终的预测。
joblib.dump 是 joblib 库中的一个函数,用于将 Python 对象保存到磁盘上。该函数通常用于将训练好的机器学习模型、预处理器、特征选择器等对象保存到磁盘上,以便于在后续的应用中使用。 joblib.dump 函数的函数原型通常如下:
joblib.dump(obj, filename, compress=3, protocol=None, *, cache_size=None)其中,obj 表示要保存的 Python 对象,filename 表示保存的文件路径,compress 表示是否对保存的对象进行压缩(默认为 3,表示中等压缩),protocol 表示序列化协议(默认为 None,表示使用 pickle 协议),cache_size 表示对象在序列化过程中占用的内存大小(默认为 None,表示根据对象大小自动调整缓存区大小)。 在调用 joblib.dump 函数之后,该函数会将 Python 对象序列化并保存到指定的文件路径中。保存的文件格式通常为二进制格式,可以使用 joblib.load 函数进行加载和反序列化。需要注意的是,保存的文件路径应该是一个有效的文件路径,且具有写入权限。 使用 joblib.dump 函数可以方便地将训练好的机器学习模型保存到磁盘上,并在后续的应用中进行加载和使用。该函数还支持对保存的对象进行压缩,以减少磁盘空间的占用。
joblib.load 是 joblib 库中的一个函数,用于从磁盘上加载 Python 对象。该函数通常用于加载之前保存的机器学习模型、预处理器、特征选择器等对象,以便于在后续的应用中使用。 joblib.load 函数的函数原型通常如下:
joblib.load(filename, mmap_mode=None)其中,filename 表示要加载的文件路径,mmap_mode 表示内存映射模式(默认为 None,表示不进行内存映射)。 在调用 joblib.load 函数之后,该函数会从指定的文件路径中加载 Python 对象,并对其进行反序列化。反序列化后的对象可以直接在应用中使用,例如进行机器学习模型的预测等操作。 需要注意的是,加载的文件路径应该是一个有效的文件路径,且具有读取权限。如果文件路径无效或者文件不存在,该函数会抛出相应的异常。 使用 joblib.load 函数可以方便地从磁盘上加载之前保存的机器学习模型、预处理器、特征选择器等对象,并在后续的应用中进行使用。该函数还支持通过内存映射的方式加载大型对象,以减少内存占用。
- Fork 本项目
- 新建 Feat_xxx 分支
- 提交代码
- 新建 Pull Request