当前位置: 首页 > news >正文

蒙特卡洛树搜索方法实践

一、算法概述

蒙特卡洛树搜索是一种用于决策过程的启发式搜索算法,特别适用于具有巨大状态空间的游戏和优化问题。其主要结合了:蒙特卡洛方法:通过随机采样来估算复杂问题的解;树搜索:将决策问题建模为树结构;UCB(Upper Confidence Bound):平衡探索与利用的选择策略。

蒙特卡洛树基于两个基本概念:一是可以使用随机模拟来近似某个行动的真实价值;二是可以有效地利用这些价值将策略调整为最佳优先策略。该算法在对博弈树先前探索结果的引导下,逐步构建一个部分博弈树。这棵树用于估计走法的价值,随着树的构建,这些估计值会变得更加准确。

基本算法包括迭代地构建搜索树,直到达到某个预定义的计算预算(通常的是时间、内存或迭代次数限制)。此时,搜索停止,并返回性能最优的根动作。搜索树中的每个节点代表领域的一个状态,指向子节点的有向链接代表通向后续状态的动作。每次搜索迭代应用四个步骤。

(1)选择:从根节点开始,递归应用子节点选择策略,沿着树向下搜索,直到找到最急需扩展的节点。如果一个节点代表非终止状态且有未访问(即未扩展)的子节点,则该节点是可扩展的。

(2)扩展:根据可用的操作,添加一个(或多个)子节点来扩展树。

(3)模拟:根据默认策略从新节点开始进行模拟,以产生一个结果。

(4)反向传播:模拟结果通过选定的节点“回溯”(即反向传播),以更新这些节点的统计信息。

这些可以分为两种截然不同的政策。

(1)树策略:从搜索树中已有的节点中选择或创建一个叶节点。

(2)默认策略:从给定的非终止状态开始执行该领域的操作,以生成价值估计(模拟)。

反向传播步骤本身不使用策略,而是更新节点统计信息,这些信息会为未来的树策略决策提供依据。这些步骤在伪代码Algorithms1中总结如下:这里,v_0是根节点,对应初始状态s_0v_l是树策略阶段最后到达的节点,对应状态s_l;\Delta是从状态s_l开始,用默认策略模拟到终局后获得的奖励。整个MCTS搜索的结果a是指:在根节点v_0的所有子节点中,选择“最优”的那个动作a

请注意,文献中对“模拟”这一术语存在不同的解释。一些作者认为它指的是在树策略和默认策略下每次迭代所选择的完整动作序列,而大多数作者认为它仅指使用默认策略所选择的动作序列。在本文中,我们将把“推演”和“模拟”这两个术语理解为“根据默认策略将任务执行至完成”,即树策略的选择和扩展步骤完成后所选择的动作。

二、详细案例分析

level代表蒙特卡洛树决策的层级,即程序会连续做几次决策

这个案例实现了一个数值优化游戏:游戏有10轮,每轮可以从[-2,2,3,-3]\times n中选择一个数字,目标是让累积值尽可能接近0。一共在10轮里做出10个决策,选择出10个数字。

(1)State类——游戏状态

关键特点:(1)value:当前累积值。(2)turn:剩余轮数。(3)moves:已执行的移动序列

移动机制解析:(1)当前轮数越高,移动的影响越大。(2)第1轮可选:[-20,20,30,-30],第十轮可选:[-2,2,3,-3],早期错误代价更高。

class State():NUM_TURNS = 10GOAL = 0MOVES=[2,-2,3,-3]MAX_VALUE= (5.0*(NUM_TURNS-1)*NUM_TURNS)/2num_moves=len(MOVES)def __init__(self, value=0, moves=[], turn=NUM_TURNS):self.value=valueself.turn=turnself.moves=movesdef next_state(self):nextmove=random.choice([x*self.turn for x  in self.MOVES])next=State(self.value+nextmove, self.moves+[nextmove],self.turn-1)return nextdef terminal(self):if self.turn == 0:return Truereturn Falsedef reward(self):r = 1.0-(abs(self.value-self.GOAL)/self.MAX_VALUE)return rdef __hash__(self):return int(hashlib.md5(str(self.moves).encode('utf-8')).hexdigest(),16)def __eq__(self,other):if hash(self)==hash(other):return Truereturn Falsedef __repr__(self):s="Value: %d; Moves: %s"%(self.value,self.moves)return sdef node_id(self):return str(hash(self))def node_info(self):return f"Value:{self.value}\nMoves:{self.moves}"

