分布变化的模仿学习算法
与传统监督学习不同,直接模仿学习在不同时刻所面临的数据分布可能不同.试设计一个考虑不同时刻数据分布变化的模仿学习算法
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from sklearn.metrics.pairwise import rbf_kernel
from sklearn.neighbors import KernelDensity
import matplotlib.pyplot as pltclass TimeAwareImitationLearning:def __init__(self, state_dim, action_dim, hidden_dim=64, device='cpu'):"""初始化时间感知的模仿学习算法state_dim: 状态维度action_dim: 动作维度hidden_dim: 隐藏层维度"""self.state_dim = state_dimself.action_dim = action_dimself.device = device# 策略网络 - 模仿专家行为self.policy = nn.Sequential(nn.Linear(state_dim + 1, hidden_dim), # +1 是为了包含时间信息nn.ReLU(),nn.Linear(hidden_dim, hidden_dim),nn.ReLU(),nn.Linear(hidden_dim, action_dim)).to(device)# 判别器网络 - 区分专家和策略生成的轨迹self.discriminator = nn.Sequential(nn.Linear(state_dim + action_dim + 1, hidden_dim), # +1 是为了包含时间信息nn.ReLU(),nn.Linear(hidden_dim, hidden_dim),nn.ReLU(),nn.Linear(hidden_dim, 1),nn.Sigmoid()).to(device)# 优化器self.policy_optimizer = optim.Adam(self.policy.parameters(), lr=1e-3)self.discriminator_optimizer = optim.Adam(self.discriminator.parameters(), lr=1e-3)# 记录训练过程self.train_losses = []def _compute_time_weights(self, expert_times, current_time, sigma=1.0):"""计算时间权重,距离当前时间越近的样本权重越大"""time_diffs = np.abs(expert_times - current_time)weights = np.exp(-time_diffs / (2 * sigma**2))return weights / np.sum(weights)def _compute_mmd_loss(self, expert_states, policy_states, times, current_time):"""计算最大均值差异(MMD)损失,衡量分布差异"""# 计算时间权重weights = self._compute_time_weights(times, current_time)# 对专家状态应用时间权重weighted_expert_states = expert_states * weights.reshape(-1, 1)# 计算MMDexpert_kernel = rbf_kernel(weighted_expert_states, weighted_expert_states)policy_kernel = rbf_kernel(policy_states, policy_states)cross_kernel = rbf_kernel(weighted_expert_states, policy_states)mmd = np.mean(expert_kernel) + np.mean(policy_kernel) - 2 * np.mean(cross_kernel)return mmddef train(self, expert_states, expert_actions, expert_times, epochs=100, batch_size=64):"""训练时间感知的模仿学习模型expert_states: 专家状态序列 [num_samples, state_dim]expert_actions: 专家动作序列 [num_samples, action_dim]expert_times: 专家时间戳 [num_samples]"""num_samples = expert_states.shape[0]expert_states_tensor = torch.FloatTensor(expert_states).to(self.device)expert_actions_tensor = torch.FloatTensor(expert_actions).to(self.device)expert_times_tensor = torch.FloatTensor(expert_times).reshape(-1, 1).to(self.device)for epoch in range(epochs):# 当前"时间" - 使用训练轮次的比例作为时间表示current_time = epoch / epochs# 生成策略动作policy_actions = []for i in range(0, num_samples, batch_size):batch_states = expert_states_tensor[i:i+batch_size]batch_times = torch.full((batch_states.shape[0], 1), current_time).to(self.device)policy_action = self.policy(torch.cat([batch_states, batch_times], dim=1))policy_actions.append(policy_action.detach().cpu().numpy())policy_actions = np.vstack(policy_actions)# 计算MMD损失mmd_loss = self._compute_mmd_loss(expert_states, policy_actions, expert_times, current_time)# 训练判别器for _ in range(5): # 判别器训练多次# 随机采样批次indices = np.random.randint(0, num_samples, batch_size)batch_expert_states = expert_states_tensor[indices]batch_expert_actions = expert_actions_tensor[indices]batch_expert_times = expert_times_tensor[indices]# 生成策略动作batch_times = torch.full((batch_size, 1), current_time).to(self.device)batch_policy_actions = self.policy(torch.cat([batch_expert_states, batch_times], dim=1))# 计算判别器损失expert_input = torch.cat([batch_expert_states, batch_expert_actions, batch_expert_times], dim=1)policy_input = torch.cat([batch_expert_states, batch_policy_actions, batch_times], dim=1)expert_output = self.discriminator(expert_input)policy_output = self.discriminator(policy_input)# 判别器损失 (最大化区分能力)d_loss = -torch.mean(torch.log(expert_output + 1e-8) + torch.log(1 - policy_output + 1e-8))self.discriminator_optimizer.zero_grad()d_loss.backward()self.discriminator_optimizer.step()# 训练策略网络for _ in range(1): # 策略网络训练较少次数indices = np.random.randint(0, num_samples, batch_size)batch_states = expert_states_tensor[indices]batch_times = torch.full((batch_size, 1), current_time).to(self.device)# 生成策略动作actions = self.policy(torch.cat([batch_states, batch_times], dim=1))# 计算策略损失 (最小化判别器的区分能力)policy_input = torch.cat([batch_states, actions, batch_times], dim=1)policy_output = self.discriminator(policy_input)# 策略损失 + MMD正则化p_loss = -torch.mean(torch.log(policy_output + 1e-8)) + 0.1 * mmd_lossself.policy_optimizer.zero_grad()p_loss.backward()self.policy_optimizer.step()# 记录损失self.train_losses.append(p_loss.item())if epoch % 100 == 0:print(f"Epoch {epoch}, Loss: {p_loss.item():.4f}, MMD: {mmd_loss:.4f}")def predict(self, state, time):"""根据当前状态和时间预测动作"""state_tensor = torch.FloatTensor(state).reshape(1, -1).to(self.device)time_tensor = torch.FloatTensor([time]).reshape(1, 1).to(self.device)with torch.no_grad():action = self.policy(torch.cat([state_tensor, time_tensor], dim=1))return action.cpu().numpy()[0]def visualize_training(self):"""可视化训练过程"""plt.figure(figsize=(10, 6))plt.plot(self.train_losses)plt.title('Training Loss')plt.xlabel('Epoch')plt.ylabel('Loss')plt.grid(True)plt.show()# 示例:生成具有时间分布变化的专家数据
def generate_time_varying_expert_data(num_samples=1000, state_dim=2, time_period=1.0):"""生成随时间变化的数据分布"""times = np.linspace(0, time_period, num_samples)states = []actions = []for t in times:# 状态分布随时间变化mean = np.array([np.sin(2 * np.pi * t), np.cos(2 * np.pi * t)])cov = np.diag([0.1 + 0.1 * np.abs(np.sin(np.pi * t)), 0.1 + 0.1 * np.abs(np.cos(np.pi * t))])state = np.random.multivariate_normal(mean, cov)# 动作是状态的函数,也随时间变化action = 2.0 * state * (1.0 + 0.5 * np.sin(2 * np.pi * t))states.append(state)actions.append(action)return np.array(states), np.array(actions), times# 测试算法
def test_time_aware_il():# 生成专家数据state_dim = 2action_dim = 2expert_states, expert_actions, expert_times = generate_time_varying_expert_data(num_samples=2000, state_dim=state_dim, time_period=1.0)# 创建并训练模型model = TimeAwareImitationLearning(state_dim, action_dim)model.train(expert_states, expert_actions, expert_times, epochs=500)# 可视化训练过程model.visualize_training()# 测试不同时间点的策略test_times = np.linspace(0, 1, 5)test_states = np.random.randn(len(test_times), state_dim)plt.figure(figsize=(12, 8))for i, t in enumerate(test_times):plt.subplot(2, 3, i+1)# 真实专家行为expert_mask = (expert_times >= t - 0.1) & (expert_times <= t + 0.1)plt.scatter(expert_states[expert_mask, 0], expert_states[expert_mask, 1], c='blue', alpha=0.5, label='Expert')# 模型预测行为pred_actions = np.array([model.predict(s, t) for s in expert_states[expert_mask]])plt.scatter(pred_actions[:, 0], pred_actions[:, 1], c='red', alpha=0.5, label='Policy')plt.title(f'Time = {t:.2f}')plt.xlabel('State 1')plt.ylabel('State 2')plt.legend()plt.tight_layout()plt.show()if __name__ == "__main__":test_time_aware_il()