multiprocess-tests-schedule-run #239
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| name: multiprocess-tests-schedule-run | |
| on: | |
| # continuous | |
| schedule: | |
| # Run every 4 hour | |
| - cron: "0 */4 * * *" | |
| permissions: | |
| contents: read | |
| actions: write # to cancel previous workflows | |
| concurrency: | |
| group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} | |
| cancel-in-progress: true | |
| jobs: | |
| multiprocess-checkpoint-benchmarks: | |
| name: "multiprocess-checkpoint-benchmarks (Python ${{ matrix.python-version }}, jax=${{ matrix.jax-version }})" | |
| runs-on: linux-g2-16-l4-1gpu-x4 | |
| # runs-on: linux-x86-ct5lp-4tpu-x4 | |
| container: us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build:infrastructure-public-image-2d2a7b1e6e2e | |
| defaults: | |
| run: | |
| working-directory: checkpoint | |
| strategy: | |
| matrix: | |
| python-version: ["3.12"] | |
| jax-version: ["0.6.0"] | |
| steps: | |
| - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 | |
| - name: Set up Python ${{ matrix.python-version }} | |
| uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 | |
| with: | |
| python-version: ${{ matrix.python-version }} | |
| - name: Install dependencies | |
| run: | | |
| pip install -e . | |
| pip install -e .[testing,gcs] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html | |
| pip uninstall -y orbax | |
| if [[ "${{ matrix.jax-version }}" == "newest" ]]; then | |
| pip install -U jax[k8s,cuda12] jaxlib | |
| elif [[ "${{ matrix.jax-version }}" == "nightly" ]]; then | |
| pip install -U --pre jax[k8s,cuda12] jaxlib --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/simple/ | |
| else | |
| pip install "jax[k8s,cuda12]>=${{ matrix.jax-version }}" "jaxlib>=${{ matrix.jax-version }}" | |
| fi | |
| pip install gcsfs | |
| pip install portpicker | |
| - name: Run correctness integration tests | |
| env: | |
| GCS_BUCKET_PATH: gs://orbax-benchmarks/benchmark-results/${{ github.run_id }} | |
| TF_FORCE_GPU_ALLOW_GROWTH: true | |
| XLA_PYTHON_CLIENT_PREALLOCATE: false | |
| run: | | |
| cd orbax/checkpoint/_src/testing/benchmarks | |
| failed_benchmarks="" | |
| benchmark_configs_file="multiprocess_benchmark_configs.txt" | |
| echo "Running benchmarks specified in $benchmark_configs_file" | |
| while IFS= read -r entry || [ -n "$entry" ]; do | |
| if [ -n "$entry" ]; then | |
| echo "Running benchmark for $entry" | |
| if ! python run_benchmarks.py --config_file="$entry" --output_directory=$GCS_BUCKET_PATH; then | |
| echo "Benchmark $entry failed" | |
| failed_benchmarks="$failed_benchmarks $entry" | |
| fi | |
| fi | |
| done < "$benchmark_configs_file" | |
| cd ../../../../.. | |
| if [ -n "$failed_benchmarks" ]; then | |
| echo "The following benchmarks failed:$failed_benchmarks" | |
| exit 1 | |
| fi | |
| # python -m pytest orbax/checkpoint/_src/handlers/array_checkpoint_handler_test.py | |
| # cd orbax/checkpoint/_src/testing/benchmarks && python run_benchmarks.py --config_file=configs/pytree_checkpoint_benchmark.yaml --output_directory=$GCS_BUCKET_PATH | |
| # The below step just reports the success or failure of tests as a "commit status". | |
| # This is needed for copybara integration. | |
| - name: Run multiprocess tests | |
| env: | |
| TEST_TMPDIR: gs://orbax-benchmarks/unit-tests/${{ github.run_id }} | |
| run: | | |
| python -c "import jax; jax.distributed.initialize(); print(jax.devices()); import pytest; test_files = [line.strip() for line in open('orbax/checkpoint/_src/testing/multiprocess_tests.txt') if line.strip()]; pytest.main(['-c', '/dev/null'] + test_files)" | |
| # python -m pytest orbax/checkpoint/_src/handlers/array_checkpoint_handler_test.py;" | |
| # cd orbax/checkpoint/_src/testing/benchmarks && python run_benchmarks.py --config_file=configs/pytree_checkpoint_benchmark.yaml --output_directory=$GCS_BUCKET_PATH | |
| # python -m pytest orbax/checkpoint/checkpoint_manager_test.py | |
| - name: Report success or failure as github status | |
| if: always() | |
| shell: bash | |
| run: | | |
| status="${{ job.status }}" | |
| lowercase_status=$(echo $status | tr '[:upper:]' '[:lower:]') | |
| curl -sS --request POST \ | |
| --url https://api.github.com/repos/${{ github.repository }}/statuses/${{ github.sha }} \ | |
| --header 'authorization: Bearer ${{ secrets.GITHUB_TOKEN }}' \ | |
| --header 'content-type: application/json' \ | |
| --data '{ | |
| "state": "'$lowercase_status'", | |
| "target_url": "https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}", | |
| "description": "'$status'", | |
| "context": "github-actions/build" | |
| }' |