当前位置: 首页 > news >正文

Day 69-70:矩阵分解

代码:

package dl;import java.io.*;
import java.util.Random;/** Matrix factorization for recommender systems.*/public class MatrixFactorization {/*** Used to generate random numbers.*/Random rand = new Random();/*** Number of users.*/int numUsers;/*** Number of items.*/int numItems;/*** Number of ratings.*/int numRatings;/*** Training data.*/Triple[] dataset;/*** A parameter for controlling learning regular.*/double alpha;/*** A parameter for controlling the learning speed.*/double lambda;/*** The low rank of the small matrices.*/int rank;/*** The user matrix U.*/double[][] userSubspace;/*** The item matrix V.*/double[][] itemSubspace;/*** The lower bound of the rating value.*/double ratingLowerBound;/*** The upper bound of the rating value.*/double ratingUpperBound;/*************************** The first constructor.** @param paraFilename*            The data filename.* @param paraNumUsers*            The number of users.* @param paraNumItems*            The number of items.* @param paraNumRatings*            The number of ratings.*************************/public MatrixFactorization(String paraFilename, int paraNumUsers, int paraNumItems,int paraNumRatings, double paraRatingLowerBound, double paraRatingUpperBound) {numUsers = paraNumUsers;numItems = paraNumItems;numRatings = paraNumRatings;ratingLowerBound = paraRatingLowerBound;ratingUpperBound = paraRatingUpperBound;try {readData(paraFilename, paraNumUsers, paraNumItems, paraNumRatings);// adjustUsingMeanRating();} catch (Exception ee) {System.out.println("File " + paraFilename + " cannot be read! " + ee);System.exit(0);} // Of try}// Of the first constructor/*************************** Set parameters.** @param paraRank*            The given rank.* @throws IOException*************************/public void setParameters(int paraRank, double paraAlpha, double paraLambda) {rank = paraRank;alpha = paraAlpha;lambda = paraLambda;}// Of setParameters/*************************** Read the data from the file.** @param paraFilename*            The given file.* @throws IOException*************************/public void readData(String paraFilename, int paraNumUsers, int paraNumItems,int paraNumRatings) throws IOException {File tempFile = new File(paraFilename);if (!tempFile.exists()) {System.out.println("File " + paraFilename + " does not exists.");System.exit(0);} // Of ifBufferedReader tempBufferReader = new BufferedReader(new FileReader(tempFile));// Allocate space.dataset = new Triple[paraNumRatings];String tempString;String[] tempStringArray;for (int i = 0; i < paraNumRatings; i++) {tempString = tempBufferReader.readLine();tempStringArray = tempString.split(",");dataset[i] = new Triple(Integer.parseInt(tempStringArray[0]),Integer.parseInt(tempStringArray[1]), Double.parseDouble(tempStringArray[2]));} // Of for itempBufferReader.close();}// Of readData/*************************** Initialize subspaces. Each value is in [0, 1].*************************/void initializeSubspaces() {userSubspace = new double[numUsers][rank];for (int i = 0; i < numUsers; i++) {for (int j = 0; j < rank; j++) {userSubspace[i][j] = rand.nextDouble();} // Of for j} // Of for iitemSubspace = new double[numItems][rank];for (int i = 0; i < numItems; i++) {for (int j = 0; j < rank; j++) {itemSubspace[i][j] = rand.nextDouble();} // Of for j} // Of for i}// Of initializeSubspaces/*************************** Predict the rating of the user to the item** @param paraUser*            The user index.*************************/public double predict(int paraUser, int paraItem) {double resultValue = 0;for (int i = 0; i < rank; i++) {// The row vector of an user and the column vector of an itemresultValue += userSubspace[paraUser][i] * itemSubspace[paraItem][i];} // Of for ireturn resultValue;}// Of predict/*************************** Train.** @param paraRounds*            The number of rounds.*************************/public void train(int paraRounds) {initializeSubspaces();for (int i = 0; i < paraRounds; i++) {updateNoRegular();if (i % 50 == 0) {// Show the processSystem.out.println("Round " + i);System.out.println("MAE: " + mae());} // Of if} // Of for i}// Of train/*************************** Update sub-spaces using the training data.*************************/public void updateNoRegular() {for (int i = 0; i < numRatings; i++) {int tempUserId = dataset[i].user;int tempItemId = dataset[i].item;double tempRate = dataset[i].rating;double tempResidual = tempRate - predict(tempUserId, tempItemId); // Residual// Update user subspacedouble tempValue = 0;for (int j = 0; j < rank; j++) {tempValue = 2 * tempResidual * itemSubspace[tempItemId][j];userSubspace[tempUserId][j] += alpha * tempValue;} // Of for j// Update item subspacefor (int j = 0; j < rank; j++) {tempValue = 2 * tempResidual * userSubspace[tempUserId][j];itemSubspace[tempItemId][j] += alpha * tempValue;} // Of for j} // Of for i}// Of updateNoRegular/*************************** Compute the RSME.** @return RSME of the current factorization.*************************/public double rsme() {double resultRsme = 0;int tempTestCount = 0;for (int i = 0; i < numRatings; i++) {int tempUserIndex = dataset[i].user;int tempItemIndex = dataset[i].item;double tempRate = dataset[i].rating;double tempPrediction = predict(tempUserIndex, tempItemIndex);// +// DataInfo.mean_rating;if (tempPrediction < ratingLowerBound) {tempPrediction = ratingLowerBound;} else if (tempPrediction > ratingUpperBound) {tempPrediction = ratingUpperBound;} // Of ifdouble tempError = tempRate - tempPrediction;resultRsme += tempError * tempError;tempTestCount++;} // Of for ireturn Math.sqrt(resultRsme / tempTestCount);}// Of rsme/*************************** Compute the MAE.** @return MAE of the current factorization.*************************/public double mae() {double resultMae = 0;int tempTestCount = 0;for (int i = 0; i < numRatings; i++) {int tempUserIndex = dataset[i].user;int tempItemIndex = dataset[i].item;double tempRate = dataset[i].rating;double tempPrediction = predict(tempUserIndex, tempItemIndex);if (tempPrediction < ratingLowerBound) {tempPrediction = ratingLowerBound;} // Of ifif (tempPrediction > ratingUpperBound) {tempPrediction = ratingUpperBound;} // Of ifdouble tempError = tempRate - tempPrediction;resultMae += Math.abs(tempError);// System.out.println("resultMae: " + resultMae);tempTestCount++;} // Of for ireturn (resultMae / tempTestCount);}// Of mae/*************************** Compute the MAE.** @return MAE of the current factorization.*************************/public static void testTrainingTesting(String paraFilename, int paraNumUsers, int paraNumItems,int paraNumRatings, double paraRatingLowerBound, double paraRatingUpperBound,int paraRounds) {try {// Step 1. read the training and testing dataMatrixFactorization tempMF = new MatrixFactorization(paraFilename, paraNumUsers,paraNumItems, paraNumRatings, paraRatingLowerBound, paraRatingUpperBound);tempMF.setParameters(5, 0.0001, 0.005);// Step 3. update and predictSystem.out.println("Begin Training ! ! !");tempMF.train(paraRounds);double tempMAE = tempMF.mae();double tempRSME = tempMF.rsme();System.out.println("Finally, MAE = " + tempMAE + ", RSME = " + tempRSME);} catch (Exception e) {e.printStackTrace();} // Of try}// Of testTrainingTesting/*************************** @param args*************************/public static void main(String args[]) {testTrainingTesting("C:\\Users\\86183\\IdeaProjects\\deepLearning\\src\\main\\java\\resources\\movielens-943u1682m.txt", 943, 1682, 10000, 1, 5, 2000);}// Of mainpublic class Triple {public int user;public int item;public double rating;/************************ The constructor.**********************/public Triple() {user = -1;item = -1;rating = -1;}// Of the first constructor/************************ The constructor.**********************/public Triple(int paraUser, int paraItem, double paraRating) {user = paraUser;item = paraItem;rating = paraRating;}// Of the first constructor/************************ Show me.**********************/public String toString() {return "" + user + ", " + item + ", " + rating;}// Of toString}// Of class Triple}// Of class MatrixFactorization

结果:

 

http://www.lryc.cn/news/101083.html

相关文章:

  • 数据结构:树的存储结构
  • Vue前端渲染blob二进制对象图片的方法
  • Java的标记接口(Marker Interface)
  • Kafka基础架构与核心概念
  • 观察者模式与观察者模式实例EventBus
  • 科普 | OSI模型
  • redis相关异常之RedisConnectionExceptionRedisCommandTimeoutException
  • Merge the squares! 2023牛客暑期多校训练营4-H
  • STM32 串口学习(二)
  • 点大商城V2_2.5.0 全开源版 商家自营+多商户入驻 百度+支付宝+QQ+头条+小程序端+unipp开源前端安装测试教程
  • “深入理解SpringBoot:从入门到精通“
  • PCB绘制时踩的坑 - SOT-223封装
  • Go语法入门 + 项目实战
  • QT控件通过qss设置子控件的对齐方式、大小自适应等
  • 基于java在线收银系统设计与实现
  • Linux--进程的新建状态
  • 区间dp,合并石子模板题
  • C++代码格式化工具clang-format详细介绍
  • CentOS 7安装PostgreSQL 15版本数据库
  • QGraphicsView实现简易地图2『瓦片经纬度』
  • 医学图像重建—第一章笔记
  • python-pytorch基础之神经网络分类
  • 【C++ 程序设计】实战:C++ 变量实践练习题
  • 微软对Visual Studio 17.7 Preview 4进行版本更新,新插件管理器亮相
  • Kafka 入门到起飞 - Kafka怎么做到保障消息不会重复消费的? 消费者组是什么?
  • MongoDB 的增、查、改、删
  • mysql常用操作命令
  • 数学建模常见模型汇总
  • C#使用LINQ查询操作符实例代码(二)
  • jenkinsfile小试牛刀