#!/usr/bin/env python3
import os
import sys
import subprocess
from subprocess import PIPE
import re
import hashlib

def run_psql(container, expr, dbname="postgres", user="wire-server"):
    p = (
        subprocess.run(
            ["docker", "exec", "-i", container, "psql", "-U", user, "-d", dbname, "-c", expr],
            stdout=PIPE,
            check=True,
        )
        .stdout.decode("utf8")
        .strip()
    )
    return p

def run_pg_dump(container, dbname, user="wire-server"):
    p = (
        subprocess.run(
            ["docker", "exec", "-i", container, "pg_dump", "-h", "localhost", "-p", "5432", "-U", user, "-d", dbname, "--schema-only"],
            stdout=PIPE,
            check=True,
        )
        .stdout.decode("utf8")
        .strip()
    )
    return p

def get_container_id():
    return (
        subprocess.run(
            ["docker", "ps", "--filter=name=postgres", "--format={{.ID}}"],
            stdout=PIPE,
            check=True,
        )
        .stdout.decode("utf8")
        .rstrip()
    )

def list_databases(container, user="wire-server"):
    out = run_psql(container, "\\l", user=user)
    # Extract database names from output
    dbs = []
    for line in out.splitlines():
        match = re.match(r"\s*(\w+)\s*\|", line)
        if match and match.group(1) not in ("Name", "template0", "template1", "postgres"):
            dbs.append(match.group(1))
    return dbs

def normalize_rls_hashes(dump_output, dbname):
    """Replace random RLS policy hashes with deterministic ones based on database name."""
    # Create a deterministic hash based on the database name
    det_hash = hashlib.sha256(dbname.encode()).hexdigest()[:63]

    # Replace both \restrict and \unrestrict hashes
    output = re.sub(r'\\restrict [A-Za-z0-9]+', f'\\\\restrict {det_hash}', dump_output)
    output = re.sub(r'\\unrestrict [A-Za-z0-9]+', f'\\\\unrestrict {det_hash}', output)

    return output

def main():
    container = get_container_id()
    print("-- automatically generated with `make postgres-schema`")

    psql_db = os.environ.get('PSQL_DB', 'all')
    print("processing", psql_db+"...", file=sys.stderr)

    if psql_db == 'all':
        databases = list_databases(container)
    else:
        databases = [psql_db] if psql_db is not None else []

    for db in databases:
        print(f"\n------------------------------------------------------------------------------------------")
        print(f"-- Database: {db}\n")
        dump = run_pg_dump(container, db, user="wire-server")
        print(normalize_rls_hashes(dump, db))

if __name__ == "__main__":
    main()
