• 【阿旭机器学习实战】【15】人脸自动补全(多目标回归),并比较5种不同模型的预测效果


    【阿旭机器学习实战】系列文章主要介绍机器学习的各种算法模型及其实战案例,欢迎点赞,关注共同学习交流。

    本文通过5种不同的机器学习模型KNN回归模型,线性回归,岭回归,lasso回归,极端随机森林回归,对人脸进行自动补全,并且比较其预测结果。

    机器学习实战—人脸自动补全(多目标预测)

    目标

    通过上半部分的人脸图案来预测下边部分人脸,进行人脸补全。

    实质是一个多目标预测问题,对每一个目标点都会进行模型建模,然后通过相应模型对各个点进行预测

    数据集

    采用Olivetti人脸数据集包含400张灰度的64*64像素的人脸图像,每个图像被展平为大小为4096的一维向量,40个不同的人拍照十次。

    from sklearn.neighbors import KNeighborsRegressor
    from sklearn.linear_model import LinearRegression,Ridge,Lasso
    from sklearn.ensemble import ExtraTreesRegressor
    
    • 1
    • 2
    • 3
    from sklearn import datasets
    
    • 1
    faces = datasets.fetch_olivetti_faces()
    
    • 1
    faces
    
    • 1
    {'data': array([[0.30991736, 0.3677686 , 0.41735536, ..., 0.15289256, 0.16115703,
             0.1570248 ],
            [0.45454547, 0.47107437, 0.5123967 , ..., 0.15289256, 0.15289256,
             0.15289256],
            [0.3181818 , 0.40082645, 0.49173555, ..., 0.14049587, 0.14876033,
             0.15289256],
            ...,
            [0.5       , 0.53305787, 0.607438  , ..., 0.17768595, 0.14876033,
             0.19008264],
            [0.21487603, 0.21900827, 0.21900827, ..., 0.57438016, 0.59090906,
             0.60330576],
            [0.5165289 , 0.46280992, 0.28099173, ..., 0.35950413, 0.3553719 ,
             0.38429752]], dtype=float32),
     'images': array([[[0.30991736, 0.3677686 , 0.41735536, ..., 0.37190083,
              0.3305785 , 0.30578512],
             [0.3429752 , 0.40495867, 0.43801653, ..., 0.37190083,
              0.338843  , 0.3140496 ],
             [0.3429752 , 0.41735536, 0.45041323, ..., 0.38016528,
              0.338843  , 0.29752067],
             ...,
             [0.21487603, 0.20661157, 0.2231405 , ..., 0.15289256,
              0.16528925, 0.17355372],
             [0.20247933, 0.2107438 , 0.2107438 , ..., 0.14876033,
              0.16115703, 0.16528925],
             [0.20247933, 0.20661157, 0.20247933, ..., 0.15289256,
              0.16115703, 0.1570248 ]],
     
            [[0.45454547, 0.47107437, 0.5123967 , ..., 0.19008264,
              0.18595041, 0.18595041],
             [0.446281  , 0.48347107, 0.5206612 , ..., 0.21487603,
              0.2107438 , 0.2107438 ],
             [0.49586776, 0.5165289 , 0.53305787, ..., 0.20247933,
              0.20661157, 0.20661157],
             ...,
             [0.77272725, 0.78099173, 0.7933884 , ..., 0.1446281 ,
              0.1446281 , 0.1446281 ],
             [0.77272725, 0.7768595 , 0.7892562 , ..., 0.13636364,
              0.13636364, 0.13636364],
             [0.7644628 , 0.7892562 , 0.78099173, ..., 0.15289256,
              0.15289256, 0.15289256]],
     
            [[0.3181818 , 0.40082645, 0.49173555, ..., 0.40082645,
              0.3553719 , 0.30991736],
             [0.30991736, 0.3966942 , 0.47933885, ..., 0.40495867,
              0.37603307, 0.30165288],
             [0.26859504, 0.34710744, 0.45454547, ..., 0.3966942 ,
              0.37190083, 0.30991736],
             ...,
             [0.1322314 , 0.09917355, 0.08264463, ..., 0.13636364,
              0.14876033, 0.15289256],
             [0.11570248, 0.09504132, 0.0785124 , ..., 0.1446281 ,
              0.1446281 , 0.1570248 ],
             [0.11157025, 0.09090909, 0.0785124 , ..., 0.14049587,
              0.14876033, 0.15289256]],
     
            ...,
     
            [[0.5       , 0.53305787, 0.607438  , ..., 0.28512397,
              0.23966943, 0.21487603],
             [0.49173555, 0.5413223 , 0.60330576, ..., 0.29752067,
              0.20247933, 0.20661157],
             [0.46694216, 0.55785125, 0.6198347 , ..., 0.29752067,
              0.17768595, 0.18595041],
             ...,
             [0.03305785, 0.46280992, 0.5289256 , ..., 0.17355372,
              0.17355372, 0.1694215 ],
             [0.1570248 , 0.5247934 , 0.53305787, ..., 0.16528925,
              0.1570248 , 0.18595041],
             [0.45454547, 0.5206612 , 0.53305787, ..., 0.17768595,
              0.14876033, 0.19008264]],
     
            [[0.21487603, 0.21900827, 0.21900827, ..., 0.71487606,
              0.71487606, 0.6942149 ],
             [0.20247933, 0.20661157, 0.20661157, ..., 0.7107438 ,
              0.7066116 , 0.6942149 ],
             [0.2107438 , 0.20661157, 0.20661157, ..., 0.6859504 ,
              0.69008267, 0.6942149 ],
             ...,
             [0.2644628 , 0.25619835, 0.2603306 , ..., 0.5413223 ,
              0.57438016, 0.59090906],
             [0.26859504, 0.2644628 , 0.26859504, ..., 0.56198347,
              0.58264464, 0.59504133],
             [0.27272728, 0.26859504, 0.27272728, ..., 0.57438016,
              0.59090906, 0.60330576]],
     
            [[0.5165289 , 0.46280992, 0.28099173, ..., 0.5785124 ,
              0.5413223 , 0.60330576],
             [0.5165289 , 0.45041323, 0.29338843, ..., 0.58264464,
              0.553719  , 0.5785124 ],
             [0.5165289 , 0.44214877, 0.29338843, ..., 0.59917355,
              0.5785124 , 0.54545456],
             ...,
             [0.39256197, 0.41322315, 0.38842976, ..., 0.33471075,
              0.37190083, 0.3966942 ],
             [0.39256197, 0.38429752, 0.40495867, ..., 0.3305785 ,
              0.35950413, 0.37603307],
             [0.3677686 , 0.40495867, 0.3966942 , ..., 0.35950413,
              0.3553719 , 0.38429752]]], dtype=float32),
     'target': array([ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  1,  1,  1,  1,  1,  1,  1,
             1,  1,  1,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  3,  3,  3,  3,
             3,  3,  3,  3,  3,  3,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  5,
             5,  5,  5,  5,  5,  5,  5,  5,  5,  6,  6,  6,  6,  6,  6,  6,  6,
             6,  6,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  8,  8,  8,  8,  8,
             8,  8,  8,  8,  8,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9, 10, 10,
            10, 10, 10, 10, 10, 10, 10, 10, 11, 11, 11, 11, 11, 11, 11, 11, 11,
            11, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 13, 13, 13, 13, 13, 13,
            13, 13, 13, 13, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 15, 15, 15,
            15, 15, 15, 15, 15, 15, 15, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16,
            17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 18, 18, 18, 18, 18, 18, 18,
            18, 18, 18, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 20, 20, 20, 20,
            20, 20, 20, 20, 20, 20, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 22,
            22, 22, 22, 22, 22, 22, 22, 22, 22, 23, 23, 23, 23, 23, 23, 23, 23,
            23, 23, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 25, 25, 25, 25, 25,
            25, 25, 25, 25, 25, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 27, 27,
            27, 27, 27, 27, 27, 27, 27, 27, 28, 28, 28, 28, 28, 28, 28, 28, 28,
            28, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 30, 30, 30, 30, 30, 30,
            30, 30, 30, 30, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 32, 32, 32,
            32, 32, 32, 32, 32, 32, 32, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33,
            34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 35, 35, 35, 35, 35, 35, 35,
            35, 35, 35, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 37, 37, 37, 37,
            37, 37, 37, 37, 37, 37, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 39,
            39, 39, 39, 39, 39, 39, 39, 39, 39]),
     'DESCR': 'Modified Olivetti faces dataset.\n\nThe original database was available from\n\n    http://www.cl.cam.ac.uk/research/dtg/attarchive/facedatabase.html\n\nThe version retrieved here comes in MATLAB format from the personal\nweb page of Sam Roweis:\n\n    http://www.cs.nyu.edu/~roweis/\n\nThere are ten different images of each of 40 distinct subjects. For some\nsubjects, the images were taken at different times, varying the lighting,\nfacial expressions (open / closed eyes, smiling / not smiling) and facial\ndetails (glasses / no glasses). All the images were taken against a dark\nhomogeneous background with the subjects in an upright, frontal position (with\ntolerance for some side movement).\n\nThe original dataset consisted of 92 x 112, while the Roweis version\nconsists of 64x64 images.\n'}
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111
    • 112
    • 113
    • 114
    • 115
    • 116
    • 117
    • 118
    • 119
    • 120
    • 121
    • 122
    • 123
    data = faces.data
    target = faces.target
    data.shape
    
    • 1
    • 2
    • 3
    (400, 4096)
    
    • 1
    faces.images.shape
    
    • 1
    (400, 64, 64)
    
    • 1
    import matplotlib.pyplot as plt
    %matplotlib inline
    
    • 1
    • 2
    # 打印一张人脸图片
    plt.imshow(data[100].reshape((64,64)),cmap="gray")
    
    • 1
    • 2

    请添加图片描述

    数据切分

    切分特征数据数据和标签数据,特征是上半边脸,标签是下半边脸

    # 特征是上半边脸
    faces_up = data[:,:2048]
    # 需要预测的目标:标签是下半边脸
    faces_down = data[:,2048:]
    
    • 1
    • 2
    • 3
    • 4
    plt.figure(figsize=(2,2))
    plt.imshow(faces_up[10].reshape((32,64)),cmap="gray")
    
    • 1
    • 2
    
    
    • 1

    请添加图片描述

    plt.figure(figsize=(2,2))
    plt.imshow(faces_down[10].reshape((32,64)),cmap="gray")
    
    • 1
    • 2

    请添加图片描述

    划分数据集

    # 数据切分
    from sklearn.model_selection import train_test_split
    
    • 1
    • 2
    x_train,x_test,y_train,y_test = train_test_split(faces_up,faces_down,test_size=0.02)
    
    • 1
    y_train[1]
    
    • 1
    array([0.5082645 , 0.5082645 , 0.5123967 , ..., 0.16115703, 0.17768595,
           0.1694215 ], dtype=float32)
    
    • 1
    • 2

    建立不同的回归模型并训练

    此处分别用KNN回归模型,线性回归,岭回归,lasso回归,极端随机森林回归这几种不同的模型来进行建模

    estimators = {
        "knn":KNeighborsRegressor(),
        "linear":LinearRegression(),
        "ridge":Ridge(),
        "lasso":Lasso(),
        "extra":ExtraTreesRegressor()  #极端随机森林回归
    }
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    # 定义一个字典,用于保存每个算法预测结果
    faces_pre = dict()
    for key,estimator in estimators.items():
        # 对算法进行模型训练
        estimator.fit(x_train,y_train)
        # 预测
        y_ = estimator.predict(x_test)
        # 把预测的结果保存
        faces_pre[key] = y_
        # 得分
        score = estimator.score(x_test, y_test)
        print(key, score)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    knn 0.4880642098170732
    linear 0.18894319531680143
    ridge 0.5157197923145055
    lasso -0.2100687498661858
    extra 0.35087195680524175
    
    • 1
    • 2
    • 3
    • 4
    • 5
    faces_pre
    
    • 1
    {'knn': array([[0.4471074 , 0.41652894, 0.42066115, ..., 0.54793394, 0.5355372 ,
             0.546281  ],
            [0.34876034, 0.34214878, 0.346281  , ..., 0.42727274, 0.42809922,
             0.43057853],
            [0.5355372 , 0.546281  , 0.58016527, ..., 0.56611574, 0.56280994,
             0.5644628 ],
            ...,
            [0.64793384, 0.67685956, 0.7049587 , ..., 0.41487604, 0.3586777 ,
             0.36776862],
            [0.3942149 , 0.41322312, 0.43553716, ..., 0.45785123, 0.43471074,
             0.39173552],
            [0.47520667, 0.47024792, 0.51404965, ..., 0.631405  , 0.6256199 ,
             0.59173554]], dtype=float32),
     'linear': array([[0.42212042, 0.35969752, 0.39748642, ..., 0.63096315, 0.5628751 ,
             0.5159277 ],
            [0.4241521 , 0.26758337, 0.16570012, ..., 0.09656662, 0.13010818,
             0.19814485],
            [0.62213266, 0.441006  , 0.48480797, ..., 0.5819658 , 0.69699645,
             0.44033697],
            ...,
            [0.71544605, 0.6732123 , 0.7088314 , ..., 0.37067276, 0.39097485,
             0.45659465],
            [0.2940399 , 0.3306437 , 0.32395566, ..., 0.19252078, 0.21714431,
             0.24263924],
            [0.4138433 , 0.47978985, 0.5166639 , ..., 0.5562554 , 0.4086836 ,
             0.42044348]], dtype=float32),
     'ridge': array([[0.4290133 , 0.37331253, 0.4017402 , ..., 0.5793132 , 0.53899723,
             0.4968022 ],
            [0.3253019 , 0.2301054 , 0.17614344, ..., 0.33642793, 0.3497425 ,
             0.3560007 ],
            [0.5519007 , 0.46847916, 0.5257808 , ..., 0.6301012 , 0.69831306,
             0.5881569 ],
            ...,
            [0.6989316 , 0.6826698 , 0.7077453 , ..., 0.29566136, 0.32281214,
             0.3521443 ],
            [0.31752783, 0.33159164, 0.33879474, ..., 0.24723864, 0.23903543,
             0.23862499],
            [0.39791593, 0.4184358 , 0.52279156, ..., 0.58981174, 0.50477254,
             0.5145724 ]], dtype=float32),
     'lasso': array([[0.5130819 , 0.5360938 , 0.56652683, ..., 0.31880376, 0.31096098,
             0.307535  ],
            [0.5130819 , 0.5360938 , 0.56652683, ..., 0.31880376, 0.31096098,
             0.307535  ],
            [0.5130819 , 0.5360938 , 0.56652683, ..., 0.31880376, 0.31096098,
             0.307535  ],
            ...,
            [0.5130819 , 0.5360938 , 0.56652683, ..., 0.31880376, 0.31096098,
             0.307535  ],
            [0.5130819 , 0.5360938 , 0.56652683, ..., 0.31880376, 0.31096098,
             0.307535  ],
            [0.5130819 , 0.5360938 , 0.56652683, ..., 0.31880376, 0.31096098,
             0.307535  ]], dtype=float32),
     'extra': array([[0.42644627, 0.39462809, 0.40661157, ..., 0.5409091 , 0.53388429,
             0.53966941],
            [0.30619835, 0.33347108, 0.35661157, ..., 0.43057852, 0.42066116,
             0.40909091],
            [0.43842976, 0.47768595, 0.58347108, ..., 0.45867768, 0.40041323,
             0.39380165],
            ...,
            [0.64049588, 0.65702479, 0.6731405 , ..., 0.36157025, 0.37272727,
             0.38429752],
            [0.3161157 , 0.3144628 , 0.37066115, ..., 0.41239669, 0.40206612,
             0.37685951],
            [0.43471075, 0.47272727, 0.51818182, ..., 0.54090908, 0.503719  ,
             0.50041322]])}
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    faces_pre["knn"]
    
    • 1
    array([[0.4471074 , 0.41652894, 0.42066115, ..., 0.54793394, 0.5355372 ,
            0.546281  ],
           [0.34876034, 0.34214878, 0.346281  , ..., 0.42727274, 0.42809922,
            0.43057853],
           [0.5355372 , 0.546281  , 0.58016527, ..., 0.56611574, 0.56280994,
            0.5644628 ],
           ...,
           [0.64793384, 0.67685956, 0.7049587 , ..., 0.41487604, 0.3586777 ,
            0.36776862],
           [0.3942149 , 0.41322312, 0.43553716, ..., 0.45785123, 0.43471074,
            0.39173552],
           [0.47520667, 0.47024792, 0.51404965, ..., 0.631405  , 0.6256199 ,
            0.59173554]], dtype=float32)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13

    不同模型预测的人脸结果与实际的对比

    import numpy as np
    
    • 1
    plt.figure(figsize=(6*3,8*3))
    for i in range(8):
        axes = plt.subplot(8,6,i*6+1)
        axes.axis("off")
        face_up = x_test[i]
        face_down = y_test[i]
        face = np.concatenate([face_up,face_down])
        axes.imshow(face.reshape((64,64)),cmap="gray")
        if i==0:
            axes.set_title("True")
        # 把机器学习预测出来的下半边脸和上半边脸拼接
        for j,key in enumerate(faces_pre):
            axes = plt.subplot(8,6,i*6+2+j)
            axes.axis("off")
            if i==0:
                axes.set_title(key)
            face_up = x_test[i]
            y_pre = faces_pre[key]
            face_down_pre = y_pre[i]
            
            face =np.concatenate([face_up,face_down_pre])
            axes.imshow(face.reshape((64,64)),cmap="gray")
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22

    请添加图片描述

    通过对比发现,上述案例中通过KNN预测的结果的脸型要好一些只是有明显的分界线,需要进一步处理,线性回归与岭回归预测的结果没有明显分界线,但是实际预测效果没有那么好;lasso回归,极端随机森林预测出的人脸结果不理想。

    如果内容对你有帮助,感谢记得点赞+关注哦!

    更多干货内容持续更新中……

  • 相关阅读:
    HDFS读写流程
    ​力扣解法汇总1779. 找到最近的有相同 X 或 Y 坐标的点
    数字电子技术笔记——组合逻辑功能
    3.3 Institution
    SystemUI状态栏
    Html_Css问答集(2)
    Qt 天气预报程序解析
    全网最详细SpringBoot、SpringCloud整合阿里云OSS对象存储服务
    ssm毕设项目志愿者活动管理平台zx2tk(java+VUE+Mybatis+Maven+Mysql+sprnig)
    SpringCloud Gateway--Predicate/断言(详细介绍)下
  • 原文地址:https://blog.csdn.net/qq_42589613/article/details/127667276