@@ -149,7 +149,7 @@ def __test__(cls):
149149
150150 def setUp (self ):
151151 """
152- 1. update runfrist and rerun to run defined different config
152+ 1. update runfirst and rerun to run defined different config
153153 2. update need_allclose to True if you want to check the result
154154 3. update rtol to the relative value you want to check
155155 """
@@ -169,7 +169,7 @@ def setUp(self):
169169
170170 self .run_lora_file = "llm/finetune_generation.py"
171171
172- def runfrist (self , train_args ):
172+ def runfirst (self , train_args ):
173173 self .run_n1c8 (self .run_lora_file , ** train_args )
174174
175175 def rerun (self , train_args ):
@@ -181,7 +181,7 @@ def testTP4PP2(self):
181181 remove_ckpt (lora_arguments ["output_dir" ])
182182
183183 train_args = self .configs ["TP4PP2" ]
184- self .runfrist (train_args )
184+ self .runfirst (train_args )
185185 self .rerun (train_args )
186186
187187 if self .need_allclose :
@@ -196,7 +196,7 @@ def testTP2Sharding4(self):
196196 remove_ckpt (lora_arguments ["output_dir" ])
197197
198198 train_args = self .configs ["TP2Sharding4" ]
199- self .runfrist (train_args )
199+ self .runfirst (train_args )
200200 self .rerun (train_args )
201201
202202 if self .need_allclose :
@@ -213,7 +213,7 @@ def testTP8(self):
213213 remove_ckpt (lora_arguments ["output_dir" ])
214214
215215 train_args = self .configs ["TP8" ]
216- self .runfrist (train_args )
216+ self .runfirst (train_args )
217217 self .rerun (train_args )
218218
219219 if self .need_allclose :
@@ -227,7 +227,7 @@ def testTP4DP2(self):
227227 remove_ckpt (lora_arguments ["output_dir" ])
228228
229229 train_args = self .configs ["TP4DP2" ]
230- self .runfrist (train_args )
230+ self .runfirst (train_args )
231231 self .rerun (train_args )
232232
233233 if self .need_allclose :
@@ -242,7 +242,7 @@ def testTP4Sharding2(self):
242242 remove_ckpt (lora_arguments ["output_dir" ])
243243
244244 train_args = self .configs ["TP4Sharding2" ]
245- self .runfrist (train_args )
245+ self .runfirst (train_args )
246246 self .rerun (train_args )
247247
248248 if self .need_allclose :
@@ -257,7 +257,7 @@ def testTP2PP4(self):
257257 remove_ckpt (lora_arguments ["output_dir" ])
258258
259259 train_args = self .configs ["TP2PP4" ]
260- self .runfrist (train_args )
260+ self .runfirst (train_args )
261261 self .rerun (train_args )
262262
263263 if self .need_allclose :
@@ -272,7 +272,7 @@ def testPP8(self):
272272 remove_ckpt (lora_arguments ["output_dir" ])
273273
274274 train_args = self .configs ["PP8" ]
275- self .runfrist (train_args )
275+ self .runfirst (train_args )
276276 self .rerun (train_args )
277277
278278 if self .need_allclose :
@@ -287,7 +287,7 @@ def testPP4DP2(self):
287287 remove_ckpt (lora_arguments ["output_dir" ])
288288
289289 train_args = self .configs ["PP4DP2" ]
290- self .runfrist (train_args )
290+ self .runfirst (train_args )
291291 self .rerun (train_args )
292292
293293 if self .need_allclose :
@@ -302,7 +302,7 @@ def testPP4Sharding2(self):
302302 remove_ckpt (lora_arguments ["output_dir" ])
303303
304304 train_args = self .configs ["PP4Sharding2" ]
305- self .runfrist (train_args )
305+ self .runfirst (train_args )
306306 self .rerun (train_args )
307307
308308 if self .need_allclose :
@@ -317,7 +317,7 @@ def testSharding8S1(self):
317317 remove_ckpt (lora_arguments ["output_dir" ])
318318
319319 train_args = self .configs ["Sharding8S1" ]
320- self .runfrist (train_args )
320+ self .runfirst (train_args )
321321 self .rerun (train_args )
322322
323323 if self .need_allclose :
@@ -332,7 +332,7 @@ def testSharding8S2(self):
332332 remove_ckpt (lora_arguments ["output_dir" ])
333333
334334 train_args = self .configs ["Sharding8S2" ]
335- self .runfrist (train_args )
335+ self .runfirst (train_args )
336336 self .rerun (train_args )
337337
338338 if self .need_allclose :
@@ -347,7 +347,7 @@ def testSharding4S1DP2(self):
347347 remove_ckpt (lora_arguments ["output_dir" ])
348348
349349 train_args = self .configs ["Sharding4S1DP2" ]
350- self .runfrist (train_args )
350+ self .runfirst (train_args )
351351 self .rerun (train_args )
352352
353353 if self .need_allclose :
@@ -362,7 +362,7 @@ def testSharding4S2DP2(self):
362362 remove_ckpt (lora_arguments ["output_dir" ])
363363
364364 train_args = self .configs ["Sharding4S2DP2" ]
365- self .runfrist (train_args )
365+ self .runfirst (train_args )
366366 self .rerun (train_args )
367367
368368 if self .need_allclose :
@@ -377,7 +377,7 @@ def testSharding2S1DP4(self):
377377 remove_ckpt (lora_arguments ["output_dir" ])
378378
379379 train_args = self .configs ["Sharding2S1DP4" ]
380- self .runfrist (train_args )
380+ self .runfirst (train_args )
381381 self .rerun (train_args )
382382
383383 if self .need_allclose :
@@ -392,7 +392,7 @@ def testSharding2S2DP4(self):
392392 remove_ckpt (lora_arguments ["output_dir" ])
393393
394394 train_args = self .configs ["Sharding2S2DP4" ]
395- self .runfrist (train_args )
395+ self .runfirst (train_args )
396396 self .rerun (train_args )
397397
398398 if self .need_allclose :
@@ -407,7 +407,7 @@ def testDP8(self):
407407 remove_ckpt (lora_arguments ["output_dir" ])
408408
409409 train_args = self .configs ["DP8" ]
410- self .runfrist (train_args )
410+ self .runfirst (train_args )
411411 self .rerun (train_args )
412412
413413 if self .need_allclose :
@@ -416,27 +416,29 @@ def testDP8(self):
416416 np .testing .assert_allclose (res [0 ], res [1 ], self .rtol )
417417
418418
419+ @pytest .mark .skipif (True , reason = "Skip for None CE" )
419420class TestUnifiedCheckpointOnN2C4 (TestUnifiedCheckpointBase ):
420421 def setUp (self ):
421422 super ().setUp ()
422423 self .need_allclose = True
423424 self .rtol = 1e-7
424425
425- def runfrist (self , train_args ):
426+ def runfirst (self , train_args ):
426427 self .run_n2c4 (self .run_lora_file , ** train_args )
427428
428429 def rerun (self , train_args ):
429430 self .run_n2c4 (self .run_lora_file , ** train_args )
430431
431432
433+ @pytest .mark .skipif (True , reason = "Skip for None CE" )
432434class TestUnifiedCheckpointOnN1C8CheckpointCompatible (TestUnifiedCheckpointBase ):
433435 def setUp (self ):
434436 super ().setUp ()
435437
436438 self .need_allclose = True
437439 self .rtol = 1e-7
438440
439- def runfrist (self , train_args ):
441+ def runfirst (self , train_args ):
440442 train_args ["unified_checkpoint" ] = 0
441443 self .run_n1c8 (self .run_lora_file , ** train_args )
442444
@@ -445,14 +447,15 @@ def rerun(self, train_args):
445447 self .run_n1c8 (self .run_lora_file , ** train_args )
446448
447449
450+ @pytest .mark .skipif (True , reason = "Skip for None CE" )
448451class TestPaddleCheckpointOnN1C8Reset (TestUnifiedCheckpointBase ):
449452 def setUp (self ):
450453 super ().setUp ()
451454
452455 self .need_allclose = True
453456 self .rtol = 1e-7
454457
455- def runfrist (self , train_args ):
458+ def runfirst (self , train_args ):
456459 train_args ["unified_checkpoint" ] = 0
457460 self .run_n1c8 (self .run_lora_file , ** train_args )
458461
@@ -469,7 +472,7 @@ def setUp(self):
469472 self .need_allclose = True
470473 self .rtol = 1e-7
471474
472- def runfrist (self , train_args ):
475+ def runfirst (self , train_args ):
473476 train_args ["unified_checkpoint" ] = 0
474477 self .run_n2c4 (self .run_lora_file , ** train_args )
475478
0 commit comments