Skip to content

Conversation

@haohongxiang
Copy link
Contributor

PR types

Bug fixes

PR changes

Others

Description

[Auto Parallel] Support semi-auto trainer and fit Llama2 training

@paddle-bot
Copy link

paddle-bot bot commented Jan 23, 2024

Thanks for your contribution!

@haohongxiang haohongxiang force-pushed the semi_auto_trainer_llama2 branch 6 times, most recently from 9668320 to 97498b9 Compare January 23, 2024 06:22
Copy link
Contributor

@ZHUI ZHUI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

改动比较大,先 Request changes 一手。

@haohongxiang haohongxiang force-pushed the semi_auto_trainer_llama2 branch from 97498b9 to 16bca68 Compare January 23, 2024 06:37
)

return optimizer
def _wrap_dist_loader(self, train_dataloader):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not used in dynamic mode?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, now it's used in dynamic and static mode.

meshes.append(_get_mesh(pp_idx))
return meshes

def _wrap_dist_loader(self, train_dataloader):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the difference with _wrap_dist_loader in run_pretrain_3D_auto.py?

shard_dims="dp",
)

def _wrap_for_static(self, model, train_dataloader):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it seems not used?

Copy link
Contributor Author

@haohongxiang haohongxiang Jan 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's called in Trainer from paddlenlp/trainer/trainer.py, for wrapping model into DistModel in static mode

position_ids = paddle.arange(seq_length, dtype="int64").expand((batch_size, seq_length))
# NOTE(zhaoyingli): infer spmd does not support [seq_len] --> [batch, seq_len] in data_parallel
position_ids = dist.shard_tensor(position_ids, get_mesh(), [dist.Shard(0), dist.Replicate()])
position_ids = dist.shard_tensor(position_ids, get_mesh(), [dist.Replicate(), dist.Replicate()])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why change to replicated?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because in static mode, infer spmd hasn't supported the case -- "[seq_len] --> [batch, seq_len]"

@haohongxiang haohongxiang force-pushed the semi_auto_trainer_llama2 branch 6 times, most recently from 0afdc96 to 624abd7 Compare January 24, 2024 08:26
Comment on lines 723 to 726
if self.args.use_auto_parallel and self.args.run_static_semi_auto:
model = self._wrap_for_static(model, train_dataloader)

self.model = model
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if self.args.use_auto_parallel and self.args.run_static_semi_auto:
model = self._wrap_for_static(model, train_dataloader)
self.model = model
if self.args.use_auto_parallel and self.args.run_static_semi_auto:
model = self._wrap_for_static(model, train_dataloader)
self.model = model

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@haohongxiang haohongxiang force-pushed the semi_auto_trainer_llama2 branch from 624abd7 to 6a381c3 Compare January 25, 2024 06:38
@codecov
Copy link

codecov bot commented Jan 25, 2024

Codecov Report

Attention: 398 lines in your changes are missing coverage. Please review.

Comparison is base (44bfeb0) 56.80% compared to head (f13a0bf) 56.57%.

Files Patch % Lines
paddlenlp/trainer/auto_trainer.py 0.00% 326 Missing ⚠️
paddlenlp/transformers/llama/modeling_3D_auto.py 4.54% 42 Missing ⚠️
paddlenlp/trainer/training_args.py 47.36% 20 Missing ⚠️
paddlenlp/trainer/trainer_utils.py 30.76% 9 Missing ⚠️
paddlenlp/trainer/trainer.py 88.88% 1 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##           develop    #7885      +/-   ##
===========================================
- Coverage    56.80%   56.57%   -0.23%     
===========================================
  Files          588      589       +1     
  Lines        89536    89900     +364     
===========================================
+ Hits         50858    50865       +7     
- Misses       38678    39035     +357     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@haohongxiang haohongxiang force-pushed the semi_auto_trainer_llama2 branch from 4df557e to dee9d04 Compare January 25, 2024 08:16
@haohongxiang haohongxiang force-pushed the semi_auto_trainer_llama2 branch 4 times, most recently from e3dfa0b to eda936c Compare January 28, 2024 22:12
@haohongxiang haohongxiang force-pushed the semi_auto_trainer_llama2 branch from eda936c to e541379 Compare January 28, 2024 23:15

# all_gather + mean() to get average loss over all processes
tr_loss_scalar = self._nested_gather(tr_loss).mean().item()
tr_loss_scalar = self._get_item_from_loss(self._nested_gather(tr_loss).mean())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

def _get_item_from_loss(self, loss):
assert isinstance(loss, paddle.Tensor) and loss._is_initialized()
return loss.item()
def _maybe_log_save_evaluate(self, tr_loss, model, epoch, ignore_keys_for_eval, **kwargs):
if self.control.should_log:
logs: Dict[str, float] = {}
# all_gather + mean() to get average loss over all processes
tr_loss_scalar = self._get_item_from_loss(self._nested_gather(tr_loss).mean())

这里 我看你是复用了 _maybe_log_save_evaluate 函数。而且外面包括了 guard,为什么这里 要加一个 assert isinstance(loss, paddle.Tensor) and loss._is_initialized()的检查?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里可以删掉,半自动判断逻辑在 auto_trainer 中重写即可

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.