首页 >> 大全

朴素贝叶斯分类器与Fisher线性判别实践——水果识别为例

2023-06-26 大全 55 作者:考证青年

开篇:本次博客主要分享二分类水果实现案例,样本构造显然不具备泛化、大数量特性以及背景均为纯白色是为了方便目标提取。大家若需要更好的水果样本,上有许多的优秀案例。本次代码计算效率较低,代码有较多计算可改进的地方。

一、朴素贝叶斯及线性判别识别

(一)朴素贝叶斯最大后验概率分类

(二)线性两类判别

在求解最佳w投影时,其实有两种方法可以实现,其分别为梯度下降法和拉格朗日乘子法,此处主要以拉格朗日乘子法为例,但是也会在代码中给出梯度下降法,从而来比较两种方法的差异。

图一 拉格朗日乘子法

对于梯度而言,我们知道梯度的本意是一个向量(矢量),表示某一函数在该点处的方向导数沿着该方向取得的最大值,即函数在该点处沿着该方向(梯度方向)变化最大,变化率最大(梯度的模)。然而,对于梯度下降来说,通常我们需要优化的损失函数是一个非凸函数,所以其通常会收敛到一个局部极值,它的收敛极值往往就和我们所选的变量初值有关。当然,我们知道若一个函数是凸函数,那么通过梯度下降法它一定会收敛到全局极值,所以我们在解决问题时通常会选择构造一个损失凸函数。

补充点:在二维曲面中,若一阶偏导等于零且二阶偏导大于零,那么此点一定是一个曲面极小值点。

图二 梯度下降法

二、实验构造样本集及测试集

图三训练样本集

图四测试样本集

三、判别实现流程

图五 最大后验概率贝叶斯判别流程 图六 线性判别流程

四、实现代码

(一)判别法

# Fisher线性二分类(LDA)import os
import cv2
import numpy as np
import mathdef preProcess(OriResize):# 转HSV颜色空间利用饱和度去除白色背景HSV = cv2.cvtColor(OriResize, cv2.COLOR_BGR2HSV)# 通过实验发现所使用背景可利用s_min=40的饱和度去除lower = np.array([0, 40, 0])upper = np.array([179, 255, 255])mask = cv2.inRange(HSV, lower, upper)# 进行高斯模糊blur = cv2.GaussianBlur(mask, (7, 7), 1)return blurdef DataGet(imgname, Binary, OriResize):DataList = []MultOJ = cv2.bitwise_and(OriResize, OriResize, mask=Binary)B, G, R = cv2.split(MultOJ)shape = OriResize.shapeheight = shape[0]width = shape[1]# 正式开始获取颜色特征fMB = 0fMG = 0fMR = 0Rcount = 0for i in range(height):for j in range(width):fMB += B[i, j]fMG += G[i, j]fMR += R[i, j]Rcount += 1fMB /= RcountfMG /= RcountfMR /= RcountDataList.append(fMB)DataList.append(fMG)DataList.append(fMR)if imgname.startswith("apple"):DataList.append(0)else:DataList.append(1)return DataListdef test(w, tpro0, tpro1):# 获取样本数据(B、G、R目标均值)SampleFiles1 = os.listdir(r"D:\ThirdGhomework\classification\BayesTest")# 图像目标预处理error = 0true = 0record = 0for imgname1 in SampleFiles1:# 对香蕉和苹果进行训练,其他滤过if not imgname1.startswith("apple") and not imgname1.startswith("banana"):continueimgpath1 = os.path.join(r"D:\ThirdGhomework\classification\BayesTest", imgname1)Img1 = cv2.imread(imgpath1)# 调整图片长宽像素比调整显示大小shape1 = Img1.shapeOriResize1 = cv2.resize(Img1, (shape1[0] // 2, shape1[1] // 2))binary1 = preProcess(OriResize1)L1 = DataGet(imgname, binary1, OriResize1)test1 = L1[0:3]testS = L1[3]testpro = np.matmul(w, np.array(test1).T)if abs(testpro - tpro0) > abs(testpro - tpro1):record = 1else:record = 0if record == testS:true += 1else:error += 1return true/(error + true)# 梯度算子(瑞利商函数对w求偏导,此处矩阵的对向量求偏导不熟练)
def grad(w):# 申明全局变量Sb与Swglobal u0,u1,cov1,cov0# 瑞利商函数对w的偏导,设置学习率为10的-3次幂rate = 0.01Sb = np.outer(u0-u1,u0-u1)Sw = cov1 + cov0sw2 = w.dot(Sw).dot(w.T)sb2 = w.dot(Sb).dot(w.T)nw = w + rate/pow(sw2, 2) * (2 * sw2 * Sb.dot(w.T) + 2 * sb2 * Sw.dot(w.T))ruiliS1 = w.dot(Sb).dot(w.T)/w.dot(Sw).dot(w.T)ruiliS2 = nw.dot(Sb).dot(nw.T)/nw.dot(Sw).dot(nw.T)return nw,ruiliS1,ruiliS2# 获取样本数据(B、G、R目标均值)
SampleFiles = os.listdir(r"D:\ThirdGhomework\classification\BayesSamples")
# 图像目标预处理
DataL = []
for imgname in SampleFiles:# 对香蕉和苹果进行训练,其他滤过if not imgname.startswith("apple") and not imgname.startswith("banana"):continueimgpath = os.path.join(r"D:\ThirdGhomework\classification\BayesSamples", imgname)Img = cv2.imread(imgpath)# 调整图片长宽像素比调整显示大小shape = Img.shapeOriResize = cv2.resize(Img, (shape[0] // 2, shape[1] // 2))binary = preProcess(OriResize)L = DataGet(imgname, binary, OriResize)DataL.append(L)
# 法一:Fisher二类线性判别(拉格朗日极值法)
npl = np.array(DataL)
A = npl.T[0:3]
sign = npl.T[3]
# 两类数据分类
x0 = A.T[sign == 0]
x1 = A.T[sign == 1]
# 返回轴向平均值(aix=0表示对各列求平均值;aix=1表示对各行求平均值)
u0 = x0.mean(axis=0)
u1 = x1.mean(axis=0)
# 求协方差矩阵
cov0 = np.cov(x0, rowvar=False)
cov1 = np.cov(x1, rowvar=False)
# 求解w
sw = cov1 + cov0
wf = np.matmul(np.linalg.inv(sw), u0-u1)
# 计算Fisher(LDA)两类均值的投影
pro0 = np.matmul(wf, u0.T)
pro1 = np.matmul(wf, u1.T)
# 法二:梯度下降法(与拉格朗日极值法共用初次协方差与均值求取)
# 初始化w投影向量
Wgrad = np.ones(3)
while True:Wgrad, s1, s2 = grad(Wgrad)if abs(s1-s2) < 0.000001:break
# 计算梯度下降方法两类投影中心
Gpro0 = np.matmul(Wgrad, u0.T)
Gpro1 = np.matmul(Wgrad, u1.T)# 测试集检验
zql1 = test(wf, pro0, pro1)
zql2 = test(Wgrad, Gpro0, Gpro1)
print("Fisher(LDA)分类正确率为:")
print(zql1)
print("梯度下降法分类正确率为:")
print(zql2)

