RISC-V MCU中文社区

【分享】 使用matlab搭建BP从零搭建BP神经网络完成鸢尾花数据集分类

发表于 全国大学生集成电路创新创业大赛 2023-05-26 10:11:23
0
507
0

小组名:啊啊对对对队
报名编号:CICC2400
iris_training.mat文件如下
链接:https://pan.baidu.com/s/14vb1c0noPB4YKCCdOCsofA?pwd=ozmz
提取码:ozmz
不赘述,正确率96.7%


load("iris_training.mat")
X = [iristraining(1:30,1:4);iristraining(41:70,1:4);iristraining(81:110,1:4)];
D = [iristraining(1:30,5);iristraining(41:70,5);iristraining(81:110,5)];
X_test = [iristraining(31:40,1:4);iristraining(71:80,1:4);iristraining(111:120,1:4)];
D_test = [iristraining(31:40,5);iristraining(71:80,5);iristraining(111:120,5)];
%%Bp neural network
%三层bp神经网络,四维特征,三分类问题
%最后三个节点的输出分别代表三个类别,当输入类别为i是,当且仅当第i个神经元输出大于0.5时,认为判别正确


%%bp神经网络共分为三层,输入层,中间层,输出层,其中输入层4节点,中间层4节点,输出层3节点
w1 = 2*rand(5,4)-1;             %连接权重
w2 = 2*rand(5,4)-1;
w3 = 2*rand(5,3)-1;
s = 0.01;                       %误差
a = 0.05;                       %学习率
d_correct = 0.9;                
d_wrong = 0.1;
err = 1;
gen = 0;
%BP训练
while err>s && gen<1000
    gen = gen + 1
    err = 0;
    for i = 1:90

        %前馈
        for m = 1:4
            x(m) = logsig(w1(1,m)*X(i,1) + w1(2,m)*X(i,2) + w1(3,m)*X(i,3) + w1(4,m)*X(i,4) - w1(5,m));   %输入层
        end
        for m = 1:4
            y(m) = logsig(w2(1,m)*x(1) + w2(2,m)*x(2) + w2(3,m)*x(3) + w2(4,m)*x(4) - w2(5,m));           %中间层
        end
        for m = 1:3
            z(m) = logsig(w3(1,m)*y(1) + w3(2,m)*y(2) + w3(3,m)*y(3) + w3(4,m)*y(4) - w3(5,m));              %输出层
        end



        for m =1:3
            if (D(i)+1)==m
               err = err + (z(m)-d_correct)^2;
            else
               err = err + (z(m)-d_wrong)^2;
            end
        end




        %反馈
        for m = 1:3
            delta_w3(m) = z(m)*(1-z(m))*(z(m)-(d_wrong+(d_correct-d_wrong)*((D(i)+1)==m)));                                
        end
        for m = 1:4
            delta_w2(m) = y(m)*(1-y(m))*(delta_w3(1)*w3(m,1) + delta_w3(2)*w3(m,2) + delta_w3(3)*w3(m,3));                                
        end
        for m = 1:4
            delta_w1(m) = x(m)*(1-x(m))*(delta_w2(1)*w2(m,1) + delta_w2(2)*w2(m,2) + delta_w2(3)*w2(m,3) + delta_w2(4)*w2(m,4));    
        end

        %更新网络权重
        w1 = w1 - a*[X(i,:) -1]'*delta_w1;
        w2 = w2 - a*[x -1]'*delta_w2;
        w3 = w3 - a*[y -1]'*delta_w3;
    end   
    err = err/270;
end

%测试效果
correct = 0;
for i = 1:30
    corr = 1;
    %前馈
    for m = 1:4
        x(m) = logsig(w1(1,m)*X_test(i,1) + w1(2,m)*X_test(i,2) + w1(3,m)*X_test(i,3) + w1(4,m)*X_test(i,4) - w1(5,m));   %输入层
    end
    for m = 1:4
        y(m) = logsig(w2(1,m)*x(1) + w2(2,m)*x(2) + w2(3,m)*x(3) + w2(4,m)*x(4) - w2(5,m));           %中间层
    end
    for m = 1:3
        z(m) = logsig(w3(1,m)*y(1) + w3(2,m)*y(2) + w3(3,m)*y(3) + w3(4,m)*y(4) - w3(5,m));              %输出层
    end
    for m = 1:3
        if m==(D_test(i)+1) && z(m)<0.5
            corr = 0;
        end
        if m~=(D_test(i)+1) && z(m)>0.5
            corr = 0;
        end
    end
    if corr==1
        correct = correct + 1;
    end


end
correct = correct/30;
fprintf("correction is %d%%",correct*100)
喜欢0
用户评论
铭………

铭……… 实名认证

凭来去

积分
问答
粉丝
关注
  • RV-STAR 开发板
  • RISC-V处理器设计系列课程
  • 培养RISC-V大学土壤 共建RISC-V教育生态
RV-STAR 开发板