• 【无标题】


    # -*- coding: utf-8 -*-
    
    import sys
    import math
    from collections import defaultdict
    
    class MaxEnt:
        def __init__(self):
            self._samples = []      # 样本集, 元素是[y,x1,x2,...,xn]的元组
            self._Y = set([])       # 标签集合,相当于去重之后的y
            self._numXY = defaultdict(int)  # key是(xi,yi)对,value是count(xi,yi)
            self._N = 0         # 样本数量
            self._n = 0         # 特征对(xi,yi)总数量
            self._xyID = {}     # 对(x,y)对做的顺序编号(ID),key是(xi,yi)对,value是ID
            self._C = 0         # 样本最大的特征数量,用于求参数的迭代,见IIS原理说明
            self._ep_ = []      # 样本分布的特征-->期望值
            self._ep = []       # 模型分布的特征-->期望值
            self._w = []        # 对应n个特征的权值
            self._lastw = []    # 上一轮迭代的权值
            self._EPS = 0.01    # 判断是否收敛的阈值
    
        def load_data(self,filename):
            for line in open(filename,"r"):
                sample = line.strip().split("\t")
                if len(sample) < 2: #至少:标签+一个特征
                    continue
                y = sample[0]
                X = sample[1:]
                self._samples.append(sample)
                self._Y.add(y)      #label
                for x in set(X):    #set给X去重
                    self._numXY[(x,y)] += 1
    
        def _initparams(self):
            self._N = len(self._samples)
            self._n = len(self._numXY) # 没有做任何特征提取操作,直接操作特征
            self._C = max([len(sample) -1 for sample in self._samples])
            self._w = [0.0] * self._n
            self._lastw = self._w[:]
            self._sample_ep()
    
        def _convergence(self):
            for w,lw in zip(self._w,self._lastw):
                if math.fabs(w-lw) >= self._EPS:
                    return False
            return True
    
        def _sample_ep(self):
            self._ep_ = [0.0] * self._n
            # 计算方法参见公式(20)
            for i,xy in enumerate(self._numXY):
                self._ep_[i] = self._numXY[xy] * 1.0 / self._N
                self._xyID[xy] = i
    
        def _zx(self,X):
            # calculate Z(x) 计算方法参见公式(15)
            ZX = 0.0
            for y in self._Y:
                sum = 0.0
                for x in X:
                    if(x,y) in self._numXY:
                        sum += self._w[self._xyID[(x,y)]]
                ZX += math.exp(sum)
            return ZX
        def _pyx(self,X):
            # calculate p(y|x), 计算方法参见公式(22)
            ZX = self._zx(X)
            results = []
            for y in self._Y:
                sum = 0.0
                for x in X:
                    if(x,y) in self._numXY:  #这个判断相当于指示函数的作用
                        sum += self._w[self._xyID[(x,y)]]
                pyx = 1.0 / ZX * math.exp(sum)
                results.append((y,pyx))
            return results
        def _model_ep(self):
            self._ep = [0.0] * self._n
            # 参见公式(21)
            for sample in self._samples:
                X = sample[1:]
                pyx = self._pyx(X)
                for y,p in pyx:
                    for x in X:
                        if(x,y) in self._numXY:
                            self._ep[self._xyID[(x,y)]] += p * 1.0 / self._N
        def train(self,maxiter = 1000):
            self._initparams()
            for i in range(0,maxiter):
                print("Iter:%d...."%i)
                self._lastw = self._w[:]    #保存上一轮权值
                self._model_ep()
                # 更新每个特征的权值
                for i,w in enumerate(self._w):
                    # 参考公式(19)
                    self._w[i] += 1.0 / self._C * math.log(self._ep_[i] / self._ep[i])
                print(self._w)
                # 检查是否收敛
                if self._convergence():
                    break
        def predict(self,inp):
            X = inp.strip().split("\t")
            prob = self._pyx(X)
            return prob
    
    if __name__ == "__main__":
        maxent = MaxEnt()
        maxent.load_data('data.txt')
        maxent.train()
        print (maxent.predict("sunny\thot\thigh\tFALSE"))
        print (maxent.predict("overcast\thot\thigh\tFALSE"))
        print (maxent.predict("sunny\tcool\thigh\tTRUE"))
        sys.exit(0)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111
    • 112
    • 113

    预测的结果:
    [(‘yes’, 0.004162651871979297), (‘no’, 0.9958373481280207)]
    [(‘yes’, 0.9943682102360447), (‘no’, 0.005631789763955368)]
    [(‘yes’, 1.4464465173635736e-07), (‘no’, 0.9999998553553483)]

    data.txt内容:
    no sunny hot high FALSE
    no sunny hot high TRUE
    yes overcast hot high FALSE
    yes rainy mild high FALSE
    yes rainy cool normal FALSE
    no rainy cool normal TRUE
    yes overcast cool normal TRUE
    no sunny mild high FALSE
    yes sunny cool normal FALSE
    yes rainy mild normal FALSE
    yes sunny mild normal TRUE
    yes overcast mild high TRUE
    yes overcast hot normal FALSE
    no rainy mild high TRUE

    参考:
    https://blog.csdn.net/u014688145/article/details/55003910
    http://www.hankcs.com/ml/the-logistic-regression-and-the-maximum-entropy-model.html

  • 相关阅读:
    HttpRunnerManager(四) - 完成安装注册&登录
    校园招聘面试精典博文java
    DJYOS开源往事二:DJYOS开源工作室时期
    Kubernetes部署
    JS轮播图实现
    包管理工具--》发布一个自己的npm包
    阿里5年经验之谈 —— 记录一次jmeter压测的过程!
    Android终于要推出Google官方的二维码扫描库了?
    vue+iview中日期时间选择器不能选择当前日期之前包括时分秒
    Qt实现单例模式:Q_GLOBAL_STATIC和Q_GLOBAL_STATIC_WITH_ARGS
  • 原文地址:https://blog.csdn.net/bian_h_f612701198412/article/details/126287022