@@ -1351,7 +1351,7 @@ def gather_distributed_model_meta(self, model, optimizer):
13511351 if not self .args .use_hybrid_parallel :
13521352 return None
13531353
1354- if not self .args .should_save_sharding_stage1_model :
1354+ if not ( self .args .should_save_sharding_stage1_model or self . args . save_checkpoint_format == "flex_checkpoint" ) :
13551355 return None
13561356
13571357 nranks = dist .get_world_size ()
@@ -1453,7 +1453,7 @@ def saved_ckptmeta(state_dict, ckpt_file_name, process_group=None):
14531453 metadata .state_dict_metadata = merge_state_dict_metadata (global_state_dict_metadata )
14541454 metadata .storage_metadata = balanced_dedup_key_in_dict (global_storage_metadata )
14551455 metadata .flat_mapping = dedup_key_in_dict (global_flatten_mapping )
1456- logger .debug (f"metadata:{ metadata } " )
1456+ # logger.debug(f"metadata:{metadata}")
14571457
14581458 def _gen_filter_map ():
14591459 for tensor_index , file_name in metadata .storage_metadata .items ():
@@ -1463,7 +1463,7 @@ def _gen_filter_map():
14631463 local_state_dict_filter_map [tensor_index .tensor_key ] = True
14641464
14651465 _gen_filter_map ()
1466- logger .debug (f"local_state_dict_filter_map:{ local_state_dict_filter_map } " )
1466+ # logger.debug(f"local_state_dict_filter_map:{local_state_dict_filter_map}")
14671467
14681468 return metadata , local_state_dict_filter_map
14691469
@@ -1491,13 +1491,14 @@ def _manipulate_state_dict_and_config(self, model_to_save, optimizer):
14911491 filter_sharded_params ,
14921492 )
14931493
1494- filter_sharded_params = sharded_state_dict_compatibility (filter_sharded_params , return_sharded_state_dict = True )
1495- exclude_parameters_in_state_dict = sharded_state_dict_compatibility (
1496- exclude_parameters_in_state_dict , return_sharded_state_dict = True
1497- )
1494+ # filter_sharded_params = sharded_state_dict_compatibility(filter_sharded_params, return_sharded_state_dict=True)
1495+ # exclude_parameters_in_state_dict = sharded_state_dict_compatibility(
1496+ # exclude_parameters_in_state_dict, return_sharded_state_dict=True
1497+ # )
14981498
1499- state_dict = model_to_save .sharded_state_dict ()
1500- if self .args .should_save_sharding_stage1_model :
1499+ state_dict = model_to_save .state_dict ()
1500+ # tmp wa should_save_sharding_stage1_model
1501+ if self .args .should_save_sharding_stage1_model or self .args .save_checkpoint_format == "flex_checkpoint" :
15011502 state_dict = split_model_state (state_dict , group_getter )
15021503 for gid in gids :
15031504 state_dict [gid ] = filter_sharded_params (
@@ -1508,7 +1509,10 @@ def _manipulate_state_dict_and_config(self, model_to_save, optimizer):
15081509 )
15091510 state_dict = merge_model_state (state_dict )
15101511
1511- if self .args .bf16 and self .args .should_save_sharding_stage1_model :
1512+ # tmp wa should_save_sharding_stage1_model
1513+ if self .args .bf16 and (
1514+ self .args .should_save_sharding_stage1_model or self .args .save_checkpoint_format == "flex_checkpoint"
1515+ ):
15121516 param_names_in_master_weights = []
15131517 optimzier_state_dict = optimizer .state_dict ()
15141518 optimzier_state_dict = split_opt_state (optimzier_state_dict , group_getter )
@@ -1531,11 +1535,18 @@ def _manipulate_state_dict_and_config(self, model_to_save, optimizer):
15311535 return state_dict
15321536
15331537 def _cache_meta_for_sharded_save (self , model , optimizer ):
1534- # TODO(): fix later.
15351538 logger .info ("Start caching metas for sharded save..." )
15361539 (self .manipulated_state_dict ) = self ._manipulate_state_dict_and_config (model , optimizer )
15371540
1538- logger .debug (f"manipulated_state_dict: { self .manipulated_state_dict .keys ()} " )
1541+ def recover_sharded_state_dict ():
1542+ filtered_sharded_state_dict = {}
1543+ model_sharded_state_dict = model .sharded_state_dict ()
1544+ for k , v in self .manipulated_state_dict .items ():
1545+ filtered_sharded_state_dict [k ] = model_sharded_state_dict [k ]
1546+ return filtered_sharded_state_dict
1547+
1548+ self .manipulated_state_dict = recover_sharded_state_dict ()
1549+
15391550 logger .info ("Cache manipulated static dict done." )
15401551
15411552 model_to_save = unwrap_model (model )
@@ -1577,10 +1588,36 @@ def create_ckpt_file_name():
15771588 self .master_weight_ckpt_meta , self .master_weights_filter = saved_ckptmeta (master_weights , self .ckpt_data_name )
15781589
15791590 # gen unified name mapping for optimzier
1580- self .unified_name_mapping = self ._gen_unified_name (optimizer , model .sharded_state_dict ())
1591+ self .unified_name_mapping , self .param_slice_info = self ._gen_unified_name (
1592+ optimizer , model .sharded_state_dict ()
1593+ )
15811594 logger .info ("Cache distributed model meta done." )
15821595
15831596 def _gen_unified_name (self , optimizer , model_sharded_state_dict ):
1597+ param_slice_info = {}
1598+ padded_param = set ()
1599+ for buffer in optimizer ._comm_buffer_list :
1600+ for (
1601+ param_name ,
1602+ grad_view ,
1603+ ) in buffer ._sharding_param_grad_view .items ():
1604+ numel = grad_view ._param .numel ().item ()
1605+ param_begin = grad_view ._param_begin
1606+ param_end = grad_view ._param_end
1607+ index = grad_view ._index
1608+ padding_begin = index + numel
1609+ flattened_range = slice (
1610+ param_begin - index ,
1611+ max (
1612+ min (padding_begin - index , param_end - index ),
1613+ param_begin - index ,
1614+ ),
1615+ )
1616+ if param_end > padding_begin :
1617+ padded_param .add (param_name )
1618+
1619+ param_slice_info [param_name ] = flattened_range
1620+
15841621 _FP32_MASTER = "fp32_master_0"
15851622 _optimizer_scalar_name = [
15861623 "beta1_pow_acc_0" ,
@@ -1600,29 +1637,44 @@ def _generate_base_static_name(vname):
16001637 return vname [: - (len (name ) + 1 )], name
16011638 raise ValueError (f"Cannot split variable name: { vname } ." )
16021639
1640+ model_sharded_state_dict = dict (sorted (model_sharded_state_dict .items ()))
16031641 static_to_struct_mapping = {}
16041642 for k , v in model_sharded_state_dict .items ():
16051643 if v .local_tensor .name not in static_to_struct_mapping :
16061644 static_to_struct_mapping [v .local_tensor .name ] = k
16071645
16081646 optimizer_state_dict = optimizer .state_dict ()
16091647 optimizer_unified_name_mapping = {}
1648+ unified_slice_info = {}
16101649
16111650 master_weights = optimizer_state_dict .pop ("master_weights" , None )
16121651 optimizer_state_dict .pop ("LR_Scheduler" , None )
16131652 for key , _ in optimizer_state_dict .items ():
16141653 static_name , optim_state_type = _generate_base_static_name (key )
16151654 struct_name = static_to_struct_mapping [static_name ]
16161655 unified_name = f"{ struct_name } .{ optim_state_type } "
1656+
1657+ flattened_range = param_slice_info [static_name ]
1658+
1659+ # if flattened_range.stop - flattened_range.start == 0:
1660+ # continue
16171661 optimizer_unified_name_mapping [key ] = unified_name
1662+ unified_slice_info [unified_name ] = flattened_range
16181663
16191664 if master_weights is not None :
16201665 for key , _ in master_weights .items ():
16211666 struct_name = static_to_struct_mapping [key ]
16221667 unified_name = f"{ struct_name } .w_0"
1668+
1669+ flattened_range = param_slice_info [key ]
1670+
1671+ # if flattened_range.stop - flattened_range.start == 0:
1672+ # continue
1673+
16231674 optimizer_unified_name_mapping [key ] = unified_name
1675+ unified_slice_info [unified_name ] = flattened_range
16241676
1625- return optimizer_unified_name_mapping
1677+ return optimizer_unified_name_mapping , unified_slice_info
16261678
16271679 def _pack_dynamic_objects (self ):
16281680 dynamic_objecs = {}
@@ -1641,6 +1693,7 @@ def _pack_dynamic_objects(self):
16411693 dynamic_objecs ["master_weights_filter" ] = self .master_weights_filter
16421694
16431695 dynamic_objecs ["unified_name_mapping" ] = self .unified_name_mapping
1696+ dynamic_objecs ["param_slice_info" ] = self .param_slice_info
16441697
16451698 return dynamic_objecs
16461699
@@ -1682,6 +1735,7 @@ def process_update_task(self, updates):
16821735 self .master_weights_filter = dynamic_objecs ["master_weights_filter" ]
16831736
16841737 self .unified_name_mapping = dynamic_objecs ["unified_name_mapping" ]
1738+ self .param_slice_info = dynamic_objecs ["param_slice_info" ]
16851739
16861740 optimizer_states_meta = dynamic_objecs ["optimizer_states_meta" ]
16871741 model_states_meta = dynamic_objecs ["model_states_meta" ]
@@ -1708,12 +1762,34 @@ def _replace_pname_with_unified(self, state_dict):
17081762 def _filter_state_dict (state_dict , filter_map ):
17091763 need_remove_keys = []
17101764 for k , _ in state_dict .items ():
1711- if filter_map [k ]:
1765+ # two case:
1766+ # 1. Mutliple key share the same tensor.
1767+ # 2. Don't need to be saved in current rank.
1768+ if k not in filter_map .keys ():
1769+ logger .debug (f"[ZCC worker] { k } not exist in filter map." )
1770+ if (k not in filter_map .keys ()) or filter_map [k ]:
17121771 need_remove_keys .append (k )
17131772 for k in need_remove_keys :
17141773 state_dict .pop (k )
17151774 return state_dict
17161775
1776+ @staticmethod
1777+ def _slice_padded_tensor (static_dict , param_slice_info ):
1778+ new_static_dict = {}
1779+ for k , v in static_dict .items ():
1780+ if k in param_slice_info :
1781+ logger .info (f"[ZCC worker] Slice padded tensor of { k } " )
1782+ flattened_range = param_slice_info [k ]
1783+ new_static_dict [k ] = paddle .slice (
1784+ v ,
1785+ axes = [0 ],
1786+ starts = [0 ],
1787+ ends = [flattened_range .stop - flattened_range .start ],
1788+ )
1789+ else :
1790+ new_static_dict [k ] = v
1791+ return new_static_dict
1792+
17171793 def _save_model_state (self , output_dir ):
17181794 data_file_name , meta_file_name = self .distcp_file_name
17191795 self .model_states_path = os .path .join (output_dir , "model_state" , data_file_name )
@@ -1722,9 +1798,10 @@ def _save_model_state(self, output_dir):
17221798 if self .dp_rank <= 0 or self .use_expert_parallel :
17231799 with device_guard ("cpu" ):
17241800 state_dict = self .param_fusion_storage_helper .state_dict ()
1801+ logger .debug (f"model states key before filter is { state_dict .keys ()} " )
17251802
17261803 state_dict = self ._filter_state_dict (state_dict , self .model_state_filter )
1727-
1804+ logger . debug ( f"model states length is { len ( state_dict ) } " )
17281805 paddle .save (state_dict , self .model_states_path )
17291806
17301807 if self .device_id == 0 :
@@ -1750,12 +1827,21 @@ def _save_opt_state(self, output_dir):
17501827 master_weights = self ._replace_pname_with_unified (master_weights )
17511828 logger .info ("[ZCC worker] master weightsdict replace pname using unified name." )
17521829
1830+ opt_state_dict = self ._slice_padded_tensor (opt_state_dict , self .param_slice_info )
1831+ logger .info ("[ZCC worker] opt state dict slice padded tensor complete." )
1832+ master_weights = self ._slice_padded_tensor (master_weights , self .param_slice_info )
1833+ logger .info ("[ZCC worker] master weights slice padded tensor complete." )
1834+
17531835 if self .dp_rank > 0 : # ep
17541836 opt_state_dict = self ._filter_moe_no_sync_optimizer_params (self .model_meta_content , opt_state_dict )
17551837
17561838 opt_state_dict = self ._filter_state_dict (opt_state_dict , self .opt_state_filter )
1839+ logger .info ("[ZCC worker] opt state dict filter by opt_state_filter complete." )
17571840 master_weights = self ._filter_state_dict (master_weights , self .master_weights_filter )
1841+ logger .info ("[ZCC worker] master weights dict filter by master_weights_filter complete." )
17581842
1843+ logger .debug (f"opt states length is { len (opt_state_dict )} " )
1844+ logger .debug (f"master weights length is { len (master_weights )} " )
17591845 paddle .save (opt_state_dict , self .opt_state_path )
17601846 paddle .save (master_weights , self .master_weight_path )
17611847 if self .device_id == 0 :
@@ -1771,6 +1857,7 @@ def _save_ema_state(self, output_dir):
17711857
17721858 if self .dp_rank > 0 :
17731859 ema_state_dict = self ._filter_moe_no_sync_optimizer_params (self .model_meta_content , ema_state_dict )
1860+ logger .debug (f"ema states length is { len (ema_state_dict )} " )
17741861 paddle .save (ema_state_dict , self .ema_name_path )
17751862 logger .info ("[ZCC worker] Finish ema states saved." )
17761863
0 commit comments