Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
79 commits
Select commit Hold shift + click to select a range
6b690ec
refactor training loop support
Meiyim May 29, 2023
78d764d
fix trainer
ForFishes Jun 22, 2023
aa738a6
Merge pull request #1 from ForFishes/add_seed_order_nlp
Meiyim Jun 22, 2023
9cfd157
support pp delay_scale_loss and dp_comm_overlap
haohongxiang Jun 23, 2023
28cb000
Merge pull request #3 from haohongxiang/support_pp_new_strategy
Meiyim Jun 23, 2023
3bec988
support sharding stage1 in hybrid parallel.
wuhuachaocoding Jun 23, 2023
dfbd574
pipeline 兼容非master-grad
Meiyim Jun 24, 2023
c382dc0
Merge pull request #5 from Meiyim/fix-master-grad
Meiyim Jun 24, 2023
499413b
Merge pull request #4 from wuhuachaocoding/sharding_eb4
Meiyim Jun 24, 2023
fc8f1ab
Bump paddle to `new-model-7`
Meiyim Jun 25, 2023
49c9e5c
fix sharding + `overlap` hangs under progressive batching
Meiyim Jun 25, 2023
40f3a91
Add time statistics for nccl-connection.
GhostScreaming Jun 26, 2023
dc59752
Polish code.
GhostScreaming Jun 26, 2023
9539fdb
Merge pull request #7 from GhostScreaming/refactor-training-loop
Meiyim Jun 26, 2023
0f29ea4
allow to use `main-grad` under TF32/FP32
Meiyim Jun 26, 2023
9a71dd7
Merge pull request #8 from Meiyim/tf32-main-grad
Meiyim Jun 26, 2023
3124e91
[hot-fix] resume from accumulation-step wrong
Meiyim Jun 27, 2023
afebd95
[fix] pp非mp情况下只存了pp01 model
xysheng-baidu Jun 29, 2023
1985aeb
Merge pull request #9 from Meiyim/hot-fix-resume-when-accumulate
Meiyim Jun 29, 2023
55ee7d6
Merge pull request #10 from bo-ke/fix_args
Meiyim Jun 29, 2023
16e685e
online hot fix, rasie Error when optimizer/lr scheduler no show in re…
Meiyim Jul 2, 2023
56c001a
Merge pull request #11 from Meiyim/hotfix
Meiyim Jul 2, 2023
7798f93
sharding save and load for bf16
Jul 3, 2023
cf5e5bb
fix
Jul 3, 2023
854036c
polish
Jul 3, 2023
2024390
polish
Jul 3, 2023
e182e0b
add reshard
Jul 3, 2023
d5e4512
polish
Jul 3, 2023
519ab65
bf16 save and load demo
Jul 3, 2023
2366c32
polish
Jul 3, 2023
3a16b44
merge
Jul 4, 2023
76f6167
fix range
Jul 4, 2023
059ebe3
support asynchrounous save
SylarTiaNII Jul 4, 2023
3099598
add optional argument async_save to trainer
SylarTiaNII Jul 4, 2023
2fb6905
fix _save_checkpoint
SylarTiaNII Jul 4, 2023
8bed559
tidy codes
Jul 5, 2023
53aec0a
reshard if necessary
Jul 5, 2023
6a7faec
polish
Jul 5, 2023
873ebb5
polish
Jul 5, 2023
afe7a63
reformat switch
SylarTiaNII Jul 6, 2023
b498318
format fix
SylarTiaNII Jul 6, 2023
681ca98
fix
SylarTiaNII Jul 6, 2023
4fe6b4f
Merge pull request #15 from SylarTiaNII/async_save
Meiyim Jul 6, 2023
7e16164
fix group
Jul 6, 2023
6a7c120
merge
Jul 6, 2023
a14ee6b
Merge pull request #12 from pangengzheng/sharding_save_and_load
Meiyim Jul 9, 2023
d50f333
fix sharding group
Jul 9, 2023
fc30b57
polish
Jul 9, 2023
d52e163
Merge pull request #17 from pangengzheng/fix_sharding_group
Meiyim Jul 9, 2023
6b939a8
move load model before wrap model because it will broadcast sharding0…
pangengzheng Jul 13, 2023
14a4032
Refactor training loop merge meta (#6359)
liuzhenhai93 Jul 13, 2023
dce1593
broadcast (#6397)
liuzhenhai93 Jul 14, 2023
6c1d9bc
fix ClipGradByAdaptiveNorm load state dict (#6409)
GuoxiaWang Jul 15, 2023
fa76862
remove too many parameter logs (#6407)
sneaxiy Jul 15, 2023
5517ca5
Add profile timer (#6441)
Meiyim Jul 20, 2023
02ae3fc
polish (#6492)
liuzhenhai93 Jul 25, 2023
47a71a1
[Distributed]Add dp/sharding overlap for pipeline (#6504)
ForFishes Jul 26, 2023
083f8d8
[Distributed]Support pipelineparallel in accumulation_steps (#6509)
ForFishes Jul 26, 2023
28d4e0c
[LLM] Fix asynchronize save (#6624)
SylarTiaNII Aug 7, 2023
ed8ca95
[LLM] Support master grad on dp. (#6650)
ZHUI Aug 8, 2023
583bb33
[Distributed] Add hang mode to avoid overlap (#6838)
ForFishes Aug 28, 2023
0a808d3
Add tensor parallel conversion. (#6844)
ZHUI Sep 4, 2023
2e53d60
[Trainer] Support main grad for pure sharding #7013 (#7014)
QimingPeng Sep 13, 2023
0ed3828
improve load sharding model (#6976)
pangengzheng Sep 13, 2023
1191c40
[Trainer] Support for MoE model training (#7089)
Meiyim Oct 9, 2023
8def3e7
[Trainer] Update load/save sharded model for mp model (#7180)
QimingPeng Oct 10, 2023
939224a
[LLM] use pin memory save optimizer (#7185)
xysheng-baidu Oct 11, 2023
cb7efe1
basic func
liuzhenhai93 Oct 12, 2023
2772950
basic func
liuzhenhai93 Oct 12, 2023
06123fa
polish
liuzhenhai93 Oct 12, 2023
d6400c9
code ok
liuzhenhai93 Oct 12, 2023
283371f
basic func
liuzhenhai93 Oct 13, 2023
b4f0034
polish
liuzhenhai93 Oct 13, 2023
05da413
polish
liuzhenhai93 Oct 13, 2023
d0509e4
Merge branch 'refactor-training-loop' of https://github.com/PaddlePad…
Oct 13, 2023
bdf9aa9
polish
liuzhenhai93 Oct 13, 2023
d6d1dcf
paddlenlp/trainer/utils/reshard/sharding_v2.py
liuzhenhai93 Oct 13, 2023
6f8e808
polish
liuzhenhai93 Oct 13, 2023
50950c2
polish
liuzhenhai93 Oct 13, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
[LLM] Support master grad on dp. (#6650)
* support master grad on dp.
  • Loading branch information
ZHUI authored Aug 8, 2023
commit ed8ca957999248177dec02b7f2c529c2a72d072d
16 changes: 15 additions & 1 deletion paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -897,6 +897,11 @@ def train(
# sharding
# stage1. the same as ddp
# stage2. manualy collect gradient on dp group

hack_dp_master_grad = self.args.amp_master_grad and not self.args.use_hybrid_parallel
if hack_dp_master_grad:
is_no_sync = False

if is_no_sync:
# Avoid unnecessary DDP synchronization since there will be no backward pass on this example.
with model.no_sync():
Expand Down Expand Up @@ -949,6 +954,10 @@ def train(
self.timers and self.timers("all-reduce").stop()
self.timers and self.timers("optimizer-step").start()

# Case 3: hack dp with master_grad
if hack_dp_master_grad and not (args.recompute and availiable_no_sync):
fused_allreduce_gradients(list(model.parameters()), None)

# pipeline parallel mode, handle gradient merge here
if args.pipeline_parallel_degree > 1 and enable_delay_scale_loss:
for p in model._layers.parameters():
Expand Down Expand Up @@ -1518,7 +1527,12 @@ def _wrap_model(self, model, training=True):

# Multi-gpu training
if self.args.world_size > 1 and not self.args.use_hybrid_parallel:
model = paddle.DataParallel(model)
if self.args.amp_master_grad:
mix_precision_utils.MixPrecisionLayer(model, dtype=self.amp_dtype) # return value has no use
logger.warning("Note amp_master_grad using in dp is an experimental support!")
self.optimizer = mix_precision_utils.MixPrecisionOptimizer(self.optimizer)
else:
model = paddle.DataParallel(model)
# Distributed training (should be after fp16 initialization)

in_pipeline_parallel_mode = self.args.pipeline_parallel_degree > 1
Expand Down
15 changes: 10 additions & 5 deletions paddlenlp/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,13 +714,18 @@ def __post_init__(self):
self.use_hybrid_parallel = True

if self.amp_master_grad:
if self.pipeline_parallel_degree <= 1 and self.tensor_parallel_degree <= 1:
raise ValueError(
"Temporarily amp master grad only suport for tensor/pipeline parallel. please set amp_master_grad to False."
)
# if not (self.bf16 or self.fp16):
# if (
# self.pipeline_parallel_degree <= 1 and self.tensor_parallel_degree <= 1
# ) or self.fp16_opt_level != "O2":
# raise ValueError(
# "Temporarily amp master grad only suport for tensor/pipeline parallel with fp16_opt_level O2. please set amp_master_grad to False."
# )
# if not (self.bf16 or self.fp16) or self.fp16_opt_level != "O2":
# logger.warning("set amp_master_grad to false since amp is disabled.")
# self.amp_master_grad = False
if self.pipeline_parallel_degree <= 1 and self.tensor_parallel_degree <= 1 and len(self.sharding) > 1:
logger.warning("set amp_master_grad to false, not support pure sharding yet.")
self.amp_master_grad = False

if self.use_hybrid_parallel:
world_size = paddle.distributed.get_world_size()
Expand Down