博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
一维GMM的Python代码实现
阅读量:5221 次
发布时间:2019-06-14

本文共 1967 字,大约阅读时间需要 6 分钟。

import numpy as npclass GMM(object):    """Gaussian Mixture Model     """    def __init__(self, data, K):        """        K: the number of gaussian models        alpha: the weight for corresponding gaussian model        mu: the vector of means        sigma2: the vector of variances        N: the number of samples        K: the number of gaussian models        """        self.data = data        self.K = K        self.alpha = (np.ones(K) / K)        self.mu = (np.arange(K) - K // 2) * (data.max() - data.min()) / K        self.sigma2 = (np.ones(K))        self.N = len(data)        self.gamma = np.ones((self.N, K)) / K    def phi(self):        # phi.shape(K, N)        phi = (1 / np.sqrt(2 * np.pi * self.sigma2.reshape(self.K, 1)) * np.exp(- (self.data - self.mu.reshape(self.K, 1)) ** 2 / (2 * self.sigma2.reshape(self.K, 1))))        return phi    def fit(self):        sigma2_ = self.sigma2        mu_ = self.mu        while True:            # gamma.shape(N, K)            self.gamma = (0.1 * self.gamma + 0.9 * self.phi().T * self.alpha / (self.phi().T * self.alpha).sum(axis=1).reshape(self.N, 1))            # mu.shape(1, K)            self.mu = (0.1 * self.mu + 0.9 * np.matmul(self.data, self.gamma) / self.gamma.sum(axis=0))            # sigma2.shape(1,K)            self.sigma2 = (0.1 * self.sigma2 + 0.9 * (self.gamma * (data.reshape(self.N, 1) - self.mu) ** 2).sum(axis=0) / self.gamma.sum(axis=0))            # alpha.shape(1, K)            self.alpha = (0.1 * self.alpha + 0.9 * self.gamma.sum(axis=0) / self.N)            if (np.sum((self.mu - mu_) ** 2) + np.abs(self.sigma2 - sigma2_).sum()) < 10 ** (-10):                break            mu_ = self.mu            sigma2_ = self.sigma2            print(self.gamma.argmax(axis=1))        return self.gamma.argmax(axis=1)data = np.concatenate((np.random.normal(-4, 1, 2000), np.random.normal(4, 1, 2000)))gmm = GMM(data, 2)label = gmm.fit()

 

转载于:https://www.cnblogs.com/ningjing213/p/10477018.html

你可能感兴趣的文章
使用arcpy添加grb2数据到镶嵌数据集中
查看>>
[转载] MySQL的四种事务隔离级别
查看>>
QT文件读写
查看>>
C语言小项目-火车票订票系统
查看>>
15.210控制台故障分析(解决问题的思路)
查看>>
BS调用本地应用程序的步骤
查看>>
常用到的多种锁(随时可能修改)
查看>>
用UL标签+CSS实现的柱状图
查看>>
mfc Edit控件属性
查看>>
Linq使用Join/在Razor中两次反射取属性值
查看>>
[Linux]PHP-FPM与NGINX的两种通讯方式
查看>>
Java实现二分查找
查看>>
优秀员工一定要升职吗
查看>>
[LintCode] 462 Total Occurrence of Target
查看>>
springboot---redis缓存的使用
查看>>
架构图-模型
查看>>
sql常见面试题
查看>>
jQuery总结第一天
查看>>
Java -- Swing 组件使用
查看>>
Software--Architecture--DesignPattern IoC, Factory Method, Source Locator
查看>>