|
13 | 13 | from prefect.core.edge import Edge |
14 | 14 | from prefect.core.flow import Flow |
15 | 15 | from prefect.core.task import Parameter, Task |
| 16 | +from prefect.engine.cache_validators import partial_inputs_only |
16 | 17 | from prefect.engine.result_handlers import LocalResultHandler, ResultHandler |
17 | 18 | from prefect.engine.signals import PrefectError |
18 | 19 | from prefect.engine.state import ( |
@@ -1418,6 +1419,55 @@ def handler(task, old, new): |
1418 | 1419 | assert len(state_history) == 5 # Running, Failed, Retrying, Running, Success |
1419 | 1420 |
|
1420 | 1421 |
|
| 1422 | +def test_flow_dot_run_handles_cached_states(): |
| 1423 | + class MockSchedule(prefect.schedules.Schedule): |
| 1424 | + call_count = 0 |
| 1425 | + |
| 1426 | + def next(self, n): |
| 1427 | + if self.call_count < 3: |
| 1428 | + self.call_count += 1 |
| 1429 | + return [pendulum.now("utc")] |
| 1430 | + else: |
| 1431 | + raise SyntaxError("Cease scheduling!") |
| 1432 | + |
| 1433 | + class StatefulTask(Task): |
| 1434 | + def __init__(self, maxit=False, **kwargs): |
| 1435 | + self.maxit = maxit |
| 1436 | + super().__init__(**kwargs) |
| 1437 | + |
| 1438 | + call_count = 0 |
| 1439 | + |
| 1440 | + def run(self): |
| 1441 | + self.call_count += 1 |
| 1442 | + if self.maxit: |
| 1443 | + return max(self.call_count, 2) |
| 1444 | + else: |
| 1445 | + return self.call_count |
| 1446 | + |
| 1447 | + @task( |
| 1448 | + cache_for=datetime.timedelta(minutes=1), |
| 1449 | + cache_validator=partial_inputs_only(validate_on=["x"]), |
| 1450 | + ) |
| 1451 | + def return_x(x, y): |
| 1452 | + return y |
| 1453 | + |
| 1454 | + storage = {"y": []} |
| 1455 | + |
| 1456 | + @task |
| 1457 | + def store_y(y): |
| 1458 | + storage["y"].append(y) |
| 1459 | + |
| 1460 | + t1, t2 = StatefulTask(maxit=True), StatefulTask() |
| 1461 | + schedule = MockSchedule() |
| 1462 | + with Flow(name="test", schedule=schedule) as f: |
| 1463 | + res = store_y(return_x(x=t1, y=t2)) |
| 1464 | + |
| 1465 | + with pytest.raises(SyntaxError) as exc: |
| 1466 | + f.run() |
| 1467 | + |
| 1468 | + assert storage == dict(y=[1, 1, 3]) |
| 1469 | + |
| 1470 | + |
1421 | 1471 | def test_scheduled_runs_handle_mapped_retries(): |
1422 | 1472 | class StatefulTask(Task): |
1423 | 1473 | call_count = 0 |
|
0 commit comments