vllm 的 cpu offload 参数卸载逻辑代码简析

本文最后更新于 2025年3月2日 下午

最近在看 vllm 的代码,cpu offloading 这部分它的实现还是比较简单的,这里简单记录一下。

由于大模型的参数量实在很大,所以如果想在单机上运行一般都需要跑量化蒸馏后的模型,但是有时又不想牺牲模型质量,于是CPU/SSD 卸载成为一种折衷方案,通过增加推理时间来降低内存需求。

vllm 也实现了一个简单的 cpu offload 的机制,可以通过 --cpu-offload-gb 启用。

这里发现他的实现还是比较简单的,主要就是通过这个 PR ,添加了一个 func 叫 maybe_offload_to_gpu:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module:
device = next(module.parameters()).device

if device == torch.device("cpu"):
return module

global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES
if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES:
return module

pin_memory = is_pin_memory_available()

# offload parameters to CPU
# use pin_memory if possible, which helps cudagraph capture speed
offloaded_parameters = False
for p in module.parameters():
if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES:
# we use per-parameter offloading
# one module might have some parameters offloaded and some not
break

# `torch.empty_like` does not support `pin_memory` argument
cpu_data = torch.empty_strided(size=p.data.size(),
stride=p.data.stride(),
dtype=p.data.dtype,
layout=p.data.layout,
device='cpu',
pin_memory=pin_memory)
cpu_data.copy_(p.data)
p.data = cpu_data
_CPU_OFFLOAD_BYTES += p.data.numel() * p.data.element_size()
offloaded_parameters = True

if offloaded_parameters:
original_forward = module.forward

def forward(*args, **kwargs):
module.forward = original_forward
device_state = {
# here we blindly call `to(device)`
# if the parameter is already on the device, it will be a no-op
k: v.to(device, non_blocking=True)
for k, v in module.state_dict().items()
}
output = functional_call(module,
device_state,
args=args,
kwargs=kwargs)
module.forward = forward
return output

module.forward = forward

return module

这个函数只有一个地方调用了:

1
2
3
4
5
modules = torch.nn.ModuleList(
[PPMissingLayer() for _ in range(start_layer)] + [
maybe_offload_to_cpu(layer_fn(prefix=f"{prefix}.{idx}"))
for idx in range(start_layer, end_layer)
] + [PPMissingLayer() for _ in range(end_layer, num_hidden_layers)])

cpu offload 的整个流程可以概括为:

  1. 将传入的 cpu_offload_gb 读取为 _CPU_OFFLOAD_MAX_BYTES
  2. 然后在构建 Module 时对每个 Layer 里面的参数从前往后依次塞到 CPU 的 PIN_MEMORY 上,并累加参数大小 (_CPU_OFFLOAD_BYTES += p.data.numel() * p.data.element_size(), 即参数数量 x 字节大小) 到 _CPU_OFFLOAD_BYTES,直到超过用户配置的可 offload 大小。
  3. 替换对应 Module 的 forward 函数,新的 forward 函数和原来的区别就是:在每次 forward 的时候,将 CPU 上的参数复制到 GPU 上,计算完后再释放

一些可能的理解难点

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
if offloaded_parameters:
original_forward = module.forward

def forward(*args, **kwargs):
module.forward = original_forward
device_state = {
# here we blindly call `to(device)`
# if the parameter is already on the device, it will be a no-op
k: v.to(device, non_blocking=True)
for k, v in module.state_dict().items()
}
output = functional_call(module,
device_state,
args=args,
kwargs=kwargs)
module.forward = forward
return output

module.forward = forward

这里的 forward 替换设计挺有意思:

  1. original_forward = module.forward):保存模块原始的前向传播函数到变量 original_forward,为后续恢复原始实现做准备。

  2. (在 forward 函数内的 module.forward = original_forward):在自定义的 forward 函数内部首行,当函数被调用时,立即将模块的前向传播方法恢复为原始方法。这一步确保 functional_call 能够使用原始的前向传播逻辑,避免递归调用。

  3. (在 forward 函数内末尾的 module.forward = forward):前向传播完成后,再次将模块的前向传播方法设置回自定义的 forward 函数,确保下次调用时仍能触发这个包含参数加载逻辑的自定义函数。

  4. (函数末尾的 module.forward = forward):为了确保每次都能 hook 到 cpu -> gpu 这部分的逻辑 (参考原 PR 的解释:https://github.com/vllm-project/vllm/pull/6496/files#r1682069375)

这里的核心是使用 functional_call 配合临时的 device_state,实现参数从 CPU 到 GPU 的动态加载。初看可能会疑惑为什么没有将参数从 GPU 释放回 CPU 的逻辑,其实这是因为 device_state 是个局部变量,函数执行完毕后其引用的 GPU 张量会自动被 Python 的垃圾回收机制释放。而原始参数依然保存在 CPU 内存中,从而实现了内存优化的目的。

同时还有一个小知识点:

  1. 对于 nn.Module, .cuda(), .cpu().to(device) 方法都是就地操作,因此可以直接 model.cuda()
  2. 而对于 torch.Tensor , nn.Parameter 这些,.to(device), .cpu(), .gpu() 方法本质都是创建一个新的 Tensor,原始 Tensor 不变,因此也必须使用 a = b.cuda()

vllm 的 cpu offload 参数卸载逻辑代码简析
https://moreality.net/posts/28374/
作者
Moreality
发布于
2025年3月2日
许可协议