• Pytorch房价预测


    数值稳定性和激活函数总结

    1. relu容易导致梯度爆炸、sigmoid容易导致梯度消失
    2. xavier模型初始化方法
    3. Adam适应学习的范围更大一点

    房价预测demo

    下载数据

    import hashlib
    import os
    import tarfile
    import zipfile
    import requests
    
    • 1
    • 2
    • 3
    • 4
    • 5
    DATA_HUB = dict()
    DATA_URL = 'http://d2l-data.s3-accelerate.amazonaws.com/'
    
    • 1
    • 2
    • 断言 assert 等价于
        if not expression:
            raise AssertionError
    
    • 1
    • 2
    def download(name, cache_dir=os.path.join('.', 'data')):  
        """下载一个DATA_HUB中的文件,返回本地文件名。"""
        assert name in DATA_HUB, f"{name} 不存在于 {DATA_HUB}."
        url, sha1_hash = DATA_HUB[name]
        os.makedirs(cache_dir, exist_ok=True)
        fname = os.path.join(cache_dir, url.split('/')[-1])
        if os.path.exists(fname):
            sha1 = hashlib.sha1()
            with open(fname, 'rb') as f:
                while True:
                    data = f.read(1048576)
                    if not data:
                        break
                    sha1.update(data)
            if sha1.hexdigest() == sha1_hash:
                return fname
        print(f'正在从{url}下载{fname}...')
        r = requests.get(url, stream=True, verify=True)
        with open(fname, 'wb') as f:
            f.write(r.content)
        return fname
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    def download_extract(name, folder=None):  
        """下载并解压zip/tar文件。"""
        fname = download(name)
        base_dir = os.path.dirname(fname)
        data_dir, ext = os.path.splitext(fname)
        if ext == '.zip':
            fp = zipfile.ZipFile(fname, 'r')
        elif ext in ('.tar', '.gz'):
            fp = tarfile.open(fname, 'r')
        else:
            assert False, '只有zip/tar文件可以被解压缩。'
        fp.extractall(base_dir)
        return os.path.join(base_dir, folder) if folder else data_dir
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    def download_all():  
        """下载DATA_HUB中的所有文件。"""
        for name in DATA_HUB:
            download(name)
    
    • 1
    • 2
    • 3
    • 4
    import numpy as np
    import pandas as pd
    import torch
    from torch import nn
    from d2l import torch as d2l
    
    • 1
    • 2
    • 3
    • 4
    • 5
    DATA_HUB['kaggle_house_train'] = (  
        DATA_URL + 'kaggle_house_pred_train.csv',
        '585e9cc93e70b39160e7921475f9bcd7d31219ce')
     
    DATA_HUB['kaggle_house_test'] = (  
        DATA_URL + 'kaggle_house_pred_test.csv',
        'fa19780a7b011d9b009e8bff8e99922a8ee2eb90')
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    train_data = pd.read_csv(download('kaggle_house_train'))
    test_data = pd.read_csv(download('kaggle_house_test'))
    
    print(train_data.shape)
    print(test_data.shape)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    (1460, 81)
    (1459, 80)
    
    • 1
    • 2
    # 前四个和最后两个特征,以及相应标签
    print(train_data.iloc[0:4,[0,1,2,3,-3,-2,-1]])
    
    • 1
    • 2
       Id  MSSubClass MSZoning  LotFrontage SaleType SaleCondition  SalePrice
    0   1          60       RL         65.0       WD        Normal     208500
    1   2          20       RL         80.0       WD        Normal     181500
    2   3          60       RL         68.0       WD        Normal     223500
    3   4          70       RL         60.0       WD       Abnorml     140000
    
    • 1
    • 2
    • 3
    • 4
    • 5

    特征工程

    • 需要注意,这里用的是所有数据集的均值和方差处理数据,实际中不一定能够拿到测试集
    # 在每个样本中,第一个特征是ID,我们将其从数据集中删除,同时删除训练集中的标签
    all_features = pd.concat((train_data.iloc[:,1:-1],test_data.iloc[:,1:]))
    
    • 1
    • 2
    all_features.head()
    
    • 1
    MSSubClassMSZoningLotFrontageLotAreaStreetAlleyLotShapeLandContourUtilitiesLotConfig...ScreenPorchPoolAreaPoolQCFenceMiscFeatureMiscValMoSoldYrSoldSaleTypeSaleCondition
    060RL65.08450PaveNaNRegLvlAllPubInside...00NaNNaNNaN022008WDNormal
    120RL80.09600PaveNaNRegLvlAllPubFR2...00NaNNaNNaN052007WDNormal
    260RL68.011250PaveNaNIR1LvlAllPubInside...00NaNNaNNaN092008WDNormal
    370RL60.09550PaveNaNIR1LvlAllPubCorner...00NaNNaNNaN022006WDAbnorml
    460RL84.014260PaveNaNIR1LvlAllPubFR2...00NaNNaNNaN0122008WDNormal

    5 rows × 79 columns

    all_features.info()
    
    • 1
    
    Int64Index: 2919 entries, 0 to 1458
    Data columns (total 79 columns):
     #   Column         Non-Null Count  Dtype  
    ---  ------         --------------  -----  
     0   MSSubClass     2919 non-null   int64  
     1   MSZoning       2915 non-null   object 
     2   LotFrontage    2433 non-null   float64
     3   LotArea        2919 non-null   int64  
     4   Street         2919 non-null   object 
     5   Alley          198 non-null    object 
     6   LotShape       2919 non-null   object 
     7   LandContour    2919 non-null   object 
     8   Utilities      2917 non-null   object 
     9   LotConfig      2919 non-null   object 
     10  LandSlope      2919 non-null   object 
     11  Neighborhood   2919 non-null   object 
     12  Condition1     2919 non-null   object 
     13  Condition2     2919 non-null   object 
     14  BldgType       2919 non-null   object 
     15  HouseStyle     2919 non-null   object 
     16  OverallQual    2919 non-null   int64  
     17  OverallCond    2919 non-null   int64  
     18  YearBuilt      2919 non-null   int64  
     19  YearRemodAdd   2919 non-null   int64  
     20  RoofStyle      2919 non-null   object 
     21  RoofMatl       2919 non-null   object 
     22  Exterior1st    2918 non-null   object 
     23  Exterior2nd    2918 non-null   object 
     24  MasVnrType     2895 non-null   object 
     25  MasVnrArea     2896 non-null   float64
     26  ExterQual      2919 non-null   object 
     27  ExterCond      2919 non-null   object 
     28  Foundation     2919 non-null   object 
     29  BsmtQual       2838 non-null   object 
     30  BsmtCond       2837 non-null   object 
     31  BsmtExposure   2837 non-null   object 
     32  BsmtFinType1   2840 non-null   object 
     33  BsmtFinSF1     2918 non-null   float64
     34  BsmtFinType2   2839 non-null   object 
     35  BsmtFinSF2     2918 non-null   float64
     36  BsmtUnfSF      2918 non-null   float64
     37  TotalBsmtSF    2918 non-null   float64
     38  Heating        2919 non-null   object 
     39  HeatingQC      2919 non-null   object 
     40  CentralAir     2919 non-null   object 
     41  Electrical     2918 non-null   object 
     42  1stFlrSF       2919 non-null   int64  
     43  2ndFlrSF       2919 non-null   int64  
     44  LowQualFinSF   2919 non-null   int64  
     45  GrLivArea      2919 non-null   int64  
     46  BsmtFullBath   2917 non-null   float64
     47  BsmtHalfBath   2917 non-null   float64
     48  FullBath       2919 non-null   int64  
     49  HalfBath       2919 non-null   int64  
     50  BedroomAbvGr   2919 non-null   int64  
     51  KitchenAbvGr   2919 non-null   int64  
     52  KitchenQual    2918 non-null   object 
     53  TotRmsAbvGrd   2919 non-null   int64  
     54  Functional     2917 non-null   object 
     55  Fireplaces     2919 non-null   int64  
     56  FireplaceQu    1499 non-null   object 
     57  GarageType     2762 non-null   object 
     58  GarageYrBlt    2760 non-null   float64
     59  GarageFinish   2760 non-null   object 
     60  GarageCars     2918 non-null   float64
     61  GarageArea     2918 non-null   float64
     62  GarageQual     2760 non-null   object 
     63  GarageCond     2760 non-null   object 
     64  PavedDrive     2919 non-null   object 
     65  WoodDeckSF     2919 non-null   int64  
     66  OpenPorchSF    2919 non-null   int64  
     67  EnclosedPorch  2919 non-null   int64  
     68  3SsnPorch      2919 non-null   int64  
     69  ScreenPorch    2919 non-null   int64  
     70  PoolArea       2919 non-null   int64  
     71  PoolQC         10 non-null     object 
     72  Fence          571 non-null    object 
     73  MiscFeature    105 non-null    object 
     74  MiscVal        2919 non-null   int64  
     75  MoSold         2919 non-null   int64  
     76  YrSold         2919 non-null   int64  
     77  SaleType       2918 non-null   object 
     78  SaleCondition  2919 non-null   object 
    dtypes: float64(11), int64(25), object(43)
    memory usage: 1.8+ MB
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    # 存在缺失值的列的数目
    all_features.isnull().any(axis=0).sum()
    
    • 1
    • 2
    34
    
    • 1
    # 存在缺失值的行的数目
    all_features.isnull().any(axis=1).sum()
    
    • 1
    • 2
    2919
    
    • 1
    # 将所有缺失的值替换为相应特征的平均值。 通过将特征重新缩放到零均值和单位方差来标准化数据
    numeric_features = all_features.dtypes[all_features.dtypes != "object"].index # 在pandas中object就是字符串类型
    all_features[numeric_features] = all_features[numeric_features].apply(\
        lambda x: (x- x.mean() / x.std())) # 对每一列进行操作
    all_features[numeric_features] = all_features[numeric_features].fillna(0)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    # 再看一下存在缺失值的列的数目
    all_features.isnull().any(axis=0).sum()
    
    • 1
    • 2
    23
    
    • 1

    [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-RvstCcp5-1668172706049)(attachment:faedb1f0-d795-4196-861c-164478db64e4.png)]

    # 处理字符串,one-hot编码
    all_features = pd.get_dummies(all_features, dummy_na=True)
    all_features.shape
    
    • 1
    • 2
    • 3
    (2919, 331)
    
    • 1
    # 再看一下存在缺失值的列的数目
    all_features.isnull().any(axis=0).sum()
    
    • 1
    • 2
    0
    
    • 1

    转为张量

    # 从pandas格式中提取Numpy格式,并将其转为张量
    # 切记将其转换为float32,因为tensor常用的是float32
    n_train = train_data.shape[0] # 行数
    train_features = torch.tensor(all_features[:n_train].values, dtype=torch.float32)
    train_features.shape
    
    • 1
    • 2
    • 3
    • 4
    • 5
    torch.Size([1460, 331])
    
    • 1
    test_features = torch.tensor(all_features[n_train:].values, dtype=torch.float32)
    train_labels = torch.tensor(train_data.SalePrice.values.reshape(-1,1), dtype=torch.float32)
    # 不将训练标签转换成矩阵的话训练过程中会有警告 
    
    • 1
    • 2
    • 3

    模型及训练

    模型

    loss = nn.MSELoss()
    in_features = train_features.shape[1]
    
    def get_net(): # 简单的线性回归
        net = nn.Sequential(nn.Linear(in_features=in_features,out_features=1))
        return net
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6

    [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-aslA2oL1-1668172706050)(attachment:9ebb6e60-ef9c-433d-bc4e-0ae619824975.png)]

    def log_rmse(net, features, labels):
        clipped_preds = torch.clamp(net(features), 1, float('inf'))
        rmse = torch.sqrt(loss(torch.log(clipped_preds), torch.log(labels)))
        return rmse.item()
    
    • 1
    • 2
    • 3
    • 4

    训练函数

    def train(net, train_features, train_labels, test_features, test_labels,
              num_epochs, learning_rate, weight_decay, batch_size):
        train_ls, test_ls = [], []
        train_iter = d2l.load_array((train_features, train_labels), batch_size)
        optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate,
                                     weight_decay=weight_decay)
        for epoch in range(num_epochs):
            for X, y in train_iter:
                optimizer.zero_grad()
                l = loss(net(X), y)
                l.backward()
                optimizer.step()
            train_ls.append(log_rmse(net, train_features, train_labels))
            if test_labels is not None:
                test_ls.append(log_rmse(net, test_features, test_labels))
        return train_ls, test_ls
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16

    k折交叉验证

    注意:我们这里的验证集是从训练集中分出来的

    slice(1,4) # 切片函数 跟python序列数据类型的切片一毛一样
    
    • 1
    slice(1, 4, None)
    
    • 1
    def get_k_fold_data(k,i,X,y):
        assert k > 1
        fold_size = X.shape[0] // k
        X_train, y_train = None, None
        for j in range(k):
            idx = slice(j * fold_size, (j + 1) * fold_size)
            X_part, y_part = X[idx, :], y[idx]
            if j == i:
                X_valid, y_valid = X_part, y_part
            elif X_train is None:
                X_train, y_train = X_part, y_part
            else:
                X_train = torch.cat([X_train, X_part], 0) # 行拼接
                y_train = torch.cat([y_train, y_part], 0)
        return X_train, y_train, X_valid, y_valid
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    def k_fold(k, X_train, y_train, num_epochs, learning_rate, weight_decay,
               batch_size):
        train_l_sum, valid_l_sum = 0, 0
        for i in range(k):
            data = get_k_fold_data(k, i, X_train, y_train)
            net = get_net()
            train_ls, valid_ls = train(net, *data, num_epochs, learning_rate,
                                       weight_decay, batch_size)
            train_l_sum += train_ls[-1]
            valid_l_sum += valid_ls[-1]
            if i == 0:
                d2l.plot(list(range(1, num_epochs + 1)), [train_ls, valid_ls],
                         xlabel='epoch', ylabel='rmse', xlim=[1, num_epochs],
                         legend=['train', 'valid'], yscale='log')
            print(f'fold {i + 1}, train log rmse {float(train_ls[-1]):f}, '
                  f'valid log rmse {float(valid_ls[-1]):f}')
        return train_l_sum / k, valid_l_sum / k
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17

    模型选择

    调整超参数和模型架构

    k, num_epochs, lr, weight_decay, batch_size = 5, 100, 0.03, 0.01, 64 
    # lr这么大的原因是选择了Adam优化器,他能接受的学习率范围更大
    train_l, valid_l = k_fold(k, train_features, train_labels, num_epochs, lr,
                              weight_decay, batch_size)
    print(f'{k}-折验证: 平均训练log rmse: {float(train_l):f}, '
          f'平均验证log rmse: {float(valid_l):f}')
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    fold 1, train log rmse 0.258938, valid log rmse 0.234451
    fold 2, train log rmse 0.250196, valid log rmse 0.281742
    fold 3, train log rmse 0.255146, valid log rmse 0.254656
    fold 4, train log rmse 0.255734, valid log rmse 0.259894
    fold 5, train log rmse 0.252717, valid log rmse 0.259565
    5-折验证: 平均训练log rmse: 0.254546, 平均验证log rmse: 0.258062
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6

    [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-u8RytyWu-1668172706052)(output_40_1.svg)]

    提交kaggle预测num_epochs

    def train_and_prde(train_features, test_features, train_labels, test_data,
                       num_epochs, lr,weight_decay, batch_size):
        net = get_net()
        train_ls, _ = train(net, train_features, train_labels, None, None, num_epochs,
                            lr,weight_decay,batch_size)
        
        d2l.plot(np.arange(1, num_epochs + 1), [train_ls], xlabel='epoch',
                 ylabel='log rmse', xlim=[1, num_epochs], yscale='log')
        print(f'train log rmse {float(train_ls[-1]):f}') # 保留六位小数
        
        preds = net(test_features).detach().numpy()
        test_data['SalePrice'] = pd.Series(preds.reshape(1,-1)[0])
        # print(test_data['SalePrice'])
        
        submission = pd.concat([test_data['Id'], test_data['SalePrice']],axis=1)
        submission.to_csv('submission.csv', index=False)
        
    train_and_prde(train_features, test_features, train_labels, test_data,
                       num_epochs, lr,weight_decay, batch_size)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    train log rmse 0.245931
    
    • 1

    [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-wFfADgyZ-1668172706053)(output_42_1.svg)]

  • 相关阅读:
    用AI的智慧,传递感恩之心——GPT-4o助力教师节祝福
    声纹技术(四):声纹识别的工程部署
    举报即有机会解锁CSDN限定勋章|2022上半年CSDN社区治理数据公布
    bgp的表与消息
    Unity如何查找两个transform最近的公共parent
    win系统玩游戏出现d3dx9_43.dll错误,找不到d3dx9_43.dll的解决方法
    Java的JDBC编程
    burpsuite+proxifier小程序抓包
    GraphQL & Go,graphql基本知识,go-graphql使用
    C++强制类型转换操作符
  • 原文地址:https://blog.csdn.net/qq_49821869/article/details/127813383