错误处理_IncompatibleKeys
错误重现
Exception has occurred: AttributeError
'_IncompatibleKeys' object has no attribute 'act'
File "E:\pycharm\python_doc\learntoconverse_reconstructed\utils.py", line 356, in <listcomp>
results = [async_result.get() for async_result in async_results]
File "E:\pycharm\python_doc\learntoconverse_reconstructed\utils.py", line 356, in rollout_workers
results = [async_result.get() for async_result in async_results]
File "E:\pycharm\python_doc\learntoconverse_reconstructed\train.py", line 65, in <module>
results = rollout_workers(envs_s=envs, epsilon_table=epsilon_table, original_policy=policy, config=config)
AttributeError: '_IncompatibleKeys' object has no attribute 'act'
错误分析
这个错误表明当尝试调用 policy.act()
方法时,传入的 policy
对象实际上是一个 _IncompatibleKeys
对象(这是 PyTorch 在 load_state_dict()
失败时返回的对象),而不是预期的 AttentionPolicy
实例。
问题分析
-
错误根源:
- 在
rollout_workers
函数中,这一行代码有问题:copy.deepcopy(original_policy_copy).load_state_dict(add_state_dict(...))
load_state_dict()
返回的是_IncompatibleKeys
对象(当状态字典不匹配时),而不是模型本身
- 在
-
具体问题:
- 直接使用了
load_state_dict()
的返回值作为 policy 对象 - 应该先加载状态字典,然后返回模型实例
- 直接使用了
解决方案
修改 rollout_workers
函数如下:
def rollout_workers(envs_s:list, epsilon_table:list, original_policy:AttentionPolicy, config:Config):num_task = config.evaluate_time + config.n_directionsnum_worker = config.num_workerresults = []original_policy_copy = copy.deepcopy(original_policy).to("cpu")with Pool(processes=num_worker) as pool:epsilon_ids = range(num_task)async_results = []for i in epsilon_ids:# 创建policy副本policy_copy = copy.deepcopy(original_policy_copy)# 如果是扰动方向,加载扰动后的参数if i < config.n_directions:# 获取扰动后的state_dictperturbed_state_dict = add_state_dict(policy_copy.state_dict(),epsilon_table[i],"cpu")# 加载参数但不使用返回值policy_copy.load_state_dict(perturbed_state_dict)# 提交任务async_results.append(pool.apply_async(rollout_envs,args=(envs_s[i],policy_copy, # 这里传入的是模型实例,不是load_state_dict的返回值config,i)))results = [async_result.get() for async_result in async_results]return results
关键修改点
-
分离模型复制和参数加载:
- 先创建模型副本
policy_copy = copy.deepcopy(original_policy_copy)
- 然后单独加载参数
policy_copy.load_state_dict(perturbed_state_dict)
- 先创建模型副本
-
正确处理返回值:
- 不再使用
load_state_dict()
的返回值 - 直接传递模型实例
policy_copy
给rollout_envs
- 不再使用
-
更清晰的流程:
- 使用显式的循环而不是列表推导式,提高可读性
- 分离扰动参数的应用和任务提交
为什么原代码会失败
PyTorch 的 load_state_dict()
方法在遇到不匹配的键时会返回 _IncompatibleKeys
对象(包含 missing_keys
和 unexpected_keys
),而不是模型本身。当后续代码尝试调用这个返回对象的 act()
方法时,就会抛出 AttributeError
。
这个修改确保了始终传递正确的模型实例给工作进程,同时保持了原有的功能逻辑。