(二)朴素贝叶斯

import os
import cv2
import numpy as np# 本次算法采用朴素贝叶斯分类,各分类特征假定符合相互独立条件
# 采用最大后验概率判别def preProcess(OriResize):# 转HSV颜色空间利用饱和度去除白色背景HSV = cv2.cvtColor(OriResize, cv2.COLOR_BGR2HSV)# 通过实验发现所使用背景可利用s_min=40的饱和度去除lower = np.array([0, 40, 0])upper = np.array([179, 255, 255])mask = cv2.inRange(HSV, lower, upper)# 进行高斯模糊blur = cv2.GaussianBlur(mask, (7, 7), 1)return blur# 计算返回查找表(行从上到下:苹果、香蕉、柠檬、猕猴桃)
def searchT(apl, bal, lel, kil):st = []ap = np.array(apl)ba = np.array(bal)le = np.array(lel)ki = np.array(kil)ap1 = ap.sum(axis=0)ba1 = ba.sum(axis=0)le1 = le.sum(axis=0)ki1 = ki.sum(axis=0)ap2 = list(ap1 / ap1[15])ba2 = list(ba1 / ba1[15])le2 = list(le1 / le1[15])ki2 = list(ki1 / ki1[15])st.append(ap2[0:15])st.append(ba2[0:15])st.append(le2[0:15])st.append(ki2[0:15])return st# p(yi)列表
def pyL():global appleCglobal bananaCglobal lemonCglobal kiwifruitCtotal = appleC + bananaC + lemonC + kiwifruitCp = []p.append(appleC/total)p.append(bananaC/total)p.append(lemonC/total)p.append(kiwifruitC/total)return p# 获取样本数据(B、G、R目标均值)
SampleFiles = os.listdir(r"D:\ThirdGhomework\classification\BayesSamples")
# 图像目标预处理
# 储存DN区间个数表(0,51],[52,103],[104,155],[156,207],[208,254]
appleDN = []
bananaDN = []
lemonDN = []
kiwifruitDN = []
# 各水果类别计数(计算P(yi))
appleC = 0
bananaC = 0
lemonC = 0
kiwifruitC = 0
for imgname in SampleFiles:count = []# 单样本DN区间计数a(0,51],b[52,103],c[104,155],d[156,207],e[208,254]ra, rb, rc, rd, re = 0, 0, 0, 0, 0ba, bb, bc, bd, be = 0, 0, 0, 0, 0ga, gb, gc, gd, ge = 0, 0, 0, 0, 0# 记录提取目标总像素个数tcount = 0imgpath = os.path.join(r"D:\ThirdGhomework\classification\BayesSamples", imgname)Img = cv2.imread(imgpath)# 调整图片长宽像素比调整显示大小shape = Img.shapeOriResize = cv2.resize(Img, (shape[0] // 2, shape[1] // 2))binary = preProcess(OriResize)MultOJ = cv2.bitwise_and(OriResize, OriResize, mask=binary)B, G, R = cv2.split(MultOJ)shape = OriResize.shapeheight = shape[0]width = shape[1]# 正式开始获取颜色特征# Rfor i in range(height):for j in range(width):if not R[i][j] == 0:tcount += 1if 0 < R[i][j] <= 51:ra += 1elif 51 < R[i][j] <= 103:rb += 1elif 103 < R[i][j] <= 155:rc += 1elif 155 < R[i][j] <= 207:rd += 1else:re += 1# Gfor i in range(height):for j in range(width):if 0 < G[i][j] <= 51:ga += 1elif 51 < G[i][j] <= 103:gb += 1elif 103 < G[i][j] <= 155:gc += 1elif 155 < G[i][j] <= 207:gd += 1else:ge += 1# Bfor i in range(height):for j in range(width):if 0 < B[i][j] <= 51:ba += 1elif 51 < B[i][j] <= 103:bb += 1elif 103 < B[i][j] <= 155:bc += 1elif 155 < B[i][j] <= 207:bd += 1else:be += 1count.append(ra)count.append(rb)count.append(rc)count.append(rd)count.append(re)count.append(ga)count.append(gb)count.append(gc)count.append(gd)count.append(ge)count.append(ba)count.append(bb)count.append(bc)count.append(bd)count.append(be)count.append(tcount)if imgname.startswith("apple"):appleC += 1appleDN.append(count)elif imgname.startswith("banana"):bananaDN.append(count)bananaC += 1elif imgname.startswith("lemon"):lemonDN.append(count)lemonC += 1else:kiwifruitDN.append(count)kiwifruitC += 1
ST = searchT(appleDN, bananaDN, lemonDN, kiwifruitDN)
py = pyL()# 样本测试
testpath = input("请输入测试图片的绝对路径(使用\\):")
Imgtest = cv2.imread(testpath)
# 调整图片长宽像素比调整显示大小
OriResizet = cv2.resize(Imgtest, (600, 900))
binaryt = preProcess(OriResizet)
MultOJt = cv2.bitwise_and(OriResizet, OriResizet, mask=binaryt)
Bt, Gt, Rt = cv2.split(MultOJt)
shapett = OriResizet.shape
heightt = shapett[0]
widtht = shapett[1]
testcount = 0
RDN, GDN, BDN = 0, 0, 0
for i in range(heightt):for j in range(widtht):if not Rt[i][j] == 0:testcount += 1RDN += Rt[i][j]
for i in range(heightt):for j in range(widtht):if not Gt[i][j] == 0:GDN += Gt[i][j]
for i in range(heightt):for j in range(widtht):if not Bt[i][j] == 0:BDN += Bt[i][j]
RDN /= testcount
BDN /= testcount
GDN /= testcount
# 取概率标志位
rj = 0
bj = 0
gj = 0
# r标识
if 0 < RDN <= 51:rj = 0
elif 51 < RDN <= 103:rj = 1
elif 103 < RDN <= 155:rj = 2
elif 155 < RDN <= 207:rj = 3
else:rj = 4
# g标识
if 0 < GDN <= 51:gj = 5
elif 51 < GDN <= 103:gj = 6
elif 103 < GDN <= 155:gj = 7
elif 155 < GDN <= 207:gj = 8
else:gj = 9
# b标识
if 0 < BDN <= 51:bj = 10
elif 51 < BDN <= 103:bj = 11
elif 103 < BDN <= 155:bj = 12
elif 155 < BDN <= 207:bj = 13
else:bj = 14
# 开始计算后验概率
PB = []
app = ST[0][rj] * ST[0][gj] * ST[0][bj] * py[0]
bap = ST[1][rj] * ST[1][gj] * ST[1][bj] * py[1]
lep = ST[2][rj] * ST[2][gj] * ST[2][bj] * py[2]
kip = ST[3][rj] * ST[3][gj] * ST[3][bj] * py[3]
PB.append(app)
PB.append(bap)
PB.append(lep)
PB.append(kip)
maxv = max(PB)
maxIdex = PB.index(maxv)
# 打标签
biaoqian = ""
if maxIdex == 0:biaoqian = "苹果"
elif maxIdex == 1:biaoqian = "香蕉"
elif maxIdex == 2:biaoqian = "柠檬"
else:biaoqian = "猕猴桃"
# 显示
print(biaoqian)
# 例子1 D:\\ThirdGhomework\\classification\\BayesTest\\apple1.JPG
# 例子2 D:\\ThirdGhomework\\classification\\BayesSamples\\apple7.jpg

