Skip to content

Commit dbf29ad

Browse files
authored
Save updated base job template to work pool in provision-infra (PrefectHQ#11355)
1 parent 9a99cb1 commit dbf29ad

File tree

2 files changed

+16
-1
lines changed

2 files changed

+16
-1
lines changed

src/prefect/cli/work_pool.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -434,9 +434,17 @@ async def provision_infrastructure(
434434
work_pool.type
435435
)
436436
provisioner.console = app.console
437-
await provisioner.provision(
437+
new_base_job_template = await provisioner.provision(
438438
work_pool_name=name, base_job_template=work_pool.base_job_template
439439
)
440+
441+
await client.update_work_pool(
442+
work_pool_name=name,
443+
work_pool=WorkPoolUpdate(
444+
base_job_template=new_base_job_template,
445+
),
446+
)
447+
440448
except ValueError as exc:
441449
app.console.print(f"Error: {exc}")
442450
app.console.print(

tests/cli/test_work_pool.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -797,6 +797,9 @@ async def test_file(self, mock_collection_registry, tmp_path):
797797

798798
class TestProvisionInfrastructure:
799799
async def test_provision_infra(self, monkeypatch, push_work_pool, prefect_client):
800+
client_res = await prefect_client.read_work_pool(push_work_pool.name)
801+
assert client_res.base_job_template != FAKE_DEFAULT_BASE_JOB_TEMPLATE
802+
800803
mock_provision = AsyncMock()
801804

802805
class MockProvisioner:
@@ -828,6 +831,10 @@ async def provision(self, *args, **kwargs):
828831

829832
assert mock_provision.await_count == 1
830833

834+
# ensure work pool base job template was updated
835+
client_res = await prefect_client.read_work_pool(push_work_pool.name)
836+
assert client_res.base_job_template == FAKE_DEFAULT_BASE_JOB_TEMPLATE
837+
831838
async def test_provision_infra_unsupported(self, push_work_pool):
832839
res = await run_sync_in_worker_thread(
833840
invoke_and_assert,

0 commit comments

Comments
 (0)