逐行解析每一步,NUM_TURN定义了游戏一共有10轮,GOAL定义了目标的值,MOVES定义了动作空间,MAX_VALUE定义了最大的值。num_moves为动作空间的大小.

第一个初始化目标的value值,turn为轮次,moves为移动空间。next_state目标是从目前节点随机选择下一个节点,并且更新到下个节点的状态。terminal就是终止条件。reward就是奖励函数。__hash__和__eq__,Python对象才能放进set、当作dict的key,才能高效查重。例如

s1=State(value=0,moves=[2,-2],turn=8),s2==State(value=0,moves=[2,-2],turn=8)有了__eq__和__hash__能够快速判断这个新状态是不是已经作为子节点存在。

(2)Node类——MCTS节点

介绍一下TreePolicy方法,一共分为四种情况。

class Node():def __init__(self, state, parent=None):self.visits=1self.reward=0.0self.state=stateself.children=[]self.parent=parentdef add_child(self,child_state):child=Node(child_state,self)self.children.append(child)def update(self,reward):self.reward+=rewardself.visits+=1def fully_expanded(self, num_moves_lambda):num_moves = self.state.num_movesif num_moves_lambda != None:num_moves = num_moves_lambda(self)if len(self.children)==num_moves:return Truereturn Falsedef __repr__(self):s="Node; children: %d; visits: %d; reward: %f"%(len(self.children),self.visits,self.reward)return sdef node_info(self):return f"N:{self.visits}\nR:{self.reward:.2f}\n{self.state.moves}"

初始化节点,仿真次数为1,奖励为0,状态为状态,子节点为空,父节点。

add_child为添加子节点,它的状态为输入,父节点为当前节点本身。然后将刚创建的子节点加入children离去。

update为更新该节点的奖励,以及访问次数加一次。

fully_expanded为判断该节点是否展开了子节点。

(3)MCTS关键算法

treepolicy:树策略

def TREEPOLICY(node, num_moves_lambda):#a hack to force 'exploitation' in a game where there are many options, and you may never/not want to fully expand firstwhile node.state.terminal()==False:if len(node.children)==0:return EXPAND(node)elif random.uniform(0,1)<.5:node=BESTCHILD(node,SCALAR)else:if node.fully_expanded(num_moves_lambda)==False:return EXPAND(node)else:node=BESTCHILD(node,SCALAR)return node
def EXPAND(node):tried_children=[c.state for c in node.children]new_state=node.state.next_state()while new_state in tried_children and new_state.terminal()==False:new_state=node.state.next_state()node.add_child(new_state)return node.children[-1]#current this uses the most vanilla MCTS formula it is worth experimenting with THRESHOLD ASCENT (TAGS)
def BESTCHILD(node,scalar):bestscore=0.0bestchildren=[]for c in node.children:exploit=c.reward/c.visitsexplore=math.sqrt(2.0*math.log(node.visits)/float(c.visits))score=exploit+scalar*exploreif score==bestscore:bestchildren.append(c)if score>bestscore:bestchildren=[c]bestscore=scoreif len(bestchildren)==0:logger.warn("OOPS: no best child found, probably fatal")return random.choice(bestchildren)

1.首次访问节点(没有子节点),其动作为直接展开,创建第一个子节点。例如第一轮选择动作20。

2.50%概率利用已有信息,使用UCB公式选择最有希望的子节点,根据之前的模拟结果,发现选择-20的子节点表现最好,就选择它。

3.50%概率继续探索,场景两种情况:未完全展开:还有动作没尝试过,创建新子节点。已完全展开:所有4个动作都试过了,选择最佳的继续向下。

DefaultPolicy:默认策略

def DEFAULTPOLICY(state):while state.terminal()==False:state=state.next_state()return state.reward()

即如果选择子节点20,然后后面不是有9轮,通过随机模拟的情况得到该节点某种情况的回报值。

Backup:回溯

def BACKUP(node,reward):while node!=None:node.visits+=1node.reward+=rewardnode=node.parentreturn

根据扩展计算的结果来更新该节点的历史信息,为Bestchild的选择做准备。

(4)数字案例

例如,根节点初始化状态为[],访问次数为0,奖励值函数为0。根据输入的level来判定做决策到第几层。例如level为1,则就选出第一轮最优的个数即可。例如第一次仿真开始,

