• 【无标题】


    # -*- coding: utf-8 -*-
    
    import sys
    import math
    from collections import defaultdict
    
    class MaxEnt:
        def __init__(self):
            self._samples = []      # 样本集, 元素是[y,x1,x2,...,xn]的元组
            self._Y = set([])       # 标签集合,相当于去重之后的y
            self._numXY = defaultdict(int)  # key是(xi,yi)对,value是count(xi,yi)
            self._N = 0         # 样本数量
            self._n = 0         # 特征对(xi,yi)总数量
            self._xyID = {}     # 对(x,y)对做的顺序编号(ID),key是(xi,yi)对,value是ID
            self._C = 0         # 样本最大的特征数量,用于求参数的迭代,见IIS原理说明
            self._ep_ = []      # 样本分布的特征-->期望值
            self._ep = []       # 模型分布的特征-->期望值
            self._w = []        # 对应n个特征的权值
            self._lastw = []    # 上一轮迭代的权值
            self._EPS = 0.01    # 判断是否收敛的阈值
    
        def load_data(self,filename):
            for line in open(filename,"r"):
                sample = line.strip().split("\t")
                if len(sample) < 2: #至少:标签+一个特征
                    continue
                y = sample[0]
                X = sample[1:]
                self._samples.append(sample)
                self._Y.add(y)      #label
                for x in set(X):    #set给X去重
                    self._numXY[(x,y)] += 1
    
        def _initparams(self):
            self._N = len(self._samples)
            self._n = len(self._numXY) # 没有做任何特征提取操作,直接操作特征
            self._C = max([len(sample) -1 for sample in self._samples])
            self._w = [0.0] * self._n
            self._lastw = self._w[:]
            self._sample_ep()
    
        def _convergence(self):
            for w,lw in zip(self._w,self._lastw):
                if math.fabs(w-lw) >= self._EPS:
                    return False
            return True
    
        def _sample_ep(self):
            self._ep_ = [0.0] * self._n
            # 计算方法参见公式(20)
            for i,xy in enumerate(self._numXY):
                self._ep_[i] = self._numXY[xy] * 1.0 / self._N
                self._xyID[xy] = i
    
        def _zx(self,X):
            # calculate Z(x) 计算方法参见公式(15)
            ZX = 0.0
            for y in self._Y:
                sum = 0.0
                for x in X:
                    if(x,y) in self._numXY:
                        sum += self._w[self._xyID[(x,y)]]
                ZX += math.exp(sum)
            return ZX
        def _pyx(self,X):
            # calculate p(y|x), 计算方法参见公式(22)
            ZX = self._zx(X)
            results = []
            for y in self._Y:
                sum = 0.0
                for x in X:
                    if(x,y) in self._numXY:  #这个判断相当于指示函数的作用
                        sum += self._w[self._xyID[(x,y)]]
                pyx = 1.0 / ZX * math.exp(sum)
                results.append((y,pyx))
            return results
        def _model_ep(self):
            self._ep = [0.0] * self._n
            # 参见公式(21)
            for sample in self._samples:
                X = sample[1:]
                pyx = self._pyx(X)
                for y,p in pyx:
                    for x in X:
                        if(x,y) in self._numXY:
                            self._ep[self._xyID[(x,y)]] += p * 1.0 / self._N
        def train(self,maxiter = 1000):
            self._initparams()
            for i in range(0,maxiter):
                print("Iter:%d...."%i)
                self._lastw = self._w[:]    #保存上一轮权值
                self._model_ep()
                # 更新每个特征的权值
                for i,w in enumerate(self._w):
                    # 参考公式(19)
                    self._w[i] += 1.0 / self._C * math.log(self._ep_[i] / self._ep[i])
                print(self._w)
                # 检查是否收敛
                if self._convergence():
                    break
        def predict(self,inp):
            X = inp.strip().split("\t")
            prob = self._pyx(X)
            return prob
    
    if __name__ == "__main__":
        maxent = MaxEnt()
        maxent.load_data('data.txt')
        maxent.train()
        print (maxent.predict("sunny\thot\thigh\tFALSE"))
        print (maxent.predict("overcast\thot\thigh\tFALSE"))
        print (maxent.predict("sunny\tcool\thigh\tTRUE"))
        sys.exit(0)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111
    • 112
    • 113

    预测的结果:
    [(‘yes’, 0.004162651871979297), (‘no’, 0.9958373481280207)]
    [(‘yes’, 0.9943682102360447), (‘no’, 0.005631789763955368)]
    [(‘yes’, 1.4464465173635736e-07), (‘no’, 0.9999998553553483)]

    data.txt内容:
    no sunny hot high FALSE
    no sunny hot high TRUE
    yes overcast hot high FALSE
    yes rainy mild high FALSE
    yes rainy cool normal FALSE
    no rainy cool normal TRUE
    yes overcast cool normal TRUE
    no sunny mild high FALSE
    yes sunny cool normal FALSE
    yes rainy mild normal FALSE
    yes sunny mild normal TRUE
    yes overcast mild high TRUE
    yes overcast hot normal FALSE
    no rainy mild high TRUE

    参考:
    https://blog.csdn.net/u014688145/article/details/55003910
    http://www.hankcs.com/ml/the-logistic-regression-and-the-maximum-entropy-model.html

  • 相关阅读:
    go: no such tool “compile“(一次糟糕体验)
    VirtualBox 虚拟机
    java缓存
    IGMP协议
    HTTPS应该搞懂了吧。
    PolarDB-X 的 in 常量查询
    百度地图实现 区域高亮
    设计模式之组合模式-创建层次化的对象结构
    入门力扣自学笔记74 C++ (题目编号710)(未理解)
    阿里云将投入70亿元建国际生态、增设6大海外服务中心
  • 原文地址:https://blog.csdn.net/bian_h_f612701198412/article/details/126287022