五、实验结果

图七 朴素Bayes分类结果

图八线性判别的拉格朗日极值法和梯度下降分类结果

六、实验分析

(一)朴素贝叶斯

个人编写的朴素贝叶斯分类测试准确率在高亮度测试样本中准确率还可以,但是其在提供的测试样本集中准确率很低。显然,造成朴素贝叶斯分类结果较差的原因有很多,通过分析,我认为主要有如下几点:

1、算法是以统计0-254的五段分段区间的DN值统计来构造先验概率查询表的,而且我的想法认为如果均值落于五段中的一段,那么其数据也应该基本落于该段(未考虑边缘处),此处是明显缺少数学理论支撑的,只能感性认识;

2、DN值受到拍摄角度、拍摄曝光等影响较大,导致其固定区间无法具有推广性;

3、朴素贝叶斯的原理就是大量样本的特征统计,然而我的样本量很少同时质量也较低。

最后,个人编写的朴素贝叶斯代码效率较低,其实里面可以有更多的计算改进地方。

(二)线性判别

线性判别的结果较好,梯度下降法和拉格朗日极值法测试正确率均可达到100%(显然这是由于测试样本过少的原因),但是由于梯度下降法它具有一个函数收敛过程,因此其耗时较长一点。通过对比朴素贝叶斯,我比较疑惑为什么线性判别具有更好的克服光照条件、拍摄角度不一致的能力?很幸运的是,我在学习《定量遥感》的课程中也恰巧分到了“基于小样本分类”的选题。通过学习,我了解了距离判别和线性判别分类器对小样本的分类具有较为不错的准确率,同时其可迁移性较高,普适性较强(后续还需要学习更多的相关知识)。

