当前位置: 首页 > news >正文

自组织映射Python实现

自组织映射(Self-organizing map)Python实现。仅供学习。

#!/usr/bin/env python3"""
Self-organizing map
"""from math import expimport toolzimport numpy as np
import numpy.linalg as LAfrom sklearn.base import ClusterMixin
from sklearn.preprocessing import MinMaxScaler
scaler = MinMaxScaler()class Node:"""NodeAttributes:location (np.ndarray): location of the nodeweight (np.ndarray): weight of the node, in the data sp."""def __init__(self, weight, location=None):self.weight = weightself.location = locationdef normalize(self):return self.weight / LA.norm(self.weight)def output(self, x):# similarity between the node and the input `x`return LA.norm(x - self.weight)def near(self, other, d=0.2):# judge the neighborhood of the nodes by locationsif self.location is not None and other.location is not None:return LA.norm(self.location - other.location) < delse:return 0def update(self, x, eta=0.1):"""update the weight of the nodew += r (x-w)"""self.weight += eta *(x - self.weight)@staticmethoddef random(n=2):weight = np.random.random(n)location = np.random.random(2)node = Node(weight, location)node.normalize()return nodedef plot(self, axes, i1=0, i2=1, *args, **kwargs):x1, x2 = self.weight[i1], self.weight[i2]axes.plot(x1, x2, *args, **kwargs)class Layer(ClusterMixin):"""Layer of SOMA Grid of nodes"""def __init__(self, nodes):self.nodes = list(nodes)@staticmethoddef random(n_nodes=100, *args, **kwargs):return Layer([Node.random(*args, **kwargs) for _ in range(n_nodes)])def output(self, x):# all outputs(similarity to x) of the nodesreturn [node.output(x) for node in self.nodes]def champer(self, x):"""champer node: best matching unit (BMU)"""return self.nodes[self.predict(x)]def predict(self, x):"""the index of best matching unit (BMU)"""return np.argmin(self.output(x))def update(self, x, eta=0.5, d=0.5):# update the nerighors of the best nodec = self.champer(x)for node in self.nodes:if node.near(c, d):node.update(x, eta)def plot(self, axes, i1=0, i2=1, *args, **kwargs):x1 = [node.weight[i1] for node in self.nodes]x2 = [node.weight[i2] for node in self.nodes]axes.scatter(x1, x2, *args, **kwargs)def fit(self, data, eta=0.2, d=0.2, max_iter=100):data = scaler.fit_transform(data)for t in range(max_iter):for x in data:self.update(x, eta=eta*exp(-t/10), d=d*exp(-t/10))if __name__ == '__main__':try:import pandas as pddf = pd.read_csv('heart.csv')  # input your dataexcept Exception as e:printe(e)raise Exception('Please input your data!')def _grid(size=(5, 5), *args, **kwargs):grid = []r, c = sizefor k in range(1,r):row = []for l in range(1,c):weight = np.array((k/r, l/c))# weight = np.random.random(kwargs['dim']) # for randomly generatinglocation = np.array((k/r, l/c))node = Node(weight=weight, location=location)row.append(node)grid.append(row)return griddf = df[['trestbps', 'chol']]N, p = df.shapeX = df.values.astype('float')import matplotlib.pyplot as pltfig = plt.figure()ax = fig.add_subplot(111)X_ = scaler.fit_transform(X)ax.plot(X_[:,0], X_[:,1], 'o')g = _grid(size=(5,5), dim=p)for row in g:x = [node.weight[0] for node in row]y = [node.weight[1] for node in row]ax.plot(x, y, 'g--')for col in zip(*g):x = [node.weight[0] for node in col]y = [node.weight[1] for node in col]ax.plot(x, y, 'g--')l = Layer(nodes=toolz.concat(g))l.plot(ax, marker='s', color='g', alpha=0.2)l.fit(X[:N//2,:], max_iter=50)l.plot(ax, marker='+', color='r')for row in g:x = [node.weight[0] for node in row]y = [node.weight[1] for node in row]ax.plot(x, y, 'r')for col in zip(*g):x = [node.weight[0] for node in col]y = [node.weight[1] for node in col]ax.plot(x, y, 'r')ax.set_title('Demo of SOM')ax.legend(('Data', 'Initial nodes', 'Terminal nodes'))plt.show()

在这里插入图片描述

http://www.lryc.cn/news/204485.html

相关文章:

  • 如何避免阿里云对象储存OSS被盗刷
  • 产品研发团队协作神器!10款提效工具大盘点!
  • LSTM 与 GRU
  • 代码评审CheckList
  • [尚硅谷React笔记]——第5章 React 路由
  • 如何去掉不够优雅的IF-ELSE
  • Python中defaultdict的使用
  • 【ccc3.8】虚拟列表
  • 【23种设计模式】单一职责原则
  • DNS入门学习:什么是TTL值?如何设置合适的TTL值?
  • ilr normalize isometric log-ratio transformation
  • el表单的简单查询方法
  • 【USRP】通信总的分支有哪些
  • 关于服务器网络代理解决方案(1024)
  • Linux下 /etc/shadow内容详解
  • Go学习第二章——变量与数据类型
  • 【剑指Offer】:循环有序列表的插入(涉及链表的知识)
  • 【Django 04】Django-DRF(ModelViewSet)
  • ubuntu命令
  • C++学习之强制类型转换
  • 在Linux中,可以使用以下命令来查看进程
  • 【算法训练-动态规划 一】【应用DP问题】零钱兑换、爬楼梯、买卖股票的最佳时机I、打家劫舍
  • 2023年中职组“网络安全”赛项云南省竞赛任务书
  • Modeling Deep Learning Accelerator Enabled GPUs
  • 《动手学深度学习 Pytorch版》 9.5 机器翻译与数据集
  • 网络入门基础
  • Towards a Rigorous Evaluation of Time-series Anomaly Detection(论文翻译)
  • 理解Python装饰器
  • VR智慧景区,为游客开启智慧旅游新时代
  • 蓝桥杯 Java 青蛙过河