global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES: return module
pin_memory = is_pin_memory_available()
# offload parameters to CPU # use pin_memory if possible, which helps cudagraph capture speed offloaded_parameters = False for p in module.parameters(): if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES: # we use per-parameter offloading # one module might have some parameters offloaded and some not break
# `torch.empty_like` does not support `pin_memory` argument cpu_data = torch.empty_strided(size=p.data.size(), stride=p.data.stride(), dtype=p.data.dtype, layout=p.data.layout, device='cpu', pin_memory=pin_memory) cpu_data.copy_(p.data) p.data = cpu_data _CPU_OFFLOAD_BYTES += p.data.numel() * p.data.element_size() offloaded_parameters = True
if offloaded_parameters: original_forward = module.forward
defforward(*args, **kwargs): module.forward = original_forward device_state = { # here we blindly call `to(device)` # if the parameter is already on the device, it will be a no-op k: v.to(device, non_blocking=True) for k, v in module.state_dict().items() } output = functional_call(module, device_state, args=args, kwargs=kwargs) module.forward = forward return output
if offloaded_parameters: original_forward = module.forward
defforward(*args, **kwargs): module.forward = original_forward device_state = { # here we blindly call `to(device)` # if the parameter is already on the device, it will be a no-op k: v.to(device, non_blocking=True) for k, v in module.state_dict().items() } output = functional_call(module, device_state, args=args, kwargs=kwargs) module.forward = forward return output