七、感悟与收获

(一)在工程数学应用中,在我们解求某个目标时,我们可以进行一些合理的假设与非严谨完整的数学推理以获得我们的目标即可。在本次线性二分类判别应用中,由于我们关注的只是投影方向,因此我们对w的模并不关注,对其进行了一定的限定假设,由此,我们虽然未完全严格求解拉格朗日极值法,但是我们得到了我们想要的向量w的方向,极大的省略了不必要的计算。

(二)对于极值问题,当条件适合时,对求极大值转化为求其负值的极小值可能会有一定的简化帮助,降低极值求解难度。

(三)通过学习,我发现学习深度学习真的需要太多的数学知识,而且结合实际运用巧妙灵活。

参考资料:

[1]朴素贝叶斯分类的博客-CSDN博客_朴素贝叶斯分类

[2]2020 机器学习 LDA(线性判别分析)_哔哩哔哩

[3]贝叶斯决策(Bayes rule)的算法分析_哔哩哔哩

[4]2.梯度下降法_哔哩哔哩(线性判别实现方法之一)

[5]浅谈梯度下降_哔哩哔哩

[6]工程数学.线性代数/同济大学数学系编.—6版.—北京:高等教育出版社,2014.6.

[7]超简洁! 线性判别分析LDA 纯实现 非套用 哔哩哔哩

[8]第3章-二分类线性判别分析_哔哩哔哩(理论讲解较为清晰,极力推荐)

[9](矩阵运算、标量对向量求偏导、矩阵对矩阵求偏导等)

[10]百度百科.

[11]维基百科.

[12]简单易懂的梯度下降算法讲解,带你简单入门_哔哩哔哩

[13]高中生就能听懂的人工智能科普:梯度下降法。_哔哩哔哩

关于我们

最火推荐

小编推荐

联系我们


版权声明:本站内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌抄袭侵权/违法违规的内容, 请发送邮件至 88@qq.com 举报,一经查实,本站将立刻删除。备案号:桂ICP备2021009421号
Powered By Z-BlogPHP.
复制成功
微信号:
我知道了