当前位置: 首页 > news >正文

机器学习10

自定义数据集 使用scikit-learn中svm的包实现svm分类

代码

import numpy as np
import matplotlib.pyplot as pltclass1_points = np.array([[1.9, 1.2],[1.5, 2.1],[1.9, 0.5],[1.5, 0.9],[0.9, 1.2],[1.1, 1.7],[1.4, 1.1]])class2_points = np.array([[3.2, 3.2],[3.7, 2.9],[3.2, 2.6],[1.7, 3.3],[3.4, 2.6],[4.1, 2.3],[3.0, 2.9]])x1_data = np.concatenate((class1_points[:, 0], class2_points[:, 0]))
x2_data = np.concatenate((class1_points[:, 1], class2_points[:, 1]))
y = np.concatenate((np.ones(class1_points.shape[0]), -np.ones(class2_points.shape[0])))w1 = 0.1
w2 = 0.1
b = 0
learning_rate = 0.05l_data = x1_data.sizefig, (ax1, ax2) = plt.subplots(2, 1)step_list = np.array([])  # 初始化为空数组
loss_values = np.array([])  # 初始化为空数组num_iterations = 1000
for n in range(1, num_iterations + 1):z = w1 * x1_data + w2 * x2_data + byz = y * zloss = 1 - yzloss[loss < 0] = 0hinge_loss = np.mean(loss)loss_values = np.append(loss_values, hinge_loss)step_list = np.append(step_list, n)gradient_w1 = 0gradient_w2 = 0gradient_b = 0for i in range(len(y)):if loss[i] > 0:gradient_w1 += -y[i] * x1_data[i]gradient_w2 += -y[i] * x2_data[i]gradient_b += -y[i]gradient_w1 /= len(y)gradient_w2 /= len(y)gradient_b /= len(y)w1 -= learning_rate * gradient_w1w2 -= learning_rate * gradient_w2b -= learning_rate * gradient_b# 显示频率设置frequence_display = 50if n % frequence_display == 0 or n == 1:if np.abs(w2) < 1e-5:continuex1_min, x1_max = 0, 6x2_min, x2_max = -(w1 * x1_min + b) / w2, -(w1 * x1_max + b) / w2ax1.clear()ax1.scatter(x1_data[:len(class1_points)], x2_data[:len(class1_points)], c='red', label='Class 1')ax1.scatter(x1_data[len(class1_points):], x2_data[len(class1_points):], c='blue', label='Class 2')ax1.plot((x1_min, x1_max), (x2_min, x2_max), 'r-')ax1.set_title(f"SVM: w1={round(w1.item(), 3)}, w2={round(w2.item(), 3)}, b={round(b.item(), 3)}")ax2.clear()ax2.plot(step_list, loss_values, 'g-')ax2.set_xlabel("Step")ax2.set_ylabel("Loss")# 显示图形plt.pause(1)plt.show()

效果展示

http://www.lryc.cn/news/530995.html

相关文章:

  • 【Block总结】CoT,上下文Transformer注意力|即插即用
  • linux库函数 gettimeofday() localtime的概念和使用案例
  • 编程题-电话号码的字母组合(中等)
  • EasyExcel使用详解
  • 基于“蘑菇书”的强化学习知识点(二):强化学习中基于策略(Policy-Based)和基于价值(Value-Based)方法的区别
  • 民法学学习笔记(个人向) Part.2
  • 物业管理系统源码驱动社区管理革新提升用户满意度与服务效率
  • 租房管理系统助力数字化转型提升租赁服务质量与用户体验
  • Ollama教程:轻松上手本地大语言模型部署
  • Baklib推动数字化内容管理解决方案助力企业数字化转型
  • DeepSeek-R1 论文. Reinforcement Learning 通过强化学习激励大型语言模型的推理能力
  • DOM 操作入门:HTML 元素操作与页面事件处理
  • 使用 HTTP::Server::Simple 实现轻量级 HTTP 服务器
  • C++滑动窗口技术深度解析:核心原理、高效实现与高阶应用实践
  • 基于构件的软件开发方法
  • 网站快速收录:如何设置robots.txt文件?
  • OpenGL学习笔记(六):Transformations 变换(变换矩阵、坐标系统、GLM库应用)
  • 8.攻防世界Web_php_wrong_nginx_config
  • 【优先算法】专题——位运算
  • qt.qpa.plugin: Could not find the Qt platform plugin “dxcb“ in ““
  • 1-刷力扣问题记录
  • 物联网 STM32【源代码形式-使用以太网】连接OneNet IOT从云产品开发到底层MQTT实现,APP控制 【保姆级零基础搭建】
  • 【单层神经网络】基于MXNet的线性回归实现(底层实现)
  • unity中的动画混合树
  • 《基于deepseek R1开源大模型的电子数据取证技术发展研究》
  • Potplayer常用快捷键
  • C++ Primer 自定义数据结构
  • 35.Word:公积金管理中心文员小谢【37】
  • 北京钟鼓楼:立春“鞭春牛”,钟鼓迎春来
  • 股票入门知识