多阶段报童问题动态规划求解,Python 实现
使用 python 编写了多阶段报童模型的动态规划算法。
- 使用了 python 的装饰器 @dataclass ,方便定义类
- 尝试使用并行计算,没有成功,极易出错。动态规划中使用并行计算,还是挺有挑战的;而且并行计算不一定总是比非并行运算速度快。
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Thu Nov 28 00:00:35 2024@author: zhenchen@Python version: 3.10@disp: stochastic dynamic programming to compute multi-period newsvendor problems;use @dataclass for ease of defining classes;parallel computing unsucessful, highly prone to make mistakes;
"""import scipy.stats as sp
from dataclasses import dataclass
from functools import lru_cache
import time@dataclass(frozen=True)
class State:"""state in a period: initial inventory """t: intiniInventory: float@dataclass
class Pmf:"""probability mass function for the demand distribution in each period"""truncQuantile: floatdistribution_type: str def get_pmf(self, distribution_parameters):"""Parameters----------distribution_parameters: list, may be multi dimensionalDESCRIPTION. parameter values of the distributionReturns-------pmf : 3-D listDESCRIPTION. probability mass function for the demand in each period"""if (self.distribution_type == 'poisson'): mean_demands = distribution_parametersmax_demands = [sp.poisson.ppf(self.truncQuantile, d).astype(int) for d in mean_demands]T = len(mean_demands)pmf = [[[k, sp.poisson.pmf(k, mean_demands[t])/self.truncQuantile] for k in range(max_demands[t])] for t in range(T)]return pmf@dataclass(eq = False)
class StochasticInventory:"""multi period stochastic inventory model class""" T: int capacity: float # maximum ordering quantityfixOrderCost: floatvariOrderCost: floatholdCost: floatpenaCost: floattruncationQ: floatmax_inventory: floatmin_inventory: floatpmf: [[[]]]cache_actions = {}def get_feasible_action(self, state:State):"""feasible actions for a certain state""" return range(self.capacity + 1)def state_tran(self, state:State, action, demand):"""state transition function""" nextInventory = state.iniInventory + action - demandnextInventory = self.max_inventory if self.max_inventory < nextInventory else nextInventorynextInventory = self.min_inventory if self.min_inventory > nextInventory else nextInventoryreturn State(state.t + 1, nextInventory)def imme_value(self, state:State, action, demand):"""immediate value function"""fixCost = self.fixOrderCost if action > 0 else 0variCost = self.variOrderCost * actionnextInventory = state.iniInventory + action - demandnextInventory = self.max_inventory if nextInventory > self.max_inventory else nextInventorynextInventory = self.min_inventory if nextInventory < self.min_inventory else nextInventoryholdingCost = self.holdCost * max(0, nextInventory)penaltyCost = self.penaCost * max(0, -nextInventory)return fixCost + variCost + holdingCost + penaltyCost# recursion@ lru_cache(maxsize = None)def f(self, state:State):"""recursive function"""bestQValue = float('inf')bestQ = 0for action in self.get_feasible_action(state):thisQValue = 0for randDandP in self.pmf[state.t - 1]:thisQValue += randDandP[1] * self.imme_value(state, action, randDandP[0])if state.t < T:thisQValue += randDandP[1] * self.f(self.state_tran(state, action, randDandP[0]))if thisQValue < bestQValue:bestQValue = thisQValuebestQ = actionself.cache_actions[str(state)] = bestQreturn bestQValuedemands = [10, 20, 10, 20]
distribution_type = 'poisson'
capacity = 100 # maximum ordering quantity
fixOrderCost = 0
variOderCost = 1
holdCost = 2
penaCost = 10
truncQuantile = 0.9999 # trancated quantile for the demand distribution
maxI = 500 # maximum possible inventory
minI = -300 # minimum possible inventorypmf = Pmf(truncQuantile, distribution_type).get_pmf(demands)
T = len(demands)if __name__ == '__main__': start = time.process_time()model = StochasticInventory(T,capacity, fixOrderCost, variOderCost,holdCost, penaCost, truncQuantile,maxI, minI,pmf)ini_state = State(1, 0)expect_total_cost = model.f(ini_state)print('****************************************')print('final expected total cost is %.2f' % expect_total_cost)optQ = model.cache_actions[str(State(1, 0))]print('optimal Q_1 is %.2f' % optQ)end = time.process_time()cpu_time = end - startprint('cpu time is %.4f s' % cpu_time)