AF3 checkpoint_blocks函数解读
checkpoint_blocks
函数实现了一种分块梯度检查点机制 (checkpoint_blocks
),目的是通过分块(chunking)执行神经网络模块,减少内存使用。在深度学习训练中,梯度检查点(activation checkpointing)是一种显存优化技术。该代码可以:
- 对神经网络的块(blocks)按需分块,并对每块应用梯度检查点。
- 动态调整计算开销与显存占用的权衡。
1. 源代码:
from typing import Any, Tuple, List, Callable, Optional
import torch
import torch.utils.checkpoint
import functoolstry:import deepspeeddeepspeed_is_installed = True
except ImportError:deepspeed_is_installed = FalseBLOCK_ARG = Any
BLOCK_ARGS = Tuple[BLOCK_ARG, ...] # List[BLOCK_ARGS]def get_checkpoint_fn():return torch.utils.checkpoint.checkpoint # deepspeed.checkpointing.checkpointdef checkpoint_blocks(blocks: List[Callable],args: BLOCK_ARGS,blocks_per_ckpt: Optional[int],
) -> BLOCK_ARGS:"""Chunk a list of b