Skip to content

Commit f8d0829

Browse files
fix zcc error in eb5 (#11193)
1 parent b0276d7 commit f8d0829

File tree

2 files changed

+130
-21
lines changed

2 files changed

+130
-21
lines changed

paddlenlp/trainer/trainer.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -471,8 +471,8 @@ def _save_ckpt_func(state_dict, path, signal_path=None):
471471
os.getenv("FLAG_LLM_PDC", "False")
472472
), "Dont support FLAG_LLM_PDC when using zero cost checkpoint"
473473
assert (
474-
self.args.should_save_sharding_stage1_model
475-
), "should_save_sharding_stage1_model should be True when using zero cost checkpoint"
474+
self.args.should_save_sharding_stage1_model or self.args.save_checkpoint_format == "flex_checkpoint"
475+
), "should_save_sharding_stage1_model should be True or save_checkpoint_format is flex_checkpoint when using zero cost checkpoint"
476476
assert (
477477
ShardingOption.FULL_SHARD not in self.args.sharding
478478
), "FULL_SHARD is not supported when using flash save mode"
@@ -810,11 +810,17 @@ def get_metadata_file_name(path):
810810
master_weights_path,
811811
aoa_config=self.args.aoa_config,
812812
offload=self.args.load_via_cpu,
813+
comm_method=self.args.comm_method,
813814
)
814815

815816
self._load_scheduler(resume_from_checkpoint)
816817

817-
should_load_stage1 = self.args.should_load_sharding_stage1_model
818+
from .trainer_utils import ShardingOption
819+
820+
should_load_stage1 = self.args.sharding_parallel_degree > 1 and ShardingOption.SHARD_OP in self.args.sharding
821+
logger.debug(f"should_load_stage1 = {should_load_stage1}")
822+
logger.debug(f"sharded_model_from_ema = {self.args.sharded_model_from_ema}")
823+
818824
if should_load_stage1 and self.args.sharded_model_from_ema:
819825
ema_states_path = os.path.join(resume_from_checkpoint, EMA_STATE_DIC, f"{dist.get_rank()}_0.distcp")
820826
ema_state_dict = paddle.load(ema_states_path)
@@ -829,11 +835,23 @@ def get_metadata_file_name(path):
829835
ema_state_dict = reshard_util.all_gather_state_dict(ema_state_dict, lambda x: True, self.sharding_group)
830836
self.model.set_state_dict(ema_state_dict)
831837
else:
838+
839+
def bf16_filtered_sharded_state_dict(sharded_state_dict):
840+
new_state_dict = {}
841+
for k, v in sharded_state_dict.items():
842+
if v.local_tensor.dtype == paddle.bfloat16:
843+
continue
844+
new_state_dict[k] = v
845+
return new_state_dict
846+
847+
fp32_sharded_state_dict = bf16_filtered_sharded_state_dict(model_sharded_state_dict)
848+
832849
dist.load_state_dict(
833-
model_sharded_state_dict,
850+
fp32_sharded_state_dict,
834851
model_states_path,
835852
aoa_config=self.args.aoa_config,
836853
offload=self.args.load_via_cpu,
854+
comm_method=self.args.comm_method,
837855
)
838856

839857
if self.args.bf16 and (not self.args.ignore_load_lr_and_optim) and should_load_stage1:
@@ -871,7 +889,7 @@ def recover_params_from_master_weight(opt_state_dict, group):
871889

872890
model_state_dict = self.model.state_dict()
873891
for key, param in model_state_dict.items():
874-
if param.name in master_weights:
892+
if param.name in master_weights and param.dtype == paddle.bfloat16:
875893
logger.debug(
876894
f"key {key}, convert master weights {param.name} shape {master_weights[param.name].shape} to param {param.name} shape{param.shape}"
877895
)
@@ -921,6 +939,10 @@ def _save_flex_optimizer_state(self, output_dir):
921939
master_weights_path,
922940
)
923941

