diff --git a/paddlenlp/taskflow/taskflow.py b/paddlenlp/taskflow/taskflow.py index 399206c62775..6d7040883990 100644 --- a/paddlenlp/taskflow/taskflow.py +++ b/paddlenlp/taskflow/taskflow.py @@ -14,7 +14,6 @@ # limitations under the License. import threading -import warnings import paddle @@ -43,7 +42,10 @@ from .word_segmentation import SegJiebaTask, SegLACTask, SegWordTagTask from .zero_shot_text_classification import ZeroShotTextClassificationTask -warnings.simplefilter(action="ignore", category=Warning, lineno=0, append=False) +# import warnings + + +# warnings.simplefilter(action="ignore", category=Warning, lineno=0, append=False) TASKS = { "dependency_parsing": { diff --git a/paddlenlp/trainer/plugins/timer.py b/paddlenlp/trainer/plugins/timer.py new file mode 100644 index 000000000000..0f0cdd4c6171 --- /dev/null +++ b/paddlenlp/trainer/plugins/timer.py @@ -0,0 +1,124 @@ +# Copyright 2020-present the HuggingFace Inc. team. +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time + +import paddle + +from paddlenlp.utils.log import logger + + +class _Timer: + """Profile Timer for recording time taken by forward/ bacward/ reduce/ step.""" + + def __init__(self, name): + self.name = name + self.elapsed_ = 0.0 + self.started_ = False + self.start_time = time.time() + + def start(self): + """Start the timer.""" + assert not self.started_, "timer has already started" + paddle.device.cuda.synchronize() + self.start_time = time.time() + self.started_ = True + + def stop(self): + """Stop the timers.""" + assert self.started_, "timer is not started." + paddle.device.cuda.synchronize() + self.elapsed_ += time.time() - self.start_time + self.started_ = False + + def reset(self): + """Reset timer.""" + self.elapsed_ = 0.0 + self.started_ = False + + def elapsed(self, reset=True): + """Calculate the elapsed time.""" + started_ = self.started_ + # If the timing in progress, end it first. + if self.started_: + self.stop() + # Get the elapsed time. + elapsed_ = self.elapsed_ + # Reset the elapsed time + if reset: + self.reset() + # If timing was in progress, set it back. + if started_: + self.start() + return elapsed_ + + +class Timers: + """Group of timers.""" + + def __init__(self): + self.timers = {} + + def __call__(self, name): + if name not in self.timers: + self.timers[name] = _Timer(name) + return self.timers[name] + + def write(self, names, writer, iteration, normalizer=1.0, reset=True): + """Write timers to a tensorboard writer""" + assert normalizer > 0.0 + for name in names: + value = self.timers[name].elapsed(reset=reset) / normalizer + writer.add_scalar("timers/" + name, value, iteration) + + def log(self, names, normalizer=1.0, reset=True): + """Log a group of timers.""" + assert normalizer > 0.0 + # string = "time (ms) / rate" + string = "time (ms)" + + time_dict = {} + for name in names: + time_dict[name] = self.timers[name].elapsed(reset=reset) * 1000.0 / normalizer + + # total_time = sum(list(time_dict.values())) + # string += " | total_time : {:.2f} ".format(total_time) + time_dict = sorted(time_dict.items(), key=lambda x: x[1], reverse=True) + + for time_tuple in time_dict: + name, value = time_tuple + # string += " | {} : {:.2f} ({:.2f}%) ".format(name, value, value * 100.0 / total_time) + string += " | {} : {:.2f}".format(name, value) + return string + + +_GLOBAL_TIMERS = None + + +def get_timers(): + global _GLOBAL_TIMERS + return _GLOBAL_TIMERS + + +def set_timers(): + global _GLOBAL_TIMERS + logger.info("enable PaddleNLP timer") + _GLOBAL_TIMERS = Timers() + + +def disable_timers(): + global _GLOBAL_TIMERS + logger.info("disable PaddleNLP timer") + _GLOBAL_TIMERS = None diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index 5f6297c6fd27..947edf8fa6dc 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -18,8 +18,11 @@ import collections import contextlib +import copy import inspect +import json import math +import multiprocessing import os import random import re @@ -27,6 +30,7 @@ import sys import time import types +from collections import OrderedDict from collections.abc import Mapping from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Tuple, Union @@ -38,15 +42,38 @@ import paddle.nn as nn from packaging import version from paddle.distributed import fleet +from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer import ( + DygraphShardingOptimizer, +) + +try: + from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer import ( + DygraphShardingOptimizerV2, + ) +except: + DygraphShardingOptimizerV2 = None + +from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.hybrid_parallel_optimizer import ( + HybridParallelOptimizer, +) from paddle.distributed.fleet.utils.hybrid_parallel_util import ( fused_allreduce_gradients, + obtain_optimizer_parameters_list, ) +from paddle.distributed.fleet.utils.timer_helper import get_timers as paddle_get_timers from paddle.io import DataLoader, Dataset, DistributedBatchSampler from tqdm.auto import tqdm from ..data import DataCollator, DataCollatorWithPadding, default_data_collator from ..peft import LoRAModel, PrefixModelForCausalLM -from ..transformers.model_utils import PretrainedModel, _add_variant, unwrap_model +from ..transformers.model_utils import ( + PretrainedModel, + _add_variant, + exlclude_paramters_in_state_dict, + filter_sharded_params, + unwrap_model, + unwrap_optimizer, +) from ..transformers.tokenizer_utils import PretrainedTokenizer from ..utils import device_guard from ..utils.batch_sampler import DistributedBatchSampler as NlpDistributedBatchSampler @@ -58,6 +85,7 @@ from ..utils.import_utils import is_datasets_available from ..utils.log import logger from .integrations import get_reporting_integration_callbacks +from .plugins.timer import get_timers, set_timers from .trainer_callback import ( CallbackHandler, DefaultFlowCallback, @@ -86,6 +114,7 @@ speed_metrics, ) from .training_args import TrainingArguments +from .utils import reshard as reshard_util from .utils.helper import ( # nested_truncate, distributed_concat, nested_concat, @@ -93,6 +122,7 @@ nested_numpify, nested_truncate, ) +from .utils.reshard import SHARDING_STRATEGY_V1 DEFAULT_CALLBACKS = [DefaultFlowCallback] DEFAULT_PROGRESS_CALLBACK = ProgressCallback @@ -104,6 +134,8 @@ OPTIMIZER_NAME = "optimizer.pdopt" SCHEDULER_NAME = "scheduler.pdparams" SCALER_NAME = "scaler.pdparams" +MODEL_META_NAME = "model_meta.json" +SHARDING_META_NAME = "shard_meta.json" if is_datasets_available(): @@ -120,6 +152,44 @@ except: from paddle.fluid.dataloader.dataloader_iter import _DataLoaderIterBase +async_save_queue = [] + + +def _save_func(obj, path, saved_signal_path, protocol): + paddle.save(obj, path, protocol) + # dump savd_siganl + with open(saved_signal_path, mode="w+") as f: + f.write("1") + + +def clear_async_save_task_queue(): + """ + wait until all async save task to be done. + """ + while len(async_save_queue) > 0: + task = async_save_queue.pop() + if task and task.is_alive(): + task.join() + + +def async_save_optimizer(optimizer_state_dict, path, saved_signal_path, protocol=4, sync_other_task=False): + cpu_optimizer_state_dict = {} + for k, v in optimizer_state_dict.items(): + if k == "master_weights": + cpu_optimizer_state_dict[k] = {} + for kk, vv in v.items(): + cpu_optimizer_state_dict[k][kk] = vv.pin_memory() + elif k == "LR_Scheduler": + cpu_optimizer_state_dict[k] = copy.deepcopy(v) + else: + cpu_optimizer_state_dict[k] = v.pin_memory() + paddle.device.cuda.synchronize() + if sync_other_task: + clear_async_save_task_queue() + p = multiprocessing.Process(target=_save_func, args=(cpu_optimizer_state_dict, path, saved_signal_path, protocol)) + p.start() + async_save_queue.append(p) + def paddlenlp_load(path, return_numpy=False): if return_numpy: @@ -216,6 +286,8 @@ def __init__( args = TrainingArguments(output_dir=output_dir) self.args = args + # TODO(@tiangexiao): use async save in framework instead when use_async_save==True + self.save_func = paddle.save self.is_in_train = False # self.do_grad_scaling = args.fp16 @@ -251,6 +323,9 @@ def __init__( self.train_dataset = train_dataset self.eval_dataset = eval_dataset self.tokenizer = tokenizer + if not args.skip_profile_timer: + set_timers() + self.timers = get_timers() self.model_wrapped = model self.model = model @@ -322,6 +397,8 @@ def __init__( if self.amp_dtype == "float16" or self.amp_dtype == "bfloat16": if ShardingOption.SHARD_OP in self.args.sharding: self.scaler = fleet.distributed_scaler(self.scaler) + if self.args.amp_master_grad: + mix_precision_utils.MixPrecisionScaler(self.scaler) # retun value has no use else: # scaler for stage2 and stage3 from paddle.distributed.fleet.meta_parallel.sharding.group_sharded_utils import ( @@ -329,6 +406,7 @@ def __init__( ) self.scaler = GroupShardedScaler(self.scaler) + else: self.do_grad_scaling = False self.use_cuda_amp = False @@ -356,7 +434,6 @@ def fn(layer): self.control = self.callback_handler.on_init_end(self.args, self.state, self.control) self.print_config() - # very last self._memory_tracker.stop_and_update_metrics() @@ -403,39 +480,165 @@ def load_state_dict_from_checkpoint(self, resume_from_checkpoint=None): `bool` and equals `True`, load the last checkpoint in *args.output_dir* as saved by a previous instance of [`Trainer`]. Only load model state dict. """ - resume_from_checkpoint = None if not resume_from_checkpoint else resume_from_checkpoint + resume_from_checkpoint = self.check_resume_from_checkpoint(resume_from_checkpoint) + + if resume_from_checkpoint is None: + return + + state_dict = None + if self.args.load_sharded_model: + state_dict = self.load_state_dict_from_checkpoint_with_reshard(resume_from_checkpoint) + if self.args.bf16: + state_dict = self.recover_params_from_master_weights(state_dict) + else: + if self.args.dataset_rank == 0 or self.args.use_moe: + state_dict = self.load_one_state_dict_from_checkpoint( + resume_from_checkpoint, self.args.old_weight_name_suffix + ) + else: + logger.info(f"not loading ckpt :{self.args.dataset_rank}") + + # If the model is on the GPU, it still works! + if state_dict is not None: + self._set_state_dict_in_model(state_dict) + # release memory + del state_dict + + def recover_params_from_master_weights(self, state_dict): + opt_state_dict = self.optimizer.state_dict() + assert "master_weights" in opt_state_dict + master_weights = opt_state_dict["master_weights"] + tmp = OrderedDict() + (master_weights, tmp) = (tmp, master_weights) + # cast to before + for (k, v) in tmp.items(): + master_weights[k] = paddle.cast(v.cuda(), paddle.bfloat16).cpu() + + if self.args.load_sharding_stage1_model: + structure_name_map = {k: v.name for (k, v) in self.model.state_dict().items()} + node_model_state = reshard_util.NodeModelState() + node_model_state_tmp = reshard_util.NodeModelState() + node_model_state_tmp.add_master_weights(master_weights) + node_model_state_tmp.pack_keys(structure_name_map) + node_model_state.merge_from(node_model_state_tmp, self.sharding_group.rank) + del node_model_state_tmp + assert reshard_util.is_sharding_opt(self.optimizer) + sharding_strategy = reshard_util.get_sharding_strategy(self.optimizer) + restore_func = ( + reshard_util.sharding_v1.restore + if sharding_strategy == SHARDING_STRATEGY_V1 + else reshard_util.sharding_v2.restore + ) + node_model_state = restore_func(node_model_state, self.model, self.optimizer, self.hcg) + node_model_state.unpack_keys() + master_weights = node_model_state.master_weights + + def filter_func(name): + return True + + master_weights = reshard_util.all_gather_state_dict(master_weights, filter_func, self.sharding_group) + + model_state_dict = self.model.state_dict() + logger.info("before recover, model_state_dict number: {}".format(len(model_state_dict))) + for key, param in model_state_dict.items(): + if param.name in master_weights: + assert param.shape == master_weights[param.name].shape + paddle.assign(master_weights[param.name].cuda(), model_state_dict[key]) + + logger.info("after recover, casted model_state_dict number: {}".format(len(model_state_dict))) + state_dict.update(model_state_dict) + return state_dict + + def load_state_dict_from_checkpoint_with_reshard(self, resume_from_checkpoint): + """load state_dict from_checkpoint with reshard, Only load model state dict.""" + parallel_config = self._load_distributed_strategy(resume_from_checkpoint) + pp_degree = parallel_config["pp_degree"] + mp_degree = parallel_config["mp_degree"] + sharding_degree = parallel_config["sharding_degree"] + self.args.pipeline_parallel_degree == pp_degree + self.args.tensor_parallel_degree == mp_degree + cur_sharding_degree = self.args.sharding_parallel_degree + + state_dict = OrderedDict() + + def get_name_suffix(i): + name = [] + if self.args.tensor_parallel_degree > 1: + name.append(f"tp{self.args.tensor_parallel_rank:0>2d}") + if self.args.pipeline_parallel_degree > 1: + name.append(f"pp{self.args.pipeline_parallel_rank:0>2d}") + name.append(f"shard{i:0>2d}") + return "_".join(name) + for i in range(self.args.sharding_parallel_rank, sharding_degree, cur_sharding_degree): + tmp = self.load_one_state_dict_from_checkpoint(resume_from_checkpoint, get_name_suffix(i)) + for (k, v) in tmp.items(): + state_dict[k] = v + del tmp + + def filter_func(name): + return True + + if self.args.load_sharding_stage1_model: + state_dict = reshard_util.all_gather_state_dict(state_dict, filter_func, self.sharding_group) + + return state_dict + + def load_one_state_dict_from_checkpoint(self, resume_from_checkpoint, weight_name_suffix): + """ + load state_dict of one shard from_checkpoint, Only load model state dict. + """ + if isinstance(self.model, LoRAModel): + weight_name = LORA_WEIGHT_FILE_NAME + elif isinstance(self.model, PrefixModelForCausalLM): + weight_name = PREFIX_WEIGHT_FILE_NAME + else: + weight_name = PADDLE_WEIGHT_FILE_NAME + file_path = os.path.join(resume_from_checkpoint, _add_variant(weight_name, weight_name_suffix)) + if not os.path.isfile(file_path): + raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}, no {file_path}") + + logger.info(f"Loading model from {resume_from_checkpoint} .") + + # We load the model state dict on the CPU to avoid an OOM error. + state_dict = paddle.load( + os.path.join(resume_from_checkpoint, _add_variant(weight_name, weight_name_suffix)), + return_numpy=True, + ) + return state_dict + + def check_resume_from_checkpoint(self, resume_from_checkpoint): + resume_from_checkpoint = None if not resume_from_checkpoint else resume_from_checkpoint # Load potential model checkpoint if isinstance(resume_from_checkpoint, bool) and resume_from_checkpoint: resume_from_checkpoint = get_last_checkpoint(self.args.output_dir) if resume_from_checkpoint is None: raise ValueError(f"No valid checkpoint found in output directory ({self.args.output_dir})") + return resume_from_checkpoint - if resume_from_checkpoint is not None: - if isinstance(self.model, LoRAModel): - weight_name = LORA_WEIGHT_FILE_NAME - elif isinstance(self.model, PrefixModelForCausalLM): - weight_name = PREFIX_WEIGHT_FILE_NAME - else: - weight_name = PADDLE_WEIGHT_FILE_NAME - - if not os.path.isfile( - os.path.join(resume_from_checkpoint, _add_variant(weight_name, self.args.weight_name_suffix)) - ): - raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}") - - logger.info(f"Loading model from {resume_from_checkpoint} .") - - # We load the model state dict on the CPU to avoid an OOM error. - state_dict = paddle.load( - os.path.join(resume_from_checkpoint, _add_variant(weight_name, self.args.weight_name_suffix)), - return_numpy=True, - ) - # If the model is on the GPU, it still works! - self._set_state_dict_in_model(state_dict) + def _load_check_point(self, resume_from_checkpoint, delay_optimizer_creation, max_steps): + # Check if saved optimizer or scheduler states exist + self._load_optimizer_and_scheduler(resume_from_checkpoint) + self.load_state_dict_from_checkpoint(resume_from_checkpoint) + model = self._wrap_model(self.model_wrapped) + # for the rest of this function `model` is the outside model, whether it was wrapped or not + if model is not self.model: + self.model_wrapped = model + if delay_optimizer_creation: + self.create_optimizer_and_scheduler(num_training_steps=max_steps) + return model - # release memory - del state_dict + def _load_sharded_check_point(self, resume_from_checkpoint, delay_optimizer_creation, max_steps): + model = self._wrap_model(self.model_wrapped) + # for the rest of this function `model` is the outside model, whether it was wrapped or not + if model is not self.model: + self.model_wrapped = model + if delay_optimizer_creation: + self.create_optimizer_and_scheduler(num_training_steps=max_steps) + # Check if saved optimizer or scheduler states exist + self._load_optimizer_and_scheduler(resume_from_checkpoint) + self.load_state_dict_from_checkpoint(resume_from_checkpoint) + return model def train( self, @@ -456,42 +659,11 @@ def train( """ args = self.args self.is_in_train = True - resume_from_checkpoint = None if not resume_from_checkpoint else resume_from_checkpoint + resume_from_checkpoint = self.check_resume_from_checkpoint(resume_from_checkpoint) # memory metrics - must set up as early as possible self._memory_tracker.start() - # Load potential model checkpoint - if isinstance(resume_from_checkpoint, bool) and resume_from_checkpoint: - resume_from_checkpoint = get_last_checkpoint(args.output_dir) - if resume_from_checkpoint is None: - raise ValueError(f"No valid checkpoint found in output directory ({args.output_dir})") - - if resume_from_checkpoint is not None: - if isinstance(self.model, LoRAModel): - weight_name = LORA_WEIGHT_FILE_NAME - elif isinstance(self.model, PrefixModelForCausalLM): - weight_name = PREFIX_WEIGHT_FILE_NAME - else: - weight_name = PADDLE_WEIGHT_FILE_NAME - if not os.path.isfile( - os.path.join(resume_from_checkpoint, _add_variant(weight_name, self.args.weight_name_suffix)) - ): - raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}") - - logger.info(f"Loading model from {resume_from_checkpoint} .") - - # TODO: Need to load the model state dict on the CPU to avoid an OOM error. - state_dict = paddle.load( - os.path.join(resume_from_checkpoint, _add_variant(weight_name, self.args.weight_name_suffix)), - return_numpy=True, - ) - # If the model is on the GPU, it still works! - self._set_state_dict_in_model(state_dict) - - # release memory - del state_dict - train_dataloader = self.get_train_dataloader() total_train_batch_size = args.train_batch_size * args.gradient_accumulation_steps * args.dataset_world_size @@ -536,23 +708,15 @@ def train( # and ShardingOption.SHARD_OP in self.args.sharding # ) delay_optimizer_creation = False - if not delay_optimizer_creation: self.create_optimizer_and_scheduler(num_training_steps=max_steps) self.state = TrainerState() - model = self._wrap_model(self.model_wrapped) - - # for the rest of this function `model` is the outside model, whether it was wrapped or not - if model is not self.model: - self.model_wrapped = model - - if delay_optimizer_creation: - self.create_optimizer_and_scheduler(num_training_steps=max_steps) - - # Check if saved optimizer or scheduler states exist - self._load_optimizer_and_scheduler(resume_from_checkpoint) + if self.args.load_sharded_model: + model = self._load_sharded_check_point(resume_from_checkpoint, delay_optimizer_creation, max_steps) + else: + model = self._load_check_point(resume_from_checkpoint, delay_optimizer_creation, max_steps) logger.info("***** Running training *****") logger.info(f" Num examples = {num_examples}") @@ -623,9 +787,13 @@ def train( epoch_iterator = train_dataloader # steps_in_epoch = len(epoch_iterator) - steps_in_epoch = ( - len(epoch_iterator) if len_dataloader is not None else args.max_steps * args.gradient_accumulation_steps - ) + global_steps_in_epoch = len(epoch_iterator) if len_dataloader is not None else args.max_steps + if len_dataloader is not None: + if self.args.gradient_accumulation_steps > len(epoch_iterator): + logger.warning( + f"changing accumulation step from `{self.args.gradient_accumulation_steps}` to `{len(epoch_iterator)}` to avoid, cross epoch accumulate" + ) + self.args.gradient_accumulation_steps = len(epoch_iterator) self.callback_handler.model = self.model self.callback_handler.optimizer = self.optimizer @@ -648,17 +816,21 @@ def train( npu_accelerate_plugin(self.optimizer) + self.timers and self.timers("read-data").start() + for epoch in range(epochs_trained, num_train_epochs): if isinstance(train_dataloader, paddle.io.DataLoader) and isinstance( train_dataloader.batch_sampler, DistributedBatchSampler ): train_dataloader.batch_sampler.set_epoch(epoch) - step = -1 + step = 0 self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control) - for step, inputs in enumerate(epoch_iterator): + for _, inputs in enumerate(epoch_iterator): + self.timers and self.timers("read-data").stop() self.callback_handler.on_load_data_end(args, self.state, self.control, inputs=inputs) + # Skip past any already trained steps if resuming training # for paddlenlp.utils.batch_sampler.DistributedBatchSampler # We use consumed_samples to reset the status @@ -671,7 +843,7 @@ def train( steps_trained_progress_bar.close() steps_trained_progress_bar = None self._load_rng_state(resume_from_checkpoint) - step += steps_trained_in_current_epoch + # step += steps_trained_in_current_epoch elif steps_trained_in_current_epoch > 0: steps_trained_in_current_epoch -= 1 if steps_trained_progress_bar is not None: @@ -685,20 +857,22 @@ def train( if step % args.gradient_accumulation_steps == 0: self.control = self.callback_handler.on_step_begin(args, self.state, self.control) - + self.timers and self.timers("forward-backward").start() dp_enabled = ( self.args.data_parallel_degree > 1 if self.args.use_hybrid_parallel else args.local_rank != -1 ) forbidden_no_sync = False - # stage2 and stage3 should not no_sync, because the is no DDP wrapper and no_sync API - if self.sharding and (ShardingOption.SHARD_OP not in self.args.sharding): + # stage2 and stage3 should not no_sync, because the is no DDP wrapper and no_sync API + # hybrid_parallel (tp or pp or sharding stage 1) should not no_sync + if self.args.use_hybrid_parallel: forbidden_no_sync = True # hybrid_parallel (tp or pp) should not no_sync if self.args.use_hybrid_parallel and ( self.args.tensor_parallel_degree > 1 or self.args.pipeline_parallel_degree > 1 ): forbidden_no_sync = True - + if self.args.use_moe: + forbidden_no_sync = True availiable_no_sync = dp_enabled and not forbidden_no_sync is_no_sync = ( @@ -710,27 +884,37 @@ def train( # stage1. the same as ddp # stage2. manualy collect gradient on dp group + hack_dp_master_grad = self.args.amp_master_grad and not self.args.use_hybrid_parallel + if hack_dp_master_grad: + is_no_sync = False + if is_no_sync: # Avoid unnecessary DDP synchronization since there will be no backward pass on this example. with model.no_sync(): - tr_loss_step = self.training_step(model, inputs) + tr_loss_step, outputs = self.training_step(model, inputs) else: - tr_loss_step = self.training_step(model, inputs) + tr_loss_step, outputs = self.training_step(model, inputs) - tr_loss += tr_loss_step + def fused_allreduce_gradients_no_sync(paramlist, hcg): + paramlist = list(paramlist) + nonmoe_list = [p for p in paramlist if not getattr(p, "no_sync", False)] + moelist = [p for p in paramlist if getattr(p, "no_sync", False)] + if moelist and not self.args.use_moe: + logger.warning("found `no sync` param when `use_moe=False`") + fused_allreduce_gradients(nonmoe_list, hcg) - if (step + 1) % args.gradient_accumulation_steps == 0 or ( - # last step in epoch but step is always smaller than gradient_accumulation_steps - steps_in_epoch <= args.gradient_accumulation_steps - and (step + 1) == steps_in_epoch - ): + tr_loss += tr_loss_step + if (step + 1) % args.gradient_accumulation_steps == 0: + self.timers and self.timers("forward-backward").stop() # Maunally collect gradients when group_sharded_parallel can't accept dp_group # Case 1: Use sharding stage 2/3 with dp # Case 2: Use recompute and dp # local_rank != -1 don't means dp in networks. + self.timers and self.timers("all-reduce").start() + if self.sharding and ShardingOption.SHARD_OP not in self.args.sharding: if self.args.data_parallel_degree > 1 and not is_dp_group_support_in_group_sharded_parallel(): - fused_allreduce_gradients(model.parameters(), fleet.get_hybrid_communicate_group()) + fused_allreduce_gradients_no_sync(model.parameters(), fleet.get_hybrid_communicate_group()) if ShardingOption.FULL_SHARD in self.args.sharding: # Why need sync on parm again ? # TODO: fix this. @@ -743,16 +927,46 @@ def train( # Case 2: Use recompute and dp / sharding stage1, # manualy collect gradient for dp. elif args.recompute and availiable_no_sync: - fused_allreduce_gradients(list(model.parameters()), None) + assert not self.args.use_moe, "moe must `no_sync`" + fused_allreduce_gradients_no_sync(list(model.parameters()), None) + + # Case 2.1: # 纯dp + moe 才在这里手动执行 梯度聚合。 + elif args.use_moe and not args.use_hybrid_parallel: + fused_allreduce_gradients_no_sync(list(model.parameters()), None) + + pipeline_parallel_config = set(args.pipeline_parallel_config.split(" ")) + enable_delay_scale_loss = "enable_delay_scale_loss" in pipeline_parallel_config + enable_dp_comm_overlap = "enable_dp_comm_overlap" in pipeline_parallel_config + + if isinstance(self.optimizer, HybridParallelOptimizer) and not self.do_grad_scaling: + parameters_list = obtain_optimizer_parameters_list(self.optimizer._inner_opt) + + if not enable_dp_comm_overlap: + if self.optimizer._sharding_enable: + assert reshard_util.is_sharding_opt(self.optimizer) + self.optimizer._inner_opt.reduce_gradients(list(parameters_list), self.optimizer._hcg) + + if self.optimizer._dp_enable: + fused_allreduce_gradients_no_sync(list(parameters_list), self.optimizer._hcg) + else: + assert self.args.use_moe, "moe should not `enable_dp_comm_overlap`" + + self.timers and self.timers("all-reduce").stop() + self.timers and self.timers("optimizer-step").start() + + # Case 3: hack dp with master_grad + if hack_dp_master_grad and not (args.recompute and availiable_no_sync): + fused_allreduce_gradients_no_sync(list(model.parameters()), None) # pipeline parallel mode, handle gradient merge here - if args.pipeline_parallel_degree > 1 and getattr(model, "_delay_scale_loss", False): + if args.pipeline_parallel_degree > 1 and enable_delay_scale_loss: for p in model._layers.parameters(): - if hasattr(p, "main_grad") and p.main_grad is not None: - assert p.grad is None - p.main_grad = p.main_grad.scale(1.0 / self.args.gradient_accumulation_steps) - elif p.grad is not None: - p.grad = p.grad.scale(1.0 / self.args.gradient_accumulation_steps) + with paddle.no_grad(): + if hasattr(p, "main_grad") and p.main_grad is not None: + assert p.grad is None + p.main_grad.scale_(1.0 / self.args.gradient_accumulation_steps) + elif p.grad is not None: + p.grad.scale_(1.0 / self.args.gradient_accumulation_steps) # Optimizer step self.callback_handler.on_optimizer_begin( @@ -760,6 +974,8 @@ def train( ) optimizer_was_run = True if self.do_grad_scaling: + if args.pipeline_parallel_degree > 1: + assert not self.args.use_moe, "pipline moe not work under fp16" scale_before = self.scaler._scale.numpy() self.scaler.step(self.optimizer) self.scaler.update() @@ -769,9 +985,13 @@ def train( logger.warning( f"optimizer not run, scale_before: {scale_before[0]}, scale_after: {scale_after[0]}" ) + elif isinstance(self.optimizer, HybridParallelOptimizer): + self.optimizer._step(parameters_list) else: self.optimizer.step() + self.timers and self.timers("optimizer-step").stop() + if optimizer_was_run: self.lr_scheduler.step() @@ -781,15 +1001,21 @@ def train( ) self.state.global_step += 1 - self.state.epoch = epoch + (step + 1) / steps_in_epoch - + self.state.epoch = epoch + self.state.global_step / global_steps_in_epoch self.control = self.callback_handler.on_step_end(args, self.state, self.control) - self._maybe_log_save_evaluate(tr_loss, model, epoch, ignore_keys_for_eval, inputs=inputs) + self._maybe_log_save_evaluate( + tr_loss, model, epoch, ignore_keys_for_eval, inputs=inputs, outputs=outputs + ) + self._print_timer() + + step = 0 else: self.control = self.callback_handler.on_substep_end(args, self.state, self.control) + step += 1 if self.control.should_epoch_stop or self.control.should_training_stop: break + self.timers and self.timers("read-data").start() if step < 0: logger.warning( @@ -877,7 +1103,27 @@ def _get_train_sampler(self) -> Optional[paddle.io.Sampler]: def _set_state_dict_in_model(self, state_dict): # TODO @ZHUI paddle need return the results of set_state_dict. - self.model.set_state_dict(state_dict) + logger.info(f"set state-dict :{self.model.set_state_dict(state_dict)}") + + def _print_timer(self): + """print timer and clear states""" + paddle_timer_info = "" + try: + paddle_pipeline_timers = paddle_get_timers() + for name, timer in paddle_pipeline_timers.timers.items(): + elapsed_time = timer.elapsed(reset=False) * 1000.0 + paddle_timer_info += f" | {name}: {elapsed_time:.2f}" + paddle_pipeline_timers.log(paddle_pipeline_timers.timers.keys(), reset=True) + except AssertionError: + pass + + if self.timers is not None: + timer_info = self.timers.log(self.timers.timers.keys(), reset=True) + else: + timer_info = "" + + if timer_info or paddle_timer_info: + logger.info(f"[Profile global_step: {self.state.global_step}] {timer_info} {paddle_timer_info}") def _maybe_log_save_evaluate(self, tr_loss, model, epoch, ignore_keys_for_eval, **kwargs): if self.control.should_log: @@ -1286,7 +1532,12 @@ def _wrap_model(self, model, training=True): # Multi-gpu training if self.args.world_size > 1 and not self.args.use_hybrid_parallel: - model = paddle.DataParallel(model) + if self.args.amp_master_grad: + mix_precision_utils.MixPrecisionLayer(model, dtype=self.amp_dtype) # return value has no use + logger.warning("Note amp_master_grad using in dp is an experimental support!") + self.optimizer = mix_precision_utils.MixPrecisionOptimizer(self.optimizer) + else: + model = paddle.DataParallel(model) # Distributed training (should be after fp16 initialization) in_pipeline_parallel_mode = self.args.pipeline_parallel_degree > 1 @@ -1352,11 +1603,17 @@ def get_expected_keys(inputs, keys): model = paddle.distributed.fleet.meta_parallel.TensorParallel(model, hcg, strategy=None) if ShardingOption.SHARD_OP in self.args.sharding: + if self.args.amp_master_grad: + mix_precision_utils.MixPrecisionLayer(model, dtype=self.amp_dtype) # return value has no use model = fleet.distributed_model(model) + if self.args.amp_master_grad: + self.optimizer = mix_precision_utils.MixPrecisionOptimizer(self.optimizer) self.optimizer = fleet.distributed_optimizer(self.optimizer) else: # sync params (broadcast) buffers in dp group - if not is_dp_group_support_in_group_sharded_parallel() and self.args.data_parallel_degree > 1: + if ( + not is_dp_group_support_in_group_sharded_parallel() or self.args.use_moe + ) and self.args.data_parallel_degree > 1: try: from paddle.fluid.dygraph.parallel import sync_params_buffers except ImportError: @@ -1380,7 +1637,7 @@ def get_expected_keys(inputs, keys): # add dp_group and exclude_layer params # https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/api/paddle/distributed/sharding/group_sharded_parallel_cn.html#group-sharded-parallel extra_kwargs = {} - if is_dp_group_support_in_group_sharded_parallel(): + if is_dp_group_support_in_group_sharded_parallel() and not self.args.use_moe: extra_kwargs["dp_group"] = self.dp_group extra_kwargs["exclude_layer"] = ["GroupNorm"] @@ -1497,7 +1754,12 @@ def compute_loss(self, model, inputs, return_outputs=False): self._past = outputs[self.args.past_index] # We don't use .loss here since the model may return tuples instead of ModelOutput. - loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0] + if isinstance(outputs, dict): + loss = outputs.pop("loss") + outputs = {k: nested_detach(v) for k, v in outputs.items()} + else: + loss = outputs[0] + outputs = nested_detach(outputs[1:]) return (loss, outputs) if return_outputs else loss @@ -1524,10 +1786,8 @@ def training_step(self, model: nn.Layer, inputs: Dict[str, Union[paddle.Tensor, model.train() inputs = self._prepare_inputs(inputs) - with self.autocast_smart_context_manager(): - loss = self.compute_loss(model, inputs) - + loss, outputs = self.compute_loss(model, inputs, return_outputs=True) if self.args.gradient_accumulation_steps > 1: loss = loss / self.args.gradient_accumulation_steps @@ -1536,7 +1796,10 @@ def training_step(self, model: nn.Layer, inputs: Dict[str, Union[paddle.Tensor, else: loss.backward() - return loss.detach() + if isinstance(outputs, dict) and "loss" in outputs: + loss = outputs.pop("loss") / self.args.gradient_accumulation_steps + + return loss.detach(), outputs def training_pipeline_step(self, model: nn.Layer, inputs: Dict[str, Union[paddle.Tensor, Any]]) -> paddle.Tensor: """ @@ -1561,10 +1824,10 @@ def training_pipeline_step(self, model: nn.Layer, inputs: Dict[str, Union[paddle self._pp_data_buffer = [] self._pp_data_buffer.append(inputs) if len(self._pp_data_buffer) != self.args.gradient_accumulation_steps: - return paddle.zeros([]) + return paddle.zeros([]), {} - for v in self._pp_data_buffer[0].values(): - assert isinstance(v, paddle.Tensor), f"Only support tensor as pipeline mode input, got type {type(v)}" + # for v in self._pp_data_buffer[0].values(): + # assert isinstance(v, paddle.Tensor), f"Only support tensor as pipeline mode input, got type {type(v)}" inputs = model._prepare_pipeline_inputs_func(self._pp_data_buffer) self._pp_data_buffer = [] @@ -1572,23 +1835,7 @@ def training_pipeline_step(self, model: nn.Layer, inputs: Dict[str, Union[paddle # hack _prepare_training, remove additional optimizer or scheduler check # https://github.com/PaddlePaddle/Paddle/blob/4695122492eee3cc9e9c585e33429c0f98dbdbb0/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py#L241 - def _prepare_training(self, data): - from paddle import framework - - # reset the virtual pp rank for each run - self.set_virtual_pipeline_rank(0) - assert framework._dygraph_tracer()._has_grad, "Please enable the generation of gradients." - if self.is_pipeline_first_stage(ignore_virtual=True) or self.is_pipeline_last_stage(ignore_virtual=True): - assert data is not None, "For the first and the last stage, the data must be set." - else: - data = None - - self._layers.train() - - return data - model.train() - # hack pipeline-layers # since the pipeline layer will check input is valid every iter. # in same case, for example, batch size warmup, we need dynamic change gradient_accumulation_steps to implement. @@ -1596,16 +1843,51 @@ def _prepare_training(self, data): model.micro_batch_size = self.args.per_device_train_batch_size model.accumulate_steps = self.args.gradient_accumulation_steps - inputs = _prepare_training(model, inputs) + if model._dp_comm_overlap or model._sharding_comm_overlap: + for _, buffers in model._chunk_2_comm_buffers.items(): + for buffer in buffers: + buffer._acc_steps = self.args.gradient_accumulation_steps + + inputs = model._prepare_training( + inputs, self.optimizer, self.lr_scheduler + ) # None, None => [optimizer, lr_scheduler] + model.optimizer = None # we do not use `PipelineParallel` to handler optimizer step + model.lr_scheduler = None with self.autocast_smart_context_manager(): loss = model.forward_backward_pipeline(inputs, self.scaler if self.do_grad_scaling else None) model.micro_batch_size, model.accumulate_steps = config_backup - - return loss.detach() - - def save_model(self, output_dir: Optional[str] = None, merge_tensor_parallel: Optional[bool] = False): + if not hasattr(model._layers._loss_fn, "info"): + return loss.detach(), {} + + if model.is_pipeline_last_stage(): + buf = [ + { + k: (v.item() if isinstance(v, paddle.Tensor) else v) / self.args.gradient_accumulation_steps + for k, v in model._layers._loss_fn.info.items() + } + ] + else: + buf = [None] + hcg = fleet.get_hybrid_communicate_group() + dist.broadcast_object_list(buf, src=hcg._pp_comm_group.ranks[-1], group=hcg.get_pipe_parallel_group()) + info = buf[0] + + # 当 pipenline 模型需要返回并打印多个 loss 时,需要在组网 `model._layers._loss_fn` 中插入 dict `info`. + # `info` 中持有需要被打印的 name-tensor 对。 + model._layers._loss_fn.info = {} + assert isinstance(info, dict), f"expect info to dict, got {type(info)}" + info = {k: v.detach() if isinstance(v, paddle.Tensor) else v for k, v in info.items()} + if "loss" in info: + loss = paddle.to_tensor(info.pop("loss")) + return loss.detach(), info + + def save_model( + self, + output_dir: Optional[str] = None, + merge_tensor_parallel: Optional[bool] = False, + ): """ Will save the model, so you can reload it using `from_pretrained()`. @@ -1618,8 +1900,22 @@ def save_model(self, output_dir: Optional[str] = None, merge_tensor_parallel: Op if self.args.should_save_model_state: self._save(output_dir=output_dir, merge_tensor_parallel=merge_tensor_parallel) + def _save_moe_weights(self, output_dir): + # save moe optimizer and model state # TODO 默认为冗余存储 + self.save_func( + self.model.state_dict(), + os.path.join(output_dir, _add_variant(PADDLE_WEIGHT_FILE_NAME, self.args.weight_name_suffix)), + ) + self.save_func( + self.optimizer.state_dict(), + os.path.join(output_dir, _add_variant(OPTIMIZER_NAME, self.args.optimizer_name_suffix)), + ) + def _save_checkpoint(self, model, metrics=None): # assert unwrap_model(model) is self.model, "internal model should be a reference to self.model" + if self.args.use_async_save: + # paddle.clear_async_save_task_queue() + clear_async_save_task_queue() # Save model checkpoint checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" @@ -1636,24 +1932,15 @@ def _save_checkpoint(self, model, metrics=None): optimizer_name = _add_variant(OPTIMIZER_NAME, self.args.optimizer_name_suffix) - if self.args.use_hybrid_parallel: - if self.dp_group.rank <= 0: - os.makedirs(output_dir, exist_ok=True) - paddle.save( - self.optimizer.state_dict(), - os.path.join(output_dir, optimizer_name), - ) - if self.args.should_save: if not self.args.use_hybrid_parallel: - paddle.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME)) + self.save_func(self.optimizer.state_dict(), os.path.join(output_dir, optimizer_name)) # FIXME: manybe only save one copy - paddle.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) + self.save_func(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) if self.do_grad_scaling: - paddle.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME)) - + self.save_func(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME)) # Determine the new best metric / best model checkpoint if metrics is not None and self.args.metric_for_best_model is not None: metric_to_check = self.args.metric_for_best_model @@ -1693,9 +1980,30 @@ def _save_checkpoint(self, model, metrics=None): if self.args.world_size > 1: # use global process_index to save process_index = self.args.process_index - paddle.save(rng_states, os.path.join(output_dir, f"rng_state_{process_index}.pth")) + self.save_func(rng_states, os.path.join(output_dir, f"rng_state_{process_index}.pth")) else: - paddle.save(rng_states, os.path.join(output_dir, "rng_state.pth")) + self.save_func(rng_states, os.path.join(output_dir, "rng_state.pth")) + + saved_signal_path = os.path.join(output_dir, f"saved_signal_{dist.get_rank()}") + if self.args.use_hybrid_parallel: + if self.dp_group.rank <= 0: + os.makedirs(output_dir, exist_ok=True) + if self.args.use_async_save: + assert not self.args.use_moe, "moe no support async save" + async_save_optimizer( + self.optimizer.state_dict(), + os.path.join(output_dir, optimizer_name), + saved_signal_path=saved_signal_path, + sync_other_task=True, + ) + + else: + self.save_func(self.optimizer.state_dict(), os.path.join(output_dir, optimizer_name)) + with open(saved_signal_path, mode="w+") as f: + f.write("1") + + if self.args.use_moe and self.args.data_parallel_rank > 0: + self._save_moe_weights(output_dir) # Maybe delete some older checkpoints. if self.args.should_save and (True if not self.args.use_hybrid_parallel else self.args.local_rank == 0): @@ -1767,6 +2075,130 @@ def _rotate_checkpoints(self, use_mtime=False, output_dir=None) -> None: logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit") shutil.rmtree(checkpoint) + def _save_distributed_model_meta(self, dir): + if not self.args.use_hybrid_parallel: + return + + if not self.args.save_sharding_stage1_model: + return + + nranks = dist.get_world_size() + if nranks <= 1: + return + + model_meta = {} + parallel_config = self._get_distributed_strategy() + if parallel_config: + model_meta["parallel_config"] = parallel_config + sharding_metas = self._gather_sharding_metas() + if sharding_metas: + model_meta["sharding_metas"] = sharding_metas + + if dist.get_rank(): + return + + path = os.path.join(dir, MODEL_META_NAME) + with open(path, "w") as f: + json.dump(model_meta, f, indent=4) + + def _get_distributed_strategy(self): + pp_degree = 1 + mp_degree = 1 + sharding_degree = 1 + vpp_degree = 1 + nranks = dist.get_world_size() + if self.args.use_hybrid_parallel and nranks > 1: + if dist.get_rank(): + return + hcg = fleet.get_hybrid_communicate_group() + mp_degree = hcg.get_model_parallel_world_size() + pp_degree = hcg.get_pipe_parallel_world_size() + sharding_degree = hcg.get_sharding_parallel_world_size() + """ + if pp_degree > 1: + assert isinstance(model, fleet.meta_parallel.PipelineParallel), "must be pipeline model" + vpp_degree = model._layers.get_num_virtual_stages() + """ + parallel_config = { + "pp_degree": pp_degree, + "mp_degree": mp_degree, + "sharding_degree": sharding_degree, + "vpp_degree": vpp_degree, + } + return parallel_config + + def _load_model_meta(self, dir): + meta_path = os.path.join(dir, MODEL_META_NAME) + assert os.path.exists(meta_path), f"{meta_path} not exist" + with open(meta_path, "r") as handle: + model_dist_meta = json.load(handle) + assert "parallel_config" in model_dist_meta + return model_dist_meta + + def _load_distributed_strategy(self, dir): + model_dist_meta = self._load_model_meta(dir) + parallel_config = model_dist_meta["parallel_config"] + assert "pp_degree" in parallel_config + assert "mp_degree" in parallel_config + assert "sharding_degree" in parallel_config + return parallel_config + + def _gather_sharding_metas(self): + nranks = dist.get_world_size() + if not self.args.use_hybrid_parallel or nranks <= 1: + return None + if self.args.sharding_parallel_rank != 0: + return None + if self.args.data_parallel_rank != 0: + return None + if not reshard_util.is_sharding_opt(self.optimizer): + return None + + sharding_strategy = reshard_util.get_sharding_strategy(self.optimizer) + param2rank = {} + + if sharding_strategy == SHARDING_STRATEGY_V1: + optimizer = unwrap_optimizer(self.optimizer, DygraphShardingOptimizer) + param2rank = {k: v for (k, v) in optimizer._param2rank.items()} + + model = self.model + structure_name_mapping = {k: v.name for (k, v) in model.state_dict().items()} + + sharding_metas = {} + sharding_meta = {} + + sharding_meta["param2rank"] = param2rank + sharding_meta["structure_name_mapping"] = structure_name_mapping + sharding_meta["sharding_strategy"] = sharding_strategy + suffix = f"tp{self.args.tensor_parallel_rank:0>2d}_pp{self.args.pipeline_parallel_rank:0>2d}" + sharding_metas[suffix] = sharding_meta + sharding_metas_list = self._all_gather_simple_object(sharding_metas, self.hcg.get_model_parallel_group()) + sharding_metas = {k: v for e in sharding_metas_list for (k, v) in e.items()} + if self.args.tensor_parallel_rank != 0: + return None + if self.args.pipeline_parallel_degree > 1: + sharding_metas_list = self._all_gather_simple_object(sharding_metas, self.hcg.get_pipe_parallel_group()) + sharding_metas = {k: v for e in sharding_metas_list for (k, v) in e.items()} + return sharding_metas + + def _load_sharding_meta(self, dir): + suffix = f"tp{self.args.tensor_parallel_rank:0>2d}_pp{self.args.pipeline_parallel_rank:0>2d}" + distributed_model_meta = self._load_model_meta(dir) + if "sharding_metas" in distributed_model_meta: + sharding_metas = distributed_model_meta["sharding_metas"] + assert suffix in sharding_metas + sharding_meta = sharding_metas[suffix] + assert "param2rank" in sharding_meta + return sharding_meta + + # for backward compatibility + meta_path = os.path.join(dir, _add_variant(SHARDING_META_NAME, suffix)) + assert os.path.exists(meta_path), f"{meta_path} not exist" + with open(meta_path, "r") as f: + sharding_meta = json.load(f) + assert "param2rank" in sharding_meta + return sharding_meta + def _save(self, output_dir: Optional[str] = None, state_dict=None, merge_tensor_parallel=False): # If we are executing this function, we are the process zero, so we don't check for that. output_dir = output_dir if output_dir is not None else self.args.output_dir @@ -1774,9 +2206,18 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None, merge_tensor_ logger.info(f"Saving model checkpoint to {output_dir}") # Save a trained model and configuration using `save_pretrained()`. # They can then be reloaded using `from_pretrained()` + is_bf16 = self.args.bf16 + param_names_in_master_weights = [] + if is_bf16: + optimzier_state_dict = self.optimizer.state_dict() + assert "master_weights" in optimzier_state_dict + param_names_in_master_weights = list(optimzier_state_dict["master_weights"].keys()) merge_tensor_parallel = merge_tensor_parallel and self.args.use_hybrid_parallel + sharding_group = None + if paddle.distributed.get_world_size() > 1 and self.args.use_hybrid_parallel: + sharding_group = self.sharding_group if ( not isinstance(self.model, PretrainedModel) and not isinstance(self.model, LoRAModel) @@ -1788,6 +2229,11 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None, merge_tensor_ merge_tensor_parallel=merge_tensor_parallel, variant=self.args.weight_name_suffix, is_main_process=self.args.should_save, + is_bf16=is_bf16, + param_names_in_master_weights=param_names_in_master_weights, + sharding_group=sharding_group, + save_sharding_stage1_model=self.args.save_sharding_stage1_model, + optimizer=self.optimizer, ) else: logger.info("Trainer.model is not a `PretrainedModel`, only saving its state dict.") @@ -1795,6 +2241,14 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None, merge_tensor_ logger.warning("Trainer.model is not a `PretrainedModel`, not suppor for merge_tensor_parallel.") if state_dict is None: state_dict = self.model.state_dict() + if self.args.save_sharding_stage1_model: + state_dict = filter_sharded_params(state_dict, self.optimizer, sharding_group) + if is_bf16: + logger.info("before exclude state_dict_to_save len:{}".format(len(state_dict))) + state_dict = exlclude_paramters_in_state_dict( + state_dict, param_names_in_master_weights, sharding_group + ) + logger.info("after exclude state_dict len:{}".format(len(state_dict))) paddle.save( state_dict, os.path.join(output_dir, _add_variant(PADDLE_WEIGHT_FILE_NAME, self.args.weight_name_suffix)), @@ -1805,8 +2259,16 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None, merge_tensor_ merge_tensor_parallel=merge_tensor_parallel, variant=self.args.weight_name_suffix, is_main_process=self.args.should_save, + is_bf16=is_bf16, + param_names_in_master_weights=param_names_in_master_weights, + sharding_group=sharding_group, + save_sharding_stage1_model=self.args.save_sharding_stage1_model, + optimizer=self.optimizer, + sharding_degree=self.args.sharding_parallel_degree, + use_async_save=self.args.use_async_save, ) + self._save_distributed_model_meta(output_dir) if self.args.should_save: if self.tokenizer is not None: self.tokenizer.save_pretrained(output_dir) @@ -1814,22 +2276,184 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None, merge_tensor_ # Good practice: save your training arguments together with the trained model paddle.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) + def _all_gather_simple_object(self, obj, group=None): + if group is None: + group = self.hcg.get_sharding_parallel_group() + res = [] + paddle.distributed.all_gather_object(res, obj, group) + return res + + def _map_optimizer_state_to_param(self, optimizer_state_names): + optimizer = unwrap_optimizer(self.optimizer, DygraphShardingOptimizer) + all_names = list(optimizer._param2rank.keys()) + all_names.extend(list(optimizer_state_names)) + all_names.sort() + pre_p_name = "" + opt_to_p = {} + for n in all_names: + if n in optimizer._param2rank: + # we get a param + pre_p_name = n + else: + assert pre_p_name, n + opt_to_p[n] = pre_p_name + return opt_to_p + + def _load_optimizer_state_of_one_shard(self, checkpoint, optimizer_name_suffix): + optimizer_name = _add_variant(OPTIMIZER_NAME, optimizer_name_suffix) + path = os.path.join(checkpoint, optimizer_name) + logger.info(f"load optimizer state from {path}") + if os.path.isfile(path): + return paddlenlp_load(path, return_numpy=True) + logger.info(f"{path} not exists") + return None + + def _need_reshard(self, checkpoint): + parallel_config = self._load_distributed_strategy(checkpoint) + sharding_meta = self._load_sharding_meta(checkpoint) + sharding_degree = parallel_config["sharding_degree"] + sharding_strategy = SHARDING_STRATEGY_V1 + if "sharding_strategy" in sharding_meta: + sharding_strategy = parallel_config["sharding_strategy"] + cur_sharding_degree = self.args.sharding_parallel_degree + cur_sharding_strategy = reshard_util.get_sharding_strategy(self.optimizer) + if sharding_degree != cur_sharding_degree or sharding_strategy != cur_sharding_strategy: + return True + if sharding_strategy == SHARDING_STRATEGY_V1: + param2rank = sharding_meta["param2rank"] + optimizer = unwrap_optimizer(self.optimizer, DygraphShardingOptimizer) + assert optimizer + assert len(param2rank) == len(optimizer._param2rank) + for (k, v) in param2rank.items(): + assert k in optimizer._param2rank + if optimizer._param2rank[k] != int(v): + return True + return False + + def _load_optimizer_state_with_reshard(self, checkpoint): + """load state_dict of multiple shard from_checkpoint, Only load model state dict.""" + + if not self._need_reshard(checkpoint): + logger.info("do not need reshard") + return self._load_optimizer_state_of_one_shard(checkpoint, self.args.optimizer_name_suffix) + + parallel_config = self._load_distributed_strategy(checkpoint) + sharding_meta = self._load_sharding_meta(checkpoint) + pp_degree = parallel_config["pp_degree"] + mp_degree = parallel_config["mp_degree"] + sharding_degree = parallel_config["sharding_degree"] + sharding_strategy = SHARDING_STRATEGY_V1 + if "sharding_strategy" in sharding_meta: + sharding_strategy = sharding_meta["sharding_strategy"] + assert self.args.pipeline_parallel_degree == pp_degree + assert self.args.tensor_parallel_degree == mp_degree + cur_sharding_degree = self.args.sharding_parallel_degree + cur_sharding_strategy = reshard_util.get_sharding_strategy(self.optimizer) + + logger.info("reshard optimizer state") + node_model_state = reshard_util.NodeModelState() + + def get_name_suffix(i): + name = [] + if self.args.tensor_parallel_degree > 1: + name.append(f"tp{self.args.tensor_parallel_rank:0>2d}") + if self.args.pipeline_parallel_degree > 1: + name.append(f"pp{self.args.pipeline_parallel_rank:0>2d}") + name.append(f"shard{i:0>2d}") + return "_".join(name) + + structure_name_map = {k: v.name for (k, v) in self.model.state_dict().items()} + for i in range(self.args.sharding_parallel_rank, sharding_degree, cur_sharding_degree): + tmp = self._load_optimizer_state_of_one_shard(checkpoint, get_name_suffix(i)) + node_model_state_tmp = reshard_util.NodeModelState() + node_model_state_tmp.add_opts(tmp) + node_model_state_tmp.pack_keys(structure_name_map) + node_model_state.merge_from(node_model_state_tmp, i) + del tmp + del node_model_state_tmp + + restore_func = ( + reshard_util.sharding_v1.restore + if sharding_strategy == SHARDING_STRATEGY_V1 + else reshard_util.sharding_v2.restore + ) + node_model_state = restore_func(node_model_state, self.model, self.optimizer, self.hcg) + + if self.args.load_sharding_stage1_model: + shard_func = ( + reshard_util.sharding_v1.shard + if cur_sharding_strategy == SHARDING_STRATEGY_V1 + else reshard_util.sharding_v2.shard + ) + node_model_state = shard_func(node_model_state, self.model, self.optimizer, self.hcg) + + # drop structural name in the key + node_model_state.unpack_keys() + + return node_model_state.get_opt_state_dict() + + def _load_optimizer_state(self, checkpoint): + if self.args.load_sharded_model: + return self._load_optimizer_state_with_reshard(checkpoint) + else: + return self._load_optimizer_state_of_one_shard(checkpoint, self.args.optimizer_name_suffix) + def _load_optimizer_and_scheduler(self, checkpoint): """If optimizer and scheduler states exist, load them.""" if checkpoint is None: return - optimizer_name = _add_variant(OPTIMIZER_NAME, self.args.optimizer_name_suffix) + opt_state_dict = self._load_optimizer_state(checkpoint) + + if opt_state_dict and os.path.isfile(os.path.join(checkpoint, SCHEDULER_NAME)): + # Note(GuoxiaWang): The checkpoint is not saved by ClipGradByAdaptiveNorm during training. + # To avoid errors, add a temporary empty dict to the checkpoint. + if ( + hasattr(self.args, "adaptive_norm_clip") + and self.args.adaptive_norm_clip + and "LR_Scheduler" in opt_state_dict + and "adaptive_norm" not in opt_state_dict["LR_Scheduler"] + ): + opt_state_dict["LR_Scheduler"]["adaptive_norm"] = {} - if os.path.isfile(os.path.join(checkpoint, optimizer_name)) and os.path.isfile( - os.path.join(checkpoint, SCHEDULER_NAME) - ): # Load in optimizer and scheduler states - self.optimizer.set_state_dict(paddlenlp_load(os.path.join(checkpoint, optimizer_name), return_numpy=True)) + self.optimizer.set_state_dict(opt_state_dict) + + # Note(GuoxiaWang): Hold correct adaptive_norm state dict + if ( + hasattr(self.args, "adaptive_norm_clip") + and self.args.adaptive_norm_clip + and hasattr(self.optimizer._learning_rate, "adaptive_norm") + ): + adaptive_norm = self.optimizer._learning_rate.adaptive_norm + + sched_state_dict = paddle.load(os.path.join(checkpoint, SCHEDULER_NAME)) + # Note(GuoxiaWang): The checkpoint is not saved by ClipGradByAdaptiveNorm during training. + # To avoid errors, add a temporary empty dict to the checkpoint. + if ( + hasattr(self.args, "adaptive_norm_clip") + and self.args.adaptive_norm_clip + and "adaptive_norm" not in sched_state_dict + ): + sched_state_dict["adaptive_norm"] = {} + + self.lr_scheduler.set_state_dict(sched_state_dict) + + # Note(GuoxiaWang): Because the state dict of lr_scheduler has been set to the state on the global rank 0 at this time + # so, it need restore correctly adaptive_norm state dict + if ( + hasattr(self.args, "adaptive_norm_clip") + and self.args.adaptive_norm_clip + and hasattr(self.optimizer._learning_rate, "adaptive_norm") + ): + self.optimizer._learning_rate.adaptive_norm = adaptive_norm - self.lr_scheduler.set_state_dict(paddle.load(os.path.join(checkpoint, SCHEDULER_NAME))) if self.do_grad_scaling and os.path.isfile(os.path.join(checkpoint, SCALER_NAME)): self.scaler.load_state_dict(paddle.load(os.path.join(checkpoint, SCALER_NAME), return_numpy=True)) + else: + raise ValueError( + f"optimizer-state-dict not found, opt:{checkpoint} scheduler:{os.path.join(checkpoint, SCHEDULER_NAME)}" + ) def log(self, logs: Dict[str, float], **kwargs) -> None: """ @@ -1841,9 +2465,15 @@ def log(self, logs: Dict[str, float], **kwargs) -> None: logs (`Dict[str, float]`): The values to log. """ + + try: + paddle_pipeline_timers = paddle_get_timers() + except AssertionError: + paddle_pipeline_timers = None + kwargs.update(timer=self.timers, paddle_pipeline_timers=paddle_pipeline_timers) + if self.state.epoch is not None: logs["epoch"] = round(self.state.epoch, 4) - output = {**logs, **{"step": self.state.global_step}} self.state.log_history.append(output) self.control = self.callback_handler.on_log(self.args, self.state, self.control, logs, **kwargs) diff --git a/paddlenlp/trainer/trainer_utils.py b/paddlenlp/trainer/trainer_utils.py index f9a88a7a7c2b..5bddb01f5ce0 100644 --- a/paddlenlp/trainer/trainer_utils.py +++ b/paddlenlp/trainer/trainer_utils.py @@ -218,10 +218,10 @@ def speed_metrics(split, start_time, num_samples=None, num_steps=None): result = {f"{split}_runtime": round(runtime, 4)} if num_samples is not None: samples_per_second = num_samples / runtime - result[f"{split}_samples_per_second"] = round(samples_per_second, 3) + result[f"{split}_samples_per_second"] = samples_per_second if num_steps is not None: steps_per_second = num_steps / runtime - result[f"{split}_steps_per_second"] = round(steps_per_second, 3) + result[f"{split}_steps_per_second"] = steps_per_second return result diff --git a/paddlenlp/trainer/training_args.py b/paddlenlp/trainer/training_args.py index a734eab090fe..c2f5a7699bab 100644 --- a/paddlenlp/trainer/training_args.py +++ b/paddlenlp/trainer/training_args.py @@ -20,6 +20,7 @@ import json import math import os +import time import types import warnings from dataclasses import asdict, dataclass, field @@ -219,6 +220,10 @@ class TrainingArguments: disable_partial_send_recv, optmize send speed for tensor parallel. enable_delay_scale_loss, accumulate gradients util optimizer step, all gradients div by inner pipeline accumute step. instead of div accumute step on loss directly. enable_dp_comm_overlap, fuse data parallel gradient communication. + sharding_parallel_config (`str`, *optional*)( + Some additional config it highly affect the useage of sharding parallel, we provide some option to config it. + following config is support: + split_param, use Sharding stage1 V2 recompute (`bool`, *optional*, defaults to `False`): Recompute the forward pass to calculate gradients. Used for saving memory. Only support for networks with transformer blocks. @@ -293,6 +298,8 @@ class TrainingArguments: scripts](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/examples) for more details. flatten_param_grads (`bool`, *optional*): Whether use flatten_param_grads method in optimizer, only used on NPU devices. Default is `False`. + skip_profile_timer (`bool`, *optional*): + Whether skip profile timer, timer will record time usage of forward/ backward/ step, etc. """ output_dir: str = field( @@ -463,6 +470,15 @@ class TrainingArguments: ) }, ) + save_sharded_model: bool = field( + default=False, + metadata={"help": ("whether saved sharded model when sharding_parallel_degree > 1")}, + ) + + load_sharded_model: bool = field( + default=False, + metadata={"help": ("whether load a sharded model when sharding_parallel_degree > 1")}, + ) tensor_parallel_degree: int = field( default=-1, metadata={ @@ -500,6 +516,17 @@ class TrainingArguments: ) }, ) + sharding_parallel_config: str = field( + default="", + metadata={ + "help": ( + "Some additional config it highly affect the useage of sharding parallel, we provide some option to config it." + "following config is support: \n" + "split_param, sharding split param\n" + ) + }, + ) + recompute: bool = field( default=False, metadata={ @@ -589,6 +616,18 @@ class TrainingArguments: default=True, metadata={"help": "Whether use lazy data processing."}, ) + use_async_save: Optional[bool] = field( + default=False, + metadata={"help": "Whether to use paddle.async_save instead of paddle.save."}, + ) + skip_profile_timer: Optional[bool] = field( + default=True, + metadata={"help": "enable framework timer, will output timeline informatoin in logging and visualdl"}, + ) + use_moe: Optional[bool] = field( + default=False, + metadata={"help": "开启moe训练"}, + ) def __post_init__(self): env_local_rank = int(os.environ.get("PADDLE_RANK_IN_NODE", -1)) @@ -695,14 +734,14 @@ def __post_init__(self): if self.amp_master_grad: if ( - self.pipeline_parallel_degree <= 1 and self.tensor_parallel_degree <= 1 - ) or self.fp16_opt_level != "O2": + self.pipeline_parallel_degree <= 1 + and self.tensor_parallel_degree <= 1 + and not (self.sharding and ShardingOption.SHARD_OP in self.sharding) + ): raise ValueError( - "Temporarily amp master grad only suport for tensor/pipeline parallel with fp16_opt_level O2. please set amp_master_grad to False." + "Temporarily amp master grad only support for tensor/pipeline/sharding parallel. " + "Please set amp_master_grad to False." ) - if not (self.bf16 or self.fp16): - logger.warning("set amp_master_grad to false since amp is disabled.") - self.amp_master_grad = False if self.use_hybrid_parallel: world_size = paddle.distributed.get_world_size() @@ -753,6 +792,7 @@ def __post_init__(self): "disable_partial_send_recv", "enable_delay_scale_loss", "enable_dp_comm_overlap", + "enable_timer", ]: raise ValueError( f"Found unknown pipeline mode config {x}, accpet config is disable_p2p_cache_shape, disable_partial_send_recv." @@ -766,9 +806,17 @@ def __post_init__(self): # "delay_scale_loss": True, Fix ME } logger.info(f"PP configs:{strategy.pipeline_configs}, use master_grad: {self.amp_master_grad}") + + debug_hang = int(os.environ.get("PADDLE_DEBUG_HANG", 0)) dygraph_pp_configs = { "delay_scale_loss": True if "enable_delay_scale_loss" in pipeline_parallel_config else False, - "dp_comm_overlap": "enable_dp_comm_overlap" in pipeline_parallel_config, + "dp_comm_overlap": "enable_dp_comm_overlap" in pipeline_parallel_config + and not debug_hang + and self.data_parallel_degree > 1, + "sharding_comm_overlap": "enable_dp_comm_overlap" in pipeline_parallel_config + and not debug_hang + and self.sharding_parallel_degree > 1, + "enable_timer": "enable_timer" in pipeline_parallel_config, } if self.do_eval: @@ -780,23 +828,50 @@ def __post_init__(self): "Please set per_device_eval_batch_size=per_device_train_batch_size * gradient_accumulation_steps." ) + if sharding_parallel_degree > 1: + sharding_parallel_config = set(self.sharding_parallel_config.split(" ")) + for x in sharding_parallel_config: + if len(x) > 0: + if x not in [ + "split_param", + ]: + raise ValueError(f"Found unknown pipeline mode config {x},accpet config split_param.") + sharding_split_param = "split_param" in sharding_parallel_config + if tensor_parallel_degree > 1: strategy.tensor_parallel_configs = {"tensor_init_seed": self.seed} + if self.use_moe: + order = ["sharding", "pp", "dp", "mp"] + elif tensor_parallel_degree == 1 and sharding_parallel_degree == 1: + order = ["pp", "dp", "sharding", "mp"] + else: + order = ["dp", "sharding", "pp", "mp"] hybrid_configs = { "dp_degree": self.data_parallel_degree, "mp_degree": tensor_parallel_degree, "pp_degree": pipeline_parallel_degree, + "order": order, "sharding_degree": sharding_parallel_degree, } if pipeline_parallel_degree > 1: hybrid_configs["pp_configs"] = dygraph_pp_configs - logger.info(f"using pipline configs:{dygraph_pp_configs}") + logger.info(f"using pipeline configs:{dygraph_pp_configs}") # setter once https://github.com/PaddlePaddle/Paddle/blob/b7295120b0e78b293cd7ae29706e21769d06a3cc/python/paddle/distributed/fleet/base/distributed_strategy.py#L1692 strategy.hybrid_configs = hybrid_configs + + if sharding_parallel_degree > 1: + if sharding_split_param: + strategy.hybrid_configs["sharding_configs"].split_param = True + + paddle.device.cuda.synchronize() + start_time = time.time() fleet.init(is_collective=True, strategy=strategy) + paddle.device.cuda.synchronize() + elapsed = time.time() - start_time + logger.info("NCCL-Connection costs {:.2f} ms.".format(elapsed)) logger.info(strategy) @@ -936,9 +1011,28 @@ def optimizer_name_suffix(self): name.append(f"pp{self.pipeline_parallel_rank:0>2d}") if self.sharding_parallel_degree > 1: name.append(f"shard{self.sharding_parallel_rank:0>2d}") + if self.use_moe: + name.append(f"moe{self.data_parallel_rank:0>2d}") + return "_".join(name) + else: + if self.use_moe: + return f"moe{self.data_parallel_rank:0>2d}" + return None + @property + def old_weight_name_suffix(self): + if self.use_hybrid_parallel: + name = [] + if self.tensor_parallel_degree > 1: + name.append(f"tp{self.tensor_parallel_rank:0>2d}") + if self.pipeline_parallel_degree > 1: + name.append(f"pp{self.pipeline_parallel_rank:0>2d}") + if self.use_moe: + name.append(f"moe{self.data_parallel_rank:0>2d}") return "_".join(name) else: + if self.use_moe: + return f"moe{self.data_parallel_rank:0>2d}" return None @property @@ -949,8 +1043,14 @@ def weight_name_suffix(self): name.append(f"tp{self.tensor_parallel_rank:0>2d}") if self.pipeline_parallel_degree > 1: name.append(f"pp{self.pipeline_parallel_rank:0>2d}") + if self.save_sharding_stage1_model: + name.append(f"shard{self.sharding_parallel_rank:0>2d}") + if self.use_moe: + name.append(f"moe{self.data_parallel_rank:0>2d}") return "_".join(name) else: + if self.use_moe: + return f"moe{self.data_parallel_rank:0>2d}" return None @property @@ -1011,7 +1111,9 @@ def should_save_model_state(self): if self.save_on_each_node: return self.local_process_index == 0 else: - if self.tensor_parallel_degree > 1: + if self.save_sharding_stage1_model: + return True + elif self.use_hybrid_parallel: # save on dataset rank 0 return self.sharding_parallel_rank == 0 and self.data_parallel_rank == 0 else: @@ -1024,6 +1126,18 @@ def _no_sync_in_gradient_accumulation(self): """ return True + @property + def save_sharding_stage1_model(self): + return ( + ShardingOption.SHARD_OP in self.sharding and self.sharding_parallel_degree > 1 and self.save_sharded_model + ) + + @property + def load_sharding_stage1_model(self): + return ( + ShardingOption.SHARD_OP in self.sharding and self.sharding_parallel_degree > 1 and self.load_sharded_model + ) + @contextlib.contextmanager def main_process_first(self, local=True, desc="work"): """ diff --git a/paddlenlp/trainer/utils/__init__.py b/paddlenlp/trainer/utils/__init__.py index d432b9716764..b78376c12d04 100644 --- a/paddlenlp/trainer/utils/__init__.py +++ b/paddlenlp/trainer/utils/__init__.py @@ -12,10 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .helper import * - from .doc import ( add_end_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, ) +from .helper import * diff --git a/paddlenlp/trainer/utils/reshard/__init__.py b/paddlenlp/trainer/utils/reshard/__init__.py new file mode 100644 index 000000000000..8fff8ab230b7 --- /dev/null +++ b/paddlenlp/trainer/utils/reshard/__init__.py @@ -0,0 +1,23 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import sharding_v1, sharding_v2 +from .common import ( + SHARDING_STRATEGY_V1, + SHARDING_STRATEGY_V2, + NodeModelState, + all_gather_state_dict, + get_sharding_strategy, + is_sharding_opt, +) diff --git a/paddlenlp/trainer/utils/reshard/common.py b/paddlenlp/trainer/utils/reshard/common.py new file mode 100644 index 000000000000..22bb36a43602 --- /dev/null +++ b/paddlenlp/trainer/utils/reshard/common.py @@ -0,0 +1,523 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from collections import OrderedDict + +import numpy as np +import paddle +from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer import ( + DygraphShardingOptimizer, +) +from paddle.distributed.fleet.utils.log_util import logger + +try: + from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer import ( + DygraphShardingOptimizerV2, + ) +except: + DygraphShardingOptimizerV2 = None + + +from ....transformers.model_utils import unwrap_optimizer + +SHARDING_STRATEGY_V1 = "ShardingV1" +SHARDING_STRATEGY_V2 = "ShardingV2" + + +def is_sharding_opt(optimizer): + def check(cls): + tmp = unwrap_optimizer(optimizer, cls) + if tmp is not None: + return True + return False + + if check(DygraphShardingOptimizer): + return True + + if DygraphShardingOptimizerV2 is not None: + if check(DygraphShardingOptimizerV2): + return True + + return False + + +def get_sharding_strategy(optimizer): + if DygraphShardingOptimizerV2 is not None: + tmp = unwrap_optimizer(optimizer, DygraphShardingOptimizerV2) + if tmp is not None: + return SHARDING_STRATEGY_V2 + return SHARDING_STRATEGY_V1 + + +class NodeModelState: + def __init__(self, mp_rank=None, sharding_rank=None, pp_rank=None): + self._model_weights = OrderedDict() + self._opt_state = OrderedDict() + self._master_weights = OrderedDict() + self._lr_scheduler = None + self.set_node_rank(mp_rank, sharding_rank, pp_rank) + + def set_node_rank(self, mp_rank, sharding_rank, pp_rank): + self._mp_rank = mp_rank + self._sharding_rank = sharding_rank + self._pp_rank = pp_rank + + def _add_kv(self, d, k, v): + assert k not in d + d[k] = v + + @property + def model_weights(self): + return self._model_weights + + def add_weight(self, k, v): + self._add_kv(self._model_weights, k, v) + + def add_weights(self, model_state_dict, rank=None): + for (k, v) in model_state_dict.items(): + if rank is not None: + k = (k, rank) + self.add_weight(k, v) + + def set_weights(self, model_state_dict): + self._master_weights = model_state_dict + + def set_opt_state(self, opt_state_dict): + self._opt_state = opt_state_dict + + def set_master_weights(self, master_weights): + self._master_weights = master_weights + + @property + def opt_state(self): + return self._opt_state + + def add_opt(self, k, v): + self._add_kv(self._opt_state, k, v) + + def add_opts(self, opts, rank=None): + if "master_weights" in opts: + s_master = opts["master_weights"] + opts.pop("master_weights") + self.add_master_weights(s_master, rank) + + if "LR_Scheduler" in opts: + lr_scheduler = opts["LR_Scheduler"] + opts.pop("LR_Scheduler") + self.set_lr_scheduler(lr_scheduler) + + for (k, v) in opts.items(): + if rank is not None: + k = (k, rank) + self.add_opt(k, v) + + @property + def master_weights(self): + return self._master_weights + + def add_master_weight(self, k, v): + self._add_kv(self._master_weights, k, v) + + def add_master_weights(self, master, rank=None): + for (k, v) in master.items(): + if rank is not None: + k = (k, rank) + self.add_master_weight(k, v) + + @property + def lr_scheduler(self): + return self._lr_scheduler + + def set_lr_scheduler(self, lr_scheduler): + if lr_scheduler is not None: + self._lr_scheduler = lr_scheduler + + def map_names(self, map_func): + # model weighs + model_weights_tmp = OrderedDict() + (self._model_weights, model_weights_tmp) = (model_weights_tmp, self._model_weights) + for key in list(model_weights_tmp.keys()): + structure_name, t_name = key + t_name_new = map_func(structure_name, t_name) + self._model_weights[(structure_name, t_name_new)] = model_weights_tmp[key] + del model_weights_tmp[key] + + # opt + opt_tmp = OrderedDict() + (self._opt_state, opt_tmp) = (opt_tmp, self._opt_state) + for key in list(opt_tmp.keys()): + structure_name, t_name, opt_name = key + t_name_new = map_func(structure_name, t_name) + if self._model_weights: + assert (structure_name, t_name_new) in self._model_weights + opt_name_new = t_name_new + opt_name[len(t_name) :] + self._opt_state[(structure_name, t_name_new, opt_name_new)] = opt_tmp[key] + del opt_tmp[key] + + # master weights + master_weights_tmp = OrderedDict() + (self._master_weights, master_weights_tmp) = (master_weights_tmp, self._master_weights) + for key in list(master_weights_tmp.keys()): + structure_name, master_weights_name = key + master_weights_name_new = map_func(structure_name, master_weights_name) + if self._model_weights: + assert (structure_name, t_name_new) in self._model_weights + self._master_weights[(structure_name, master_weights_name_new)] = master_weights_tmp[key] + del master_weights_tmp[key] + + return self + + def drop_rank(self): + def drop(state, l=2): + tmp_state = OrderedDict() + (state, tmp_state) = (tmp_state, state) + for key in list(tmp_state.keys()): + k, rank = key + assert len(key) == 2 + assert len(k) == l + state[k] = tmp_state[key] + del tmp_state[key] + return state + + self._model_weights = drop(self._model_weights, 2) + self._opt_state = drop(self._opt_state, 3) + self._master_weights = drop(self._master_weights, 2) + return self + + def collapse_key(self): + def collapse(state, l): + tmp_state = OrderedDict() + (state, tmp_state) = (tmp_state, state) + state_keys = list(tmp_state.keys()) + state_keys = sorted(state_keys) + pre = None + for key in state_keys: + assert len(key) == 2 + k, rank = key + assert len(k) == l + if k != pre: + pre = k + state[k] = [] + state[k].append((rank, tmp_state[key])) + del tmp_state[key] + return state + + self._model_weights = collapse(self._model_weights, 2) + self._opt_state = collapse(self._opt_state, 3) + self._master_weights = collapse(self._master_weights, 2) + return self + + def flatten_key(self): + def flatten(state, l): + tmp_state = OrderedDict() + (state, tmp_state) = (tmp_state, state) + state_keys = list(tmp_state.keys()) + for key in state_keys: + assert len(key) == l + for (rank, items) in tmp_state[key]: + state[(key, rank)] = items + tmp_state[key] + return state + + self._model_weights = flatten(self._model_weights, 2) + self._opt_state = flatten(self._opt_state, 3) + self._master_weights = flatten(self._master_weights, 2) + return self + + def pack_keys(self, structure_name_mapping=None): + + # pack key for pp convert + def _opt_name_to_tname(tensor_names, opt_names): + tensor_names = set(tensor_names) + all_names = [] + all_names.extend(list(tensor_names)) + all_names.extend(opt_names) + all_names.sort() + pre_t_name = "" + opt_to_t = {} + for n in all_names: + if n in tensor_names: + # we get a param + pre_t_name = n + else: + assert pre_t_name + opt_to_t[n] = pre_t_name + return opt_to_t + + if structure_name_mapping is not None: + tname_to_structure_name = {v: k for (k, v) in structure_name_mapping.items()} + else: + structure_name_mapping = {k: v for (k, v) in self._model_weights.items()} + tname_to_structure_name = {v: k for (k, v) in structure_name_mapping.items()} + + tensor_names = list(tname_to_structure_name.keys()) + opt_names = list(self._opt_state.keys()) + opt_name_to_tname = _opt_name_to_tname(tensor_names, opt_names) + + # model state + model_weights_tmp = OrderedDict() + (self._model_weights, model_weights_tmp) = (model_weights_tmp, self._model_weights) + for k in list(model_weights_tmp.keys()): + t_name = structure_name_mapping[k] + self._model_weights[(k, t_name)] = model_weights_tmp[k] + del model_weights_tmp[k] + + # opt + opt_tmp = OrderedDict() + (self._opt_state, opt_tmp) = (opt_tmp, self._opt_state) + for opt_name in list(opt_tmp.keys()): + assert opt_name in opt_name_to_tname + t_name = opt_name_to_tname[opt_name] + assert t_name in tname_to_structure_name + structure_name = tname_to_structure_name[t_name] + self._opt_state[(structure_name, t_name, opt_name)] = opt_tmp[opt_name].cpu() + del opt_tmp[opt_name] + + # master weights + master_weights_tmp = OrderedDict() + (self._master_weights, master_weights_tmp) = (master_weights_tmp, self._master_weights) + for master_weights_name in list(master_weights_tmp.keys()): + assert master_weights_name in tname_to_structure_name + structure_name = tname_to_structure_name[master_weights_name] + self._master_weights[(structure_name, master_weights_name)] = master_weights_tmp[master_weights_name].cpu() + del master_weights_tmp[master_weights_name] + + return self + + def unpack_keys(self): + # model weights + model_weights_tmp = OrderedDict() + (self._model_weights, model_weights_tmp) = (model_weights_tmp, self._model_weights) + for key in list(model_weights_tmp.keys()): + structure_name, t_name = key + self._model_weights[structure_name] = model_weights_tmp[key] + self._model_weights[structure_name].name = t_name + del model_weights_tmp[key] + # opt + opt_tmp = OrderedDict() + (self._opt_state, opt_tmp) = (opt_tmp, self._opt_state) + for key in list(opt_tmp.keys()): + structure_name, t_name, opt_name = key + if structure_name in self._model_weights: + assert self._model_weights[structure_name].name == t_name + self._opt_state[opt_name] = opt_tmp[key] + del opt_tmp[key] + + # master weights + master_weights_tmp = OrderedDict() + (self._master_weights, master_weights_tmp) = (master_weights_tmp, self._master_weights) + for key in list(master_weights_tmp.keys()): + structure_name, master_weights_name = key + if structure_name in self._model_weights: + assert self._model_weights[structure_name].name == master_weights_name + self._master_weights[master_weights_name] = master_weights_tmp[key] + del master_weights_tmp[key] + return self + + def _model_file(self, dir): + return os.path.join( + dir, + f"model_state.tp{self._mp_rank:0>2d}_pp{self._pp_rank:0>2d}_shard{self._sharding_rank:0>2d}.pdparams", + ) + + def _optimizer_file(self, dir): + return os.path.join( + dir, f"optimizer.tp{self._mp_rank:0>2d}_pp{self._pp_rank:0>2d}_shard{self._sharding_rank:0>2d}.pdopt" + ) + + def save(self, dir): + model_file = self._model_file(dir) + paddle.save(self._model_weights, model_file) + + def load(self, dir): + model_file = self._model_file(dir) + model_state = paddle.load(model_file) + self.add_weights(model_state) + opt_file = self._optimizer_file(dir) + opt_state = paddle.load(opt_file) + self.add_opts(opt_state) + + def split_state(self, split_func): + node_model_states = {} + for (k, v) in self._model_weights.items(): + rank = split_func(k) + if rank not in node_model_states: + node_model_states[rank] = NodeModelState() + node_model_states[rank].add_weight(k, v) + + for (k, v) in self._opt_state.items(): + rank = split_func(k) + if rank not in node_model_states: + node_model_states[rank] = NodeModelState() + node_model_states[rank].add_opt(k, v) + + for (k, v) in self._master_weights.items(): + rank = split_func(k) + if rank not in node_model_states: + node_model_states[rank] = NodeModelState() + node_model_states[rank].add_master_weight(k, v) + + return node_model_states + + def even_distribute(self, group): + def distribute(get_state): + self.collapse_key() + state = get_state() + state_keys_list = all_gather_simple_object(list(state.keys()), group) + total_state_key = set() + for keys in state_keys_list: + for k in keys: + total_state_key.add(k) + total_state_key = list(total_state_key) + total_state_key = sorted(total_state_key) + key_to_rank = {} + for (i, key) in enumerate(total_state_key): + key_to_rank[key] = i % group.nranks + + def filter_func(key): + assert key[0] in key_to_rank, key + dst_rank = key_to_rank[key[0]] + return dst_rank == group.rank + + self.flatten_key() + state = get_state() + return _all_gather_state_dict(state, filter_func, group) + + self._model_weights = distribute(lambda: self._model_weights) + self._opt_state = distribute(lambda: self._opt_state) + self._master_weights = distribute(lambda: self._master_weights) + return self + + def reshard(self, group, filter_func): + self._model_weights = _all_gather_state_dict(self._model_weights, filter_func, group) + self._opt_state = _all_gather_state_dict(self._opt_state, filter_func, group) + self._master_weights = _all_gather_state_dict(self._master_weights, filter_func, group) + lr_schedulers = all_gather_simple_object(self._lr_scheduler, group) + self._lr_scheduler = lr_schedulers[0] + return self + + def split_items(self, split_func): + def split(state, l): + tmp_state = OrderedDict() + (state, tmp_state) = (tmp_state, state) + state_keys = list(tmp_state.keys()) + for key in state_keys: + assert len(key) == l + v = tmp_state[key] + state[key] = split_func(key, v) + del tmp_state[key] + return state + + self._model_weights = split(self._model_weights, 2) + self._opt_state = split(self._opt_state, 3) + self._master_weights = split(self._master_weights, 2) + return self + + def merge_items(self, merge_func): + def merge(state, l): + tmp_state = OrderedDict() + (state, tmp_state) = (tmp_state, state) + state_keys = list(tmp_state.keys()) + for key in state_keys: + assert len(key) == l + v = tmp_state[key] + v = sorted(v, key=lambda x: x[0]) + state[key] = merge_func(key, v) + del tmp_state[key] + return state + + self._model_weights = merge(self._model_weights, 2) + self._opt_state = merge(self._opt_state, 3) + self._master_weights = merge(self._master_weights, 2) + return self + + def merge_from(self, other, rank=None): + self.add_opts(other.opt_state, rank) + self.add_master_weights(other.master_weights, rank) + self.set_lr_scheduler(other.lr_scheduler) + + def get_opt_state_dict(self): + opt_state_dict = OrderedDict() + for (k, v) in self.opt_state.items(): + opt_state_dict[k] = v + if self._lr_scheduler is not None: + opt_state_dict["LR_Scheduler"] = self._lr_scheduler + opt_state_dict["master_weights"] = self._master_weights + return opt_state_dict + + +def all_gather_simple_object(obj, group): + res = [] + paddle.distributed.all_gather_object(res, obj, group) + return res + + +def all_gather_state_dict(state_dict, filter_func, group): + res = OrderedDict() + + def map_func(weight): + if isinstance(weight, paddle.Tensor): + weight = weight.numpy() + return weight + + state_dict = {k: map_func(v) for (k, v) in state_dict.items()} + + meta_dict = {} + for (k, v) in state_dict.items(): + # src rank + meta_dict[k] = (v.dtype, v.shape, group.rank) + + meta_dict_list = all_gather_simple_object(meta_dict, group) + + total_meta_dict = {} + for meta_dict in meta_dict_list: + for (k, v) in meta_dict.items(): + assert k not in total_meta_dict + total_meta_dict[k] = v + + meta_list = list(total_meta_dict.items()) + meta_list = sorted(meta_list, key=lambda x: x[0]) + for (k, meta) in meta_list: + dtype, shape, rank = meta + if rank == group.rank: + assert k in state_dict + tensor = paddle.to_tensor(state_dict[k]) + del state_dict[k] + else: + tensor = paddle.to_tensor(np.empty(shape, dtype)) + logger.info(f"broadcast {k} from {rank}") + # broadcast the tensor + paddle.distributed.broadcast( + tensor, + src=group.ranks[rank], + group=group, + sync_op=True, + ) + if filter_func(k): + res[k] = tensor.cpu() + del tensor + return res + + +def _all_gather_state_dict(state_dict, filter_func, group): + remote_state_dict_keys = [k for k in state_dict.keys() if not filter_func(k)] + tmp_state_dict = OrderedDict() + for k in remote_state_dict_keys: + tmp_state_dict[k] = state_dict[k] + state_dict.pop(k) + tmp_state_dict = all_gather_state_dict(tmp_state_dict, filter_func, group) + for (k, v) in tmp_state_dict.items(): + state_dict[k] = v + return state_dict diff --git a/paddlenlp/trainer/utils/reshard/sharding_v1.py b/paddlenlp/trainer/utils/reshard/sharding_v1.py new file mode 100644 index 000000000000..6c7e637ec6b3 --- /dev/null +++ b/paddlenlp/trainer/utils/reshard/sharding_v1.py @@ -0,0 +1,42 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer import ( + DygraphShardingOptimizer, +) + +from ....transformers.model_utils import unwrap_optimizer + + +def shard(node_model_state, model, optimizer, hcg): + group = hcg.get_sharding_parallel_group() + cur_rank = group.rank + optimizer = unwrap_optimizer(optimizer, DygraphShardingOptimizer) + assert optimizer is not None + param2rank = optimizer._param2rank + + def filter_func(key): + names = key + param_name = names[1] + assert param_name in param2rank + dst_rank = param2rank[param_name] + return dst_rank == cur_rank + + node_model_state.reshard(group, filter_func) + return node_model_state + + +def restore(node_model_state, model, optimizer, hcg): + node_model_state.drop_rank() + return node_model_state diff --git a/paddlenlp/trainer/utils/reshard/sharding_v2.py b/paddlenlp/trainer/utils/reshard/sharding_v2.py new file mode 100644 index 000000000000..5729ab22e408 --- /dev/null +++ b/paddlenlp/trainer/utils/reshard/sharding_v2.py @@ -0,0 +1,162 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import paddle +from paddle.distributed.fleet.model import PipelineParallel + +from ....transformers.model_utils import unwrap_optimizer + +try: + from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer import ( + DygraphShardingOptimizerV2, + ) +except: + DygraphShardingOptimizerV2 = None + + +def shard(node_model_state, model, optimizer, hcg): + assert DygraphShardingOptimizerV2 is not None + group = hcg.get_sharding_parallel_group() + cur_rank = group.rank + split_infos = collect_split_info(optimizer, model) + + def split_func(k, v): + param_name = k[1] + opt_name = k[-1] + assert param_name in split_infos + is_beta = is_bata(opt_name) + index, padded_size, buffer_size, has_slice_grad = split_infos[param_name] + + if not is_beta: + v = pad_tensor(v, padded_size) + + def get_slice(v, begin, end): + if is_beta: + return v + return slice_tensor(v, begin, end) + + assert buffer_size % group.nranks == 0, f"buffer_size {buffer_size} group.nranks {group.nranks}" + buffer_slice = buffer_size // group.nranks + + # has slice grad in cur rank + if has_slice_grad: + assert index < (cur_rank + 1) * buffer_slice + assert index + padded_size > cur_rank * buffer_slice + + offset = buffer_slice - index % buffer_slice + tensors = [] + tensors.append((index // buffer_slice, get_slice(v, 0, offset))) + left_size = padded_size - offset + for _ in range((left_size + buffer_slice - 1) // buffer_slice): + end = min(offset + buffer_slice, padded_size) + assert end <= buffer_size + tensors.append(((offset + index) // buffer_slice, get_slice(v, offset, end))) + offset = end + + return tensors + + node_model_state.split_items(split_func).flatten_key() + + def filter_func(k): + names, rank = k + assert rank < group.nranks + return rank == cur_rank + + # reshard + node_model_state.reshard(group, filter_func) + node_model_state.drop_rank() + return node_model_state + + +def restore(node_model_state, model, optimizer, hcg): + group = hcg.get_sharding_parallel_group() + # evenly distribute param + node_model_state.even_distribute(group) + param_shapes = {k: v.shape for (k, v) in model.state_dict().items()} + + def merge_func(k, v): + structure_name = k[0] + opt_name = k[-1] + assert structure_name in param_shapes, structure_name + tensor_list = [e[1] for e in v] + # do not merge beta acc + if is_bata(opt_name): + return tensor_list[0] + shape = param_shapes[structure_name] + return merge_tensors(tensor_list, shape) + + node_model_state.collapse_key().merge_items(merge_func) + return node_model_state + + +def merge_tensors(tensor_list, shape): + assert len(tensor_list) > 0 + if len(tensor_list) == 1: + t = tensor_list[0] + else: + assert len(tensor_list[0].shape) == 1 + t = paddle.concat(x=tensor_list, axis=0) + tensor_size = np.prod(shape) + padded_size = t._numel() + assert padded_size >= tensor_size + t = t._slice(0, tensor_size) + t.get_tensor()._set_dims(shape) + return t + + +def pad_tensor(tensor, padded_size): + tensor_shape = tensor.shape + tensor_size = np.prod(tensor_shape) + assert tensor_size <= padded_size + t = paddle.zeros([padded_size], dtype=tensor.dtype) + tensor.flatten_() + t[0:tensor_size] = tensor + tensor.get_tensor()._set_dims(tensor_shape) + return t + + +def slice_tensor(tensor, begin, end): + return tensor[begin:end] + + +def collect_split_info(optimizer, model): + split_infos = {} + + def gather_infos(comm_buffer): + for (k, v) in comm_buffer._sharding_param_grad_view.items(): + index = v._index + padded_size = v._padded_size + buffer_size = v._param_buffer._numel() + has_slice_grad = v._slice_grad is not None + split_infos[k] = (index, padded_size, buffer_size, has_slice_grad) + + if isinstance(model, PipelineParallel) and len(model._chunk_2_comm_buffers) > 0: + for (k, v) in model._chunk_2_comm_buffers.items(): + for comm_buffer in v: + gather_infos(comm_buffer) + else: + optimizer = unwrap_optimizer(optimizer, DygraphShardingOptimizerV2) + for comm_buffer in optimizer._comm_buffer_list: + gather_infos(comm_buffer) + assert len(split_infos) + return split_infos + + +def is_bata(name): + if "_beta1_pow_acc_" in name: + return True + if "_beta2_pow_acc_" in name: + return True + return False diff --git a/paddlenlp/transformers/conversion_utils.py b/paddlenlp/transformers/conversion_utils.py index 5cef6990b93d..d0ca4e0355bd 100644 --- a/paddlenlp/transformers/conversion_utils.py +++ b/paddlenlp/transformers/conversion_utils.py @@ -256,7 +256,65 @@ def get_diff_keys(self, return_all_diff: bool = False) -> List[str]: return all_diff_keys -def merge_tensor_parallel_weight(weight_list, is_column=True): +def naive_fuse_merge_tp(weight_list, is_column=True, fuse_tensor_parts=2): + """ + + [A1 B1],[A2 B2] => [A1, A2, B1, B2] + + Args: + weight_list (List[np.ndarray]): The splited tensor parallel weight list. + is_column (bool, optional): Is ColumnLinear or RowLinear. Defaults to True. + + Returns: + weight (np.ndarray): the merged weight. + """ + if is_column: + axis = -1 + else: + axis = 0 + + reorder = [] + for item in weight_list: + reorder.extend(np.split(item, fuse_tensor_parts, axis=axis)) + # 0 1 2 3 -> 0 2 1 3 + index = ( + np.transpose(np.arange(len(reorder)).reshape([len(weight_list), fuse_tensor_parts]), [1, 0]) + .reshape(-1) + .tolist() + ) + return np.concatenate([reorder[i] for i in index], axis=axis) + + +def naive_fuse_split_tp( + weight, tensor_parallel_degree, tensor_parallel_rank=None, is_column=True, fuse_tensor_parts=2 +): + """ + + [A1, A2, B1, B2] => [A1 B1],[A2 B2] + + Args: + weight (numpy.ndarray): the tensor weight, + tensor_parallel_degree (int): tensor_parallel_degree + tensor_parallel_rank (int): tensor_parallel_rank + is_column (bool, optional): is ColumnLinear . Defaults to True. + + Returns: + tensor (numpy.ndarray): splited weight. + + """ + axis = -1 if is_column else 0 + splited = np.split(weight, fuse_tensor_parts * tensor_parallel_degree, axis=axis) + + if tensor_parallel_rank is None: + ret = [] + for tensor_parallel_rank in range(tensor_parallel_degree): + ret.append(np.concatenate(splited[tensor_parallel_rank::tensor_parallel_degree], axis=axis)) + return ret + + return np.concatenate(splited[tensor_parallel_rank::tensor_parallel_degree], axis=axis) + + +def normal_fuse_merge_tp(weight_list, is_column=True): """ [A1],[A2] => [A1, A2] @@ -274,7 +332,7 @@ def merge_tensor_parallel_weight(weight_list, is_column=True): return np.concatenate(weight_list, axis=0) -def split_tensor_parallel_weight(weight, tensor_parallel_degree, tensor_parallel_rank=None, is_column=True): +def normal_fuse_split_tp(weight, tensor_parallel_degree, tensor_parallel_rank=None, is_column=True): """ [A1, A2] => [A1],[A2] @@ -288,6 +346,29 @@ def split_tensor_parallel_weight(weight, tensor_parallel_degree, tensor_parallel Returns: tensor (numpy.ndarray): splited weight. """ + dim = -1 if is_column else 0 + if "PySafeSlice" in str(type(weight)): + size = weight.get_shape()[dim] + block_size = size // tensor_parallel_degree + start = tensor_parallel_rank * block_size + stop = (tensor_parallel_rank + 1) * block_size + assert ( + size % tensor_parallel_degree == 0 + ), f"The choosen size {size} is not compatible with sharding on {tensor_parallel_degree} shards" + + if dim == 0 or len(weight.get_shape()) == 1: + tensor = weight[start:stop] + elif dim == -1: + tensor = weight[:, start:stop] + else: + raise NotImplementedError("Let's make that generic when needed") + return tensor + + size = weight.shape[dim] + assert ( + size % tensor_parallel_degree == 0 + ), f"The choosen size {size} is not compatible with sharding on {tensor_parallel_degree} shards. for tensor shape {weight.shape}" + if is_column: splited_weights = np.split(weight, tensor_parallel_degree, axis=-1) else: @@ -362,13 +443,17 @@ def splited_qkv_to_tensor_parallel_qkv(weight_list, num_attention_heads): def get_tensor_parallel_merge_func(tensor_parallel_degree, tensor_parallel_rank, num_attention_heads=None): - def fn(x, is_column=True, transpose=False, is_old_qkv=False): + def fn(x, is_column=True, transpose=False, is_old_qkv=False, is_naive_2fuse=False, is_naive_3fuse=False): if x is None: return None - x = merge_tensor_parallel_weight( - x, - is_column=is_column, - ) + + if is_naive_2fuse: + return naive_fuse_merge_tp(x, is_column=is_column, fuse_tensor_parts=2) + elif is_naive_3fuse: + return naive_fuse_merge_tp(x, is_column=is_column, fuse_tensor_parts=3) + else: + x = normal_fuse_merge_tp(x, is_column=is_column) + if is_old_qkv: assert is_column, "QKV tensor should be column parallel linear." assert num_attention_heads is not None, "is_old_qkv need num_attention_heads" @@ -382,7 +467,7 @@ def fn(x, is_column=True, transpose=False, is_old_qkv=False): def get_tensor_parallel_split_func(tensor_parallel_degree, tensor_parallel_rank, num_attention_heads=None): - def fn(x, is_column=True, transpose=False, is_old_qkv=False): + def fn(x, is_column=True, transpose=False, is_old_qkv=False, is_naive_2fuse=False, is_naive_3fuse=False): if x is None: return None if transpose: @@ -391,12 +476,16 @@ def fn(x, is_column=True, transpose=False, is_old_qkv=False): assert is_column, "QKV tensor should be column parallel linear." assert num_attention_heads is not None, "is_old_qkv need num_attention_heads" x = naive_merged_qkv_to_tensor_parallel_qkv(x, num_attention_heads) - return split_tensor_parallel_weight( - x, - tensor_parallel_degree=tensor_parallel_degree, - tensor_parallel_rank=tensor_parallel_rank, - is_column=is_column, - ) + if is_naive_2fuse: + return naive_fuse_split_tp( + x, tensor_parallel_degree, tensor_parallel_rank, is_column=is_column, fuse_tensor_parts=2 + ) + if is_naive_3fuse: + return naive_fuse_split_tp( + x, tensor_parallel_degree, tensor_parallel_rank, is_column=is_column, fuse_tensor_parts=3 + ) + + return normal_fuse_split_tp(x, tensor_parallel_degree, tensor_parallel_rank, is_column=is_column) return fn @@ -925,6 +1014,14 @@ def _get_name_mappings(cls, config: PretrainedConfig) -> List[StateDictNameMappi """ raise NotImplementedError + @classmethod + def get_tensor_parallel_convert_actions(cls, config: PretrainedConfig, loaded_state_dict_keys, ignore_error=False): + name_action_mappings = cls._get_tensor_parallel_mappings(config) + state_keys_map = cls._resolve_prefix_keys(name_action_mappings.keys(), loaded_state_dict_keys, ignore_error) + for k, v in state_keys_map.items(): + name_action_mappings[v] = name_action_mappings.pop(k) + return name_action_mappings + @classmethod def convert_tensor_parallel( cls, weight_file: str, config: PretrainedConfig, state_dict=None, ignore_error=False @@ -932,13 +1029,14 @@ def convert_tensor_parallel( """the entry of converting config and converting model file Args: - input_dir (str | None): the input dir which contains `pytorch_model.bin` and `config.json` file + weight_file (str | None): the weight file path of `model_state.pdparams` file config (PretrainedConfig): the PretrainedConfig instance of model """ name_action_mappings = cls._get_tensor_parallel_mappings(config) if state_dict is None: with device_guard("cpu"): state_dict = paddle.load(weight_file, return_numpy=False) + logger.info("Starting to convert orignal state_dict to tensor parallel state_dict.") state_keys_map = cls._resolve_prefix_keys(name_action_mappings.keys(), state_dict.keys(), ignore_error) @@ -948,7 +1046,7 @@ def convert_tensor_parallel( for name, action in name_action_mappings.items(): if name not in state_dict: if not ignore_error: - logger.warning(f"key<{name}> not in the model state weight file.") + logger.warning(f"Key <{name}> not in the model state weight file.") continue tensor = state_dict.pop(name) new_tensor = action(tensor) @@ -986,6 +1084,11 @@ def merge_tensor_parallel(cls, state_dict, config) -> None: else: tensor = tensor.numpy() if is_dst else None + # keep state dict use paddle.tensor + if isinstance(tensor, np.ndarray): + with device_guard("cpu"): + tensor = paddle.Tensor(tensor, zero_copy=True) + state_dict_to_save[key] = tensor if len(name_action_mappings) > 0: @@ -1024,7 +1127,7 @@ def _resolve_prefix_keys(state_keys_base, state_keys_real, ignore_error=False): break if key not in state_keys_map: if not ignore_error: - logger.error(f"could not find name {key} in loaded state dict!") + logger.error(f"tensor parallel conversion: could not find name {key} in loaded state dict!") else: state_keys_real.remove(state_keys_map[key]) diff --git a/paddlenlp/transformers/generation_utils.py b/paddlenlp/transformers/generation_utils.py index 230c30c19221..280796393ccf 100644 --- a/paddlenlp/transformers/generation_utils.py +++ b/paddlenlp/transformers/generation_utils.py @@ -1151,6 +1151,7 @@ def sample( else: next_tokens = paddle.multinomial(probs) + # paddle.distributed.broadcast(next_tokens, src=0) next_scores = paddle.index_sample(origin_probs, next_tokens) if eos_token_id is not None: diff --git a/paddlenlp/transformers/model_utils.py b/paddlenlp/transformers/model_utils.py index 341d01f83391..9e2192fccec1 100644 --- a/paddlenlp/transformers/model_utils.py +++ b/paddlenlp/transformers/model_utils.py @@ -19,6 +19,7 @@ import re import shutil import tempfile +from collections import OrderedDict from contextlib import contextmanager from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type @@ -79,6 +80,79 @@ ] +def unwrap_optimizer(optimizer, optimizer_instances=()): + if optimizer is None: + return None + while hasattr(optimizer, "_inner_opt") and not isinstance(optimizer, optimizer_instances): + optimizer = optimizer._inner_opt + if isinstance(optimizer, optimizer_instances): + return optimizer + return None + + +def filter_sharded_params(state_dict, optimizer, sharding_group): + sharding_rank = sharding_group.rank + sharding_world_size = sharding_group.nranks + from paddlenlp.trainer.utils import reshard as reshard_util + + logger.info(f"filter sharded_params not placed in sharding_rank {sharding_rank} .") + if not reshard_util.is_sharding_opt(optimizer): + return state_dict + filtered_state_dict = OrderedDict() + if reshard_util.get_sharding_strategy(optimizer) == reshard_util.SHARDING_STRATEGY_V1: + from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer import ( + DygraphShardingOptimizer, + ) + + optimizer = unwrap_optimizer(optimizer, DygraphShardingOptimizer) + for (k, v) in state_dict.items(): + assert v.name in optimizer._param2rank + sharded_rank = optimizer._param2rank[v.name] + if sharded_rank != sharding_rank: + continue + filtered_state_dict[k] = v + else: + from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer import ( + DygraphShardingOptimizerV2, + ) + + optimizer = unwrap_optimizer(optimizer, DygraphShardingOptimizerV2) + parameters = optimizer._parameter_list + filtered_parameters = [p.name for (i, p) in enumerate(parameters) if i % sharding_world_size == sharding_rank] + filtered_parameters = set(filtered_parameters) + for (k, v) in state_dict.items(): + if v.name in filtered_parameters: + filtered_state_dict[k] = v + + return filtered_state_dict + + +def exlclude_paramters_in_state_dict( + model_state_dict, param_names_in_master_weights, sharding_group, save_sharding_stage1_model=True +): + assert sharding_group is not None + assert isinstance(model_state_dict, dict) and isinstance( + param_names_in_master_weights, (list, set) + ), "param_names_in_master_weights type:{}".format(type(param_names_in_master_weights)) + state_param_names = [v.name for k, v in model_state_dict.items()] + logger.debug( + "param_names_in_master_weights:{}, state_param_names:{}".format( + param_names_in_master_weights, state_param_names + ) + ) + # allgather parameter names in sharding group + tmp = [] + paddle.distributed.all_gather_object(tmp, param_names_in_master_weights, group=sharding_group) + param_names_in_master_weights = set([v for item in tmp for v in item]) + logger.info("sharding_group_param_names:{}".format(param_names_in_master_weights)) + non_parameters_state_dict = copy.copy(model_state_dict) + for k, v in model_state_dict.items(): + if v.name in param_names_in_master_weights: + non_parameters_state_dict.pop(k) + + return non_parameters_state_dict + + def prune_linear_layer(layer: nn.Linear, index: paddle.Tensor, dim: int = 0) -> nn.Linear: """ Prune a linear layer to keep only entries in index. @@ -281,6 +355,7 @@ def _find_weight_file_path( cache_dir: str, model_class: Type[PretrainedModel], resource_uri: Optional[str] = None, + config: Optional[PretrainedConfig] = None, ) -> str | None: """find the target weight file under the cache dir, because there are some conflicts about weight file names. @@ -305,7 +380,7 @@ def _find_weight_file_path( # 3. find the target weight file name for splited tensor parallel # fix for load hybrid parallel hybrid_parallel_weight_file_path = os.path.join( - cache_dir, _add_variant(resource_weight_file_name, weight_name_suffix()) + cache_dir, _add_variant(resource_weight_file_name, weight_name_suffix(config)) ) if os.path.isfile(hybrid_parallel_weight_file_path): return hybrid_parallel_weight_file_path @@ -954,7 +1029,10 @@ def _resolve_model_file_path( # find the weight file with the above two branch: `bert-base-uncased.pdparams`, `model_state.pdparams` weight_file_path = _find_weight_file_path( - cache_dir=cache_dir, model_class=cls, resource_uri=pretrained_model_name_or_path + cache_dir=cache_dir, + model_class=cls, + resource_uri=pretrained_model_name_or_path, + config=config, ) return weight_file_path @@ -964,7 +1042,7 @@ def _resolve_model_file_path( # in-order to compatible with old style: # file name in pretrained_resouce_file_maps is https://path/to/bert-base-uncased.pdparams, but the registered model-state file name in `resouce_file_maps` is `model_state.pdparams` - return _find_weight_file_path(cache_dir=pretrained_model_name_or_path, model_class=cls) + return _find_weight_file_path(cache_dir=pretrained_model_name_or_path, model_class=cls, config=config) # 4. download from community or hf-hub else: @@ -1187,6 +1265,7 @@ def _find_mismatched_keys( import paddlenlp.ops.fast_transformer.transformer.decoding as ft_decoding state_to_load = ft_decoding.get_ft_para_conf().fit_partial_model(model_to_load, state_dict) + if paddle.in_dynamic_mode(): model_to_load.set_state_dict(state_to_load) @@ -1414,6 +1493,12 @@ def save_pretrained(self, save_dir: str, **kwargs): merge_tensor_parallel = kwargs.get("merge_tensor_parallel", False) variant = kwargs.get("variant", None) is_main_process = kwargs.get("is_main_process", True) + is_bf16 = kwargs.get("is_bf16", False) + param_names_in_master_weights = list(kwargs.get("param_names_in_master_weights", [])) + sharding_group = kwargs.get("sharding_group", None) + optimizer = kwargs.get("optimizer", None) + save_sharding_stage1_model = kwargs.get("save_sharding_stage1_model", False) + use_async_save = kwargs.get("use_async_save", False) # 1. retrieve the model related config @@ -1426,10 +1511,12 @@ def save_pretrained(self, save_dir: str, **kwargs): dtype = get_parameter_dtype(model_to_save) model_to_save.config.dtype = str(dtype).split(".")[1] - state_dict_to_save = None config_to_save = copy.deepcopy(model_to_save.config) + state_dict_to_save = model_to_save.state_dict() + if save_sharding_stage1_model: + state_dict_to_save = filter_sharded_params(state_dict_to_save, optimizer, sharding_group) if merge_tensor_parallel and config_to_save.tensor_parallel_degree > 1: - state_dict_to_save = model_to_save.merge_tensor_parallel(model_to_save.state_dict(), config_to_save) + state_dict_to_save = model_to_save.merge_tensor_parallel(state_dict_to_save, config_to_save) config_to_save.tensor_parallel_degree = 1 # set variant to None for merge_tensor_parallel, but there should no relationship with variant setting variant = None @@ -1442,7 +1529,15 @@ def save_pretrained(self, save_dir: str, **kwargs): variant = f"tp{config_to_save.tensor_parallel_rank:0>2d}" # WEIGHTS_NAME = _add_variant(WEIGHTS_NAME, variant) - state_dict_to_save = self.state_dict() + if is_bf16 and save_sharding_stage1_model: + state_dict_to_save = exlclude_paramters_in_state_dict( + state_dict_to_save, param_names_in_master_weights, sharding_group + ) + logger.info( + "param_names_in_master_weights len:{}, bf16 state_dict_to_save len:{}, :{}".format( + len(param_names_in_master_weights), len(state_dict_to_save), state_dict_to_save + ) + ) if is_main_process: # Attach architecture to the config @@ -1452,7 +1547,10 @@ def save_pretrained(self, save_dir: str, **kwargs): # Save model if paddle.in_dynamic_mode(): file_name = os.path.join(save_dir, _add_variant(WEIGHTS_NAME, variant)) - paddle.save(state_dict_to_save, file_name) + if use_async_save: + paddle.async_save(state_dict_to_save, file_name) + else: + paddle.save(state_dict_to_save, file_name) del model_to_save else: logger.warning("Save pretrained model only supported dygraph mode for now!") diff --git a/paddlenlp/transformers/utils.py b/paddlenlp/transformers/utils.py index 555a496eae94..67d09b5637b4 100644 --- a/paddlenlp/transformers/utils.py +++ b/paddlenlp/transformers/utils.py @@ -391,7 +391,7 @@ def optimizer_name_suffix(): return None -def weight_name_suffix(): +def weight_name_suffix(config=None): hcg = use_hybrid_parallel() if hcg is not None: name = [] @@ -399,6 +399,12 @@ def weight_name_suffix(): name.append(f"tp{hcg.get_model_parallel_rank():0>2d}") if hcg.get_pipe_parallel_world_size() > 1: name.append(f"pp{hcg.get_stage_id():0>2d}") + if config and getattr(config, "moe_num_experts", 0): + dp_group = hcg.get_data_parallel_group() + name.append(f"moe{dp_group.rank:0>2d}") return "_".join(name) else: + if config and getattr(config, "moe_num_experts", 0): + rank = paddle.distributed.get_rank() + return f"moe{rank:0>2d}" return None