通过树策略选择30,动作记为[30],通过DefaultPolicy仿真一次选择30这个动作后面10轮的一个轮次,评估奖励函数为0.44,访问次数加1。然后通过Backup将根节点的奖励值加为0.44。然后开启第二轮仿真,此时根据TreePolicy现在有50%的概率是扩展根节点,另外50%是基于现有节点的情况下选择最优节点,然后通过DefaultPolicy继续评估。经过50%的概率筛选节点扩展到[-30],根据DefaultPolicy评估到0.866奖励函数,此时加到根节点的奖励函数值里。开始第三轮仿真,此时过程扩展到[20],评估奖励为0.733,继续添加到根节点的奖励函数值里。然后开始第四轮仿真,此时根据50%的概率是根据目前已经展开的三个节点如[30],[-30],[20]节点去展开,三个节点均被访问了一次,根据公式

p=reward+\sqrt{2ln(n_p)/n_c}

然后选择了目前评估函数最高的[-30],在其下面展开节点[-30,18],然后通过DefaultPolicy计算得到0.676此时将这个奖励值即加到根节点,也加到[-30]这个节点,以此类推。

完整代码

#!/usr/bin/env python
import random
import math
import hashlib
import logging
import argparse
import queue
from graphviz import Digraph"""
A quick Monte Carlo Tree Search implementation.  For more details on MCTS see See http://pubs.doc.ic.ac.uk/survey-mcts-methods/survey-mcts-methods.pdfThe State is a game where you have NUM_TURNS and at turn i you can make
a choice from an integeter [-2,2,3,-3]*(NUM_TURNS+1-i).  So for example in a game of 4 turns, on turn for turn 1 you can can choose from [-8,8,12,-12], and on turn 2 you can choose from [-6,6,9,-9].  At each turn the choosen number is accumulated into a aggregation value.  The goal of the game is for the accumulated value to be as close to 0 as possible.The game is not very interesting but it allows one to study MCTS which is.  Some features 
of the example by design are that moves do not commute and early mistakes are more costly.  In particular there are two models of best child that one can use 
"""#MCTS scalar.  Larger scalar will increase exploitation, smaller will increase exploration. 
SCALAR=1/(2*math.sqrt(2.0))logging.basicConfig(level=logging.WARNING)
logger = logging.getLogger('MyLogger')class State():NUM_TURNS = 10GOAL = 0MOVES=[2,-2,3,-3]MAX_VALUE= (5.0*(NUM_TURNS-1)*NUM_TURNS)/2num_moves=len(MOVES)def __init__(self, value=0, moves=[], turn=NUM_TURNS):self.value=valueself.turn=turnself.moves=movesdef next_state(self):nextmove=random.choice([x*self.turn for x  in self.MOVES])next=State(self.value+nextmove, self.moves+[nextmove],self.turn-1)return nextdef terminal(self):if self.turn == 0:return Truereturn Falsedef reward(self):r = 1.0-(abs(self.value-self.GOAL)/self.MAX_VALUE)return rdef __hash__(self):return int(hashlib.md5(str(self.moves).encode('utf-8')).hexdigest(),16)def __eq__(self,other):if hash(self)==hash(other):return Truereturn Falsedef __repr__(self):s="Value: %d; Moves: %s"%(self.value,self.moves)return sdef node_id(self):return str(hash(self))def node_info(self):return f"Value:{self.value}\nMoves:{self.moves}"class Node():def __init__(self, state, parent=None):self.visits=1self.reward=0.0self.state=stateself.children=[]self.parent=parentdef add_child(self,child_state):child=Node(child_state,self)self.children.append(child)def update(self,reward):self.reward+=rewardself.visits+=1def fully_expanded(self, num_moves_lambda):num_moves = self.state.num_movesif num_moves_lambda != None:num_moves = num_moves_lambda(self)if len(self.children)==num_moves:return Truereturn Falsedef __repr__(self):s="Node; children: %d; visits: %d; reward: %f"%(len(self.children),self.visits,self.reward)return sdef node_info(self):return f"N:{self.visits}\nR:{self.reward:.2f}\n{self.state.moves}"def UCTSEARCH(budget,root,num_moves_lambda = None):for iter in range(int(budget)):if iter%10000==9999:logger.info("simulation: %d"%iter)logger.info(root)front=TREEPOLICY(root, num_moves_lambda)reward=DEFAULTPOLICY(front.state)BACKUP(front,reward)return BESTCHILD(root,0)def TREEPOLICY(node, num_moves_lambda):#a hack to force 'exploitation' in a game where there are many options, and you may never/not want to fully expand firstwhile node.state.terminal()==False:if len(node.children)==0:return EXPAND(node)elif random.uniform(0,1)<.5:node=BESTCHILD(node,SCALAR)else:if node.fully_expanded(num_moves_lambda)==False:return EXPAND(node)else:node=BESTCHILD(node,SCALAR)return nodedef EXPAND(node):tried_children=[c.state for c in node.children]new_state=node.state.next_state()while new_state in tried_children and new_state.terminal()==False:new_state=node.state.next_state()node.add_child(new_state)return node.children[-1]#current this uses the most vanilla MCTS formula it is worth experimenting with THRESHOLD ASCENT (TAGS)
def BESTCHILD(node,scalar):bestscore=0.0bestchildren=[]for c in node.children:exploit=c.reward/c.visitsexplore=math.sqrt(2.0*math.log(node.visits)/float(c.visits))score=exploit+scalar*exploreif score==bestscore:bestchildren.append(c)if score>bestscore:bestchildren=[c]bestscore=scoreif len(bestchildren)==0:logger.warn("OOPS: no best child found, probably fatal")return random.choice(bestchildren)def DEFAULTPOLICY(state):while state.terminal()==False:state=state.next_state()return state.reward()def BACKUP(node,reward):while node!=None:node.visits+=1node.reward+=rewardnode=node.parentreturndef show_search_tree(root):dot = Digraph(comment='Game Search Tree')visited = set()que = queue.Queue()que.put(root)while not que.empty():node = que.get()node_id = str(id(node))if node_id in visited:continuevisited.add(node_id)# 添加节点dot.node(node_id, node.node_info())# 添加子节点和边for child in node.children:child_id = str(id(child))dot.node(child_id, child.node_info())dot.edge(node_id, child_id)que.put(child)with open("a.dot", "w", encoding="utf-8") as writer:writer.write(dot.source)dot.render('search_path', view=False)if __name__=="__main__":parser = argparse.ArgumentParser(description='MCTS research code')parser.add_argument('--num_sims', action="store", required=True, type=int)parser.add_argument('--levels', action="store", required=True, type=int, choices=range(State.NUM_TURNS+1))args=parser.parse_args()current_node=Node(State())for l in range(args.levels):current_node=UCTSEARCH(args.num_sims/(l+1),current_node)print("level %d"%l)print("Num Children: %d"%len(current_node.children))for i,c in enumerate(current_node.children):print(i,c)print("Best Child: %s"%current_node.state)print("--------------------------------")# 可视化搜索树结构show_search_tree(current_node)