942+
saved_signal_path = os.path.join(output_dir, f"saved_signal_{dist.get_rank()}")
943+
with open(saved_signal_path, mode="w+") as f:
944+
f.write("1")
945+
924946
def _load_from_checkpoint(self, resume_from_checkpoint=None):
925947
"""load state_dict from_checkpoint, Only load model state dict.
926948

paddlenlp/trainer/utils/zero_cost_checkpoint.py

Lines changed: 103 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1351,7 +1351,7 @@ def gather_distributed_model_meta(self, model, optimizer):
13511351
if not self.args.use_hybrid_parallel:
13521352
return None
13531353

1354-
if not self.args.should_save_sharding_stage1_model:
1354+
if not (self.args.should_save_sharding_stage1_model or self.args.save_checkpoint_format == "flex_checkpoint"):
13551355
return None
13561356

13571357
nranks = dist.get_world_size()
@@ -1453,7 +1453,7 @@ def saved_ckptmeta(state_dict, ckpt_file_name, process_group=None):
14531453
metadata.state_dict_metadata = merge_state_dict_metadata(global_state_dict_metadata)
14541454
metadata.storage_metadata = balanced_dedup_key_in_dict(global_storage_metadata)
14551455
metadata.flat_mapping = dedup_key_in_dict(global_flatten_mapping)
1456-
logger.debug(f"metadata:{metadata}")
1456+
# logger.debug(f"metadata:{metadata}")
14571457

14581458
def _gen_filter_map():
14591459
for tensor_index, file_name in metadata.storage_metadata.items():
@@ -1463,7 +1463,7 @@ def _gen_filter_map():
14631463
local_state_dict_filter_map[tensor_index.tensor_key] = True
14641464

14651465
_gen_filter_map()
1466-
logger.debug(f"local_state_dict_filter_map:{local_state_dict_filter_map}")
1466+
# logger.debug(f"local_state_dict_filter_map:{local_state_dict_filter_map}")
14671467

14681468
return metadata, local_state_dict_filter_map
14691469

@@ -1491,13 +1491,14 @@ def _manipulate_state_dict_and_config(self, model_to_save, optimizer):
14911491
filter_sharded_params,
14921492
)
14931493

1494-
filter_sharded_params = sharded_state_dict_compatibility(filter_sharded_params, return_sharded_state_dict=True)
1495-
exclude_parameters_in_state_dict = sharded_state_dict_compatibility(
1496-
exclude_parameters_in_state_dict, return_sharded_state_dict=True
1497-
)
1494+
# filter_sharded_params = sharded_state_dict_compatibility(filter_sharded_params, return_sharded_state_dict=True)
1495+
# exclude_parameters_in_state_dict = sharded_state_dict_compatibility(
1496+
# exclude_parameters_in_state_dict, return_sharded_state_dict=True
1497+
# )
14981498

1499-
state_dict = model_to_save.sharded_state_dict()
1500-
if self.args.should_save_sharding_stage1_model:
1499+
state_dict = model_to_save.state_dict()
1500+
# tmp wa should_save_sharding_stage1_model
1501+
if self.args.should_save_sharding_stage1_model or self.args.save_checkpoint_format == "flex_checkpoint":
15011502
state_dict = split_model_state(state_dict, group_getter)
15021503
for gid in gids:
15031504
state_dict[gid] = filter_sharded_params(
@@ -1508,7 +1509,10 @@ def _manipulate_state_dict_and_config(self, model_to_save, optimizer):
15081509
)
15091510
state_dict = merge_model_state(state_dict)
15101511

1511-
if self.args.bf16 and self.args.should_save_sharding_stage1_model:
1512+
# tmp wa should_save_sharding_stage1_model
1513+
if self.args.bf16 and (
1514+
self.args.should_save_sharding_stage1_model or self.args.save_checkpoint_format == "flex_checkpoint"
1515+
):
15121516
param_names_in_master_weights = []
15131517
optimzier_state_dict = optimizer.state_dict()
15141518
optimzier_state_dict = split_opt_state(optimzier_state_dict, group_getter)
@@ -1531,11 +1535,18 @@ def _manipulate_state_dict_and_config(self, model_to_save, optimizer):
15311535
return state_dict
15321536

15331537
def _cache_meta_for_sharded_save(self, model, optimizer):
1534-
# TODO(): fix later.
15351538
logger.info("Start caching metas for sharded save...")
15361539
(self.manipulated_state_dict) = self._manipulate_state_dict_and_config(model, optimizer)
15371540

1538-
logger.debug(f"manipulated_state_dict: {self.manipulated_state_dict.keys()}")
1541+
def recover_sharded_state_dict():
1542+
filtered_sharded_state_dict = {}
1543+
model_sharded_state_dict = model.sharded_state_dict()
1544+
for k, v in self.manipulated_state_dict.items():
1545+
filtered_sharded_state_dict[k] = model_sharded_state_dict[k]
1546+
return filtered_sharded_state_dict
1547+
1548+
self.manipulated_state_dict = recover_sharded_state_dict()
1549+
15391550
logger.info("Cache manipulated static dict done.")
15401551

15411552
model_to_save = unwrap_model(model)
@@ -1577,10 +1588,36 @@ def create_ckpt_file_name():
15771588
self.master_weight_ckpt_meta, self.master_weights_filter = saved_ckptmeta(master_weights, self.ckpt_data_name)
15781589

15791590
# gen unified name mapping for optimzier
1580-
self.unified_name_mapping = self._gen_unified_name(optimizer, model.sharded_state_dict())
1591+
self.unified_name_mapping, self.param_slice_info = self._gen_unified_name(
1592+
optimizer, model.sharded_state_dict()
1593+
)
15811594
logger.info("Cache distributed model meta done.")
15821595

15831596
def _gen_unified_name(self, optimizer, model_sharded_state_dict):
1597+
param_slice_info = {}
1598+
padded_param = set()
1599+
for buffer in optimizer._comm_buffer_list:
1600+
for (
1601+
param_name,
1602+
grad_view,
1603+
) in buffer._sharding_param_grad_view.items():
1604+
numel = grad_view._param.numel().item()
1605+
param_begin = grad_view._param_begin
1606+
param_end = grad_view._param_end
1607+
index = grad_view._index
1608+
padding_begin = index + numel
1609+
flattened_range = slice(
1610+
param_begin - index,
1611+
max(
1612+
min(padding_begin - index, param_end - index),
1613+
param_begin - index,
1614+
),
1615+
)
1616+
if param_end > padding_begin:
1617+
padded_param.add(param_name)
1618+
1619+
param_slice_info[param_name] = flattened_range
1620+
15841621
_FP32_MASTER = "fp32_master_0"
15851622
_optimizer_scalar_name = [
15861623
"beta1_pow_acc_0",
@@ -1600,29 +1637,44 @@ def _generate_base_static_name(vname):
16001637
return vname[: -(len(name) + 1)], name
16011638
raise ValueError(f"Cannot split variable name: {vname}.")
16021639

1640+
model_sharded_state_dict = dict(sorted(model_sharded_state_dict.items()))
16031641
static_to_struct_mapping = {}
16041642
for k, v in model_sharded_state_dict.items():
16051643
if v.local_tensor.name not in static_to_struct_mapping:
16061644
static_to_struct_mapping[v.local_tensor.name] = k
16071645

16081646
optimizer_state_dict = optimizer.state_dict()
16091647
optimizer_unified_name_mapping = {}
1648+
unified_slice_info = {}
16101649

16111650
master_weights = optimizer_state_dict.pop("master_weights", None)
16121651
optimizer_state_dict.pop("LR_Scheduler", None)
16131652
for key, _ in optimizer_state_dict.items():
16141653
static_name, optim_state_type = _generate_base_static_name(key)
16151654
struct_name = static_to_struct_mapping[static_name]
16161655
unified_name = f"{struct_name}.{optim_state_type}"
1656+
1657+
flattened_range = param_slice_info[static_name]
1658+
1659+
# if flattened_range.stop - flattened_range.start == 0:
1660+
# continue
16171661
optimizer_unified_name_mapping[key] = unified_name
1662+
unified_slice_info[unified_name] = flattened_range
16181663

16191664
if master_weights is not None:
16201665
for key, _ in master_weights.items():
16211666
struct_name = static_to_struct_mapping[key]
16221667
unified_name = f"{struct_name}.w_0"
1668+
1669+
flattened_range = param_slice_info[key]
1670+
1671+
# if flattened_range.stop - flattened_range.start == 0:
1672+
# continue
1673+
16231674
optimizer_unified_name_mapping[key] = unified_name
1675+
unified_slice_info[unified_name] = flattened_range
16241676

1625-
return optimizer_unified_name_mapping
1677+
return optimizer_unified_name_mapping, unified_slice_info
16261678

16271679
def _pack_dynamic_objects(self):
16281680
dynamic_objecs = {}
@@ -1641,6 +1693,7 @@ def _pack_dynamic_objects(self):
16411693
dynamic_objecs["master_weights_filter"] = self.master_weights_filter
16421694

16431695
dynamic_objecs["unified_name_mapping"] = self.unified_name_mapping
1696+
dynamic_objecs["param_slice_info"] = self.param_slice_info
16441697

16451698
return dynamic_objecs
16461699

@@ -1682,6 +1735,7 @@ def process_update_task(self, updates):
16821735
self.master_weights_filter = dynamic_objecs["master_weights_filter"]
16831736

16841737
self.unified_name_mapping = dynamic_objecs["unified_name_mapping"]
1738+
self.param_slice_info = dynamic_objecs["param_slice_info"]
16851739

16861740
optimizer_states_meta = dynamic_objecs["optimizer_states_meta"]
16871741
model_states_meta = dynamic_objecs["model_states_meta"]
@@ -1708,12 +1762,34 @@ def _replace_pname_with_unified(self, state_dict):
17081762
def _filter_state_dict(state_dict, filter_map):
17091763
need_remove_keys = []
17101764
for k, _ in state_dict.items():
1711-
if filter_map[k]:
1765+
# two case:
1766+
# 1. Mutliple key share the same tensor.
1767+
# 2. Don't need to be saved in current rank.
1768+
if k not in filter_map.keys():
1769+
logger.debug(f"[ZCC worker] {k} not exist in filter map.")
1770+
if (k not in filter_map.keys()) or filter_map[k]:
17121771
need_remove_keys.append(k)
17131772
for k in need_remove_keys:
17141773
state_dict.pop(k)
17151774
return state_dict
17161775

1776+
@staticmethod
1777+
def _slice_padded_tensor(static_dict, param_slice_info):
1778+
new_static_dict = {}
1779+
for k, v in static_dict.items():
1780+
if k in param_slice_info:
1781+
logger.info(f"[ZCC worker] Slice padded tensor of {k}")
1782+
flattened_range = param_slice_info[k]
1783+
new_static_dict[k] = paddle.slice(
1784+
v,
1785+
axes=[0],
1786+
starts=[0],
1787+
ends=[flattened_range.stop - flattened_range.start],
1788+
)
1789+
else:
1790+
new_static_dict[k] = v
1791+
return new_static_dict
1792+
17171793
def _save_model_state(self, output_dir):
17181794
data_file_name, meta_file_name = self.distcp_file_name
17191795
self.model_states_path = os.path.join(output_dir, "model_state", data_file_name)
@@ -1722,9 +1798,10 @@ def _save_model_state(self, output_dir):
17221798
if self.dp_rank <= 0 or self.use_expert_parallel:
17231799
with device_guard("cpu"):
17241800
state_dict = self.param_fusion_storage_helper.state_dict()
1801+
logger.debug(f"model states key before filter is {state_dict.keys()}")
17251802

17261803
state_dict = self._filter_state_dict(state_dict, self.model_state_filter)
1727-
1804+
logger.debug(f"model states length is {len(state_dict)}")
17281805
paddle.save(state_dict, self.model_states_path)
17291806

17301807
if self.device_id == 0:
@@ -1750,12 +1827,21 @@ def _save_opt_state(self, output_dir):
17501827
master_weights = self._replace_pname_with_unified(master_weights)
17511828
logger.info("[ZCC worker] master weightsdict replace pname using unified name.")
17521829

1830+
opt_state_dict = self._slice_padded_tensor(opt_state_dict, self.param_slice_info)
1831+
logger.info("[ZCC worker] opt state dict slice padded tensor complete.")
1832+
master_weights = self._slice_padded_tensor(master_weights, self.param_slice_info)
1833+
logger.info("[ZCC worker] master weights slice padded tensor complete.")
1834+
17531835
if self.dp_rank > 0: # ep
17541836
opt_state_dict = self._filter_moe_no_sync_optimizer_params(self.model_meta_content, opt_state_dict)
17551837

17561838
opt_state_dict = self._filter_state_dict(opt_state_dict, self.opt_state_filter)
1839+
logger.info("[ZCC worker] opt state dict filter by opt_state_filter complete.")
17571840
master_weights = self._filter_state_dict(master_weights, self.master_weights_filter)
1841+
logger.info("[ZCC worker] master weights dict filter by master_weights_filter complete.")
17581842

1843+
logger.debug(f"opt states length is {len(opt_state_dict)}")
1844+
logger.debug(f"master weights length is {len(master_weights)}")
17591845
paddle.save(opt_state_dict, self.opt_state_path)
17601846
paddle.save(master_weights, self.master_weight_path)
17611847
if self.device_id == 0:
@@ -1771,6 +1857,7 @@ def _save_ema_state(self, output_dir):
17711857

17721858
if self.dp_rank > 0:
17731859
ema_state_dict = self._filter_moe_no_sync_optimizer_params(self.model_meta_content, ema_state_dict)
1860+
logger.debug(f"ema states length is {len(ema_state_dict)}")
17741861
paddle.save(ema_state_dict, self.ema_name_path)
17751862
logger.info("[ZCC worker] Finish ema states saved.")
17761863

0 commit comments

Comments
 (0)