Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 25 additions & 5 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,24 @@ def remove_callback(self, callback):
"""
self.callback_handler.remove_callback(callback)

def _check_resume_from_checkpoint(self, resume_from_checkpoint):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个函数也不需要

"""load state_dict from_checkpoint, Only load model state dict.

Args:
resume_from_checkpoint (`str` or `bool`, *optional*):
If a `str`, local path to a saved checkpoint as saved by a previous instance of [`Trainer`]. If a
`bool` and equals `True`, load the last checkpoint in *args.output_dir* as saved by a previous instance
of [`Trainer`]. Only load model state dict.
"""
resume_from_checkpoint = None if not resume_from_checkpoint else resume_from_checkpoint
# Load potential model checkpoint
if isinstance(resume_from_checkpoint, bool) and resume_from_checkpoint:
resume_from_checkpoint = get_last_checkpoint(self.args.output_dir)
if resume_from_checkpoint is None:
raise ValueError(f"No valid checkpoint found in output directory ({self.args.output_dir})")

return resume_from_checkpoint

def _load_from_checkpoint(self, resume_from_checkpoint=None):
"""load state_dict from_checkpoint, Only load model state dict.

Expand Down Expand Up @@ -456,10 +474,12 @@ def _load_from_checkpoint(self, resume_from_checkpoint=None):
weight_index_name = PADDLE_WEIGHTS_INDEX_NAME # currently set paddle as default, do not support safetensors.

if self.args.should_load_sharding_stage1_model:
state_dict = self.sharding_io.load_state_dict_from_checkpoint_with_reshard(
resume_from_checkpoint,
base_weight_name=weight_name,
)
if resume_from_checkpoint is not None:
state_dict = self.sharding_io.load_state_dict_from_checkpoint_with_reshard(
resume_from_checkpoint,
base_weight_name=weight_name,
)
self._set_state_dict_in_model(state_dict)
else:
if resume_from_checkpoint is not None and self.args.dataset_rank == 0:

Expand Down Expand Up @@ -594,7 +614,7 @@ def train(
self.create_optimizer_and_scheduler(num_training_steps=max_steps)

self.state = TrainerState()

resume_from_checkpoint = self._check_resume_from_checkpoint(resume_from_checkpoint)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个是在训练代码里面做的,不是在这里。不需要加

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不加就报错了

if self.args.should_load_sharding_stage1_model:
model = self._wrap_model_and_load_sharded_checkpoint(resume_from_checkpoint)
elif self.args.should_save_sharding_stage1_model:
Expand Down