运行方式终端

python mcts.py --num_sim 100 --levels 2

如果想知道10步决策的每个决策则将终端改为

python mcts.py --num_sims 100 --levels 10  

得到最终最优的选择策略。

http://www.lryc.cn/news/586613.html

相关文章:

  • 蓝牙调试抓包工具--nRF Connect移动端 使用详细总结
  • 生成式对抗网络(GAN)模型原理概述
  • Java生产带文字、带边框的二维码
  • 牛客:HJ19 简单错误记录[华为机考][字符串]
  • 009 ST表:静态区间最值的极致优化
  • 面试现场:奇哥扮猪吃老虎,RocketMQ高级原理吊打面试官
  • MyBatis实现分页查询-苍穹外卖笔记
  • comfyUI-controlNet-线稿软边缘
  • python-enumrate函数
  • HarmonyOS从入门到精通:动画设计与实现之六 - 动画曲线与运动节奏控制
  • houdini 用 vellum 制作一个最简单的布料
  • 洛谷题解 | UVA1485 Permutation Counting
  • C++结构体数组应用
  • Spring Boot 中使用 Lombok 进行依赖注入的示例
  • 基于springboot+Vue的二手物品交易的设计与实现(免费分享)
  • 2025年亚太杯(中文赛项)数学建模B题【疾病的预测与大数据分析】原创论文讲解(含完整python代码)
  • jieba 库:中文分词的利器
  • JAVA--双亲委派机制
  • 【springcloud】快速搭建一套分布式服务springcloudalibaba(四)
  • 【一起来学AI大模型】RAG系统流程:查询→向量化→检索→生成
  • 【AI News | 20250711】每日AI进展
  • 【TOOL】ubuntu升级cmake版本
  • AI产品经理面试宝典第12天:AI产品经理的思维与转型路径面试题与答法
  • 功耗校准数据PowerProfile测试方法建议
  • 【深度剖析】致力“四个最”的君乐宝数字化转型(下篇:转型成效5-打造数字化生存能力探索可持续发展路径)
  • VUE3 el-table 主子表 显示
  • Transformer基础
  • Openpyxl:Python操作Excel的利器
  • Qt 多线程编程:单例任务队列的设计与实现
  • 五、深度学习——CNN