import datetime

import marshmallow
import pendulum
import pytest

import prefect
from prefect import __version__, schedules
from prefect.serialization import schedule as schemas


def serialize_fmt(dt):
    p_dt = pendulum.instance(dt)
    return dict(dt=p_dt.naive().to_iso8601_string(), tz=p_dt.tzinfo.name)


all_schedule_classes = set(
    cls
    for cls in schedules.__dict__.values()
    if isinstance(cls, type)
    and issubclass(cls, schedules.Schedule)
    and cls is not schedules.Schedule
)


@pytest.fixture()
def interval_schedule():
    return schedules.IntervalSchedule(
        interval=datetime.timedelta(hours=1),
        start_date=datetime.datetime(2020, 1, 1),
        end_date=datetime.datetime(2020, 5, 1),
    )


@pytest.fixture()
def cron_schedule():
    return schedules.CronSchedule(
        cron="0 0 * * *",
        start_date=datetime.datetime(2020, 1, 1),
        end_date=datetime.datetime(2020, 5, 1),
    )


def test_all_schedules_have_serialization_schemas():
    """
    Tests that all Schedule subclasses in prefect.schedules have corresponding schemas
    in prefect.serialization.schedule
    """

    assert set(s.__name__ for s in all_schedule_classes) == set(
        schemas.ScheduleSchema.type_schemas.keys()
    ), "Not every schedule class has an associated schema"


def test_all_schedules_have_deserialization_schemas():
    """
    Tests that all Schedule subclasses in prefect.schedules have corresponding schemas
    in prefect.serialization.schedule with the correct deserialization class
    """

    assert all_schedule_classes == set(
        s.Meta.object_class for s in schemas.ScheduleSchema.type_schemas.values()
    ), "Not every schedule class has an associated schema"


def test_deserialize_without_type_fails():
    with pytest.raises(marshmallow.exceptions.ValidationError):
        schemas.ScheduleSchema().load({})


def test_deserialize_bad_type_fails():
    with pytest.raises(marshmallow.exceptions.ValidationError):
        schemas.ScheduleSchema().load({"type": "BadSchedule"})


def test_serialize_cron_schedule(cron_schedule):
    schema = schemas.CronScheduleSchema()
    assert schema.dump(cron_schedule) == {
        "cron": cron_schedule.cron,
        "__version__": __version__,
        "start_date": serialize_fmt(cron_schedule.start_date),
        "end_date": serialize_fmt(cron_schedule.end_date),
    }


def test_serialize_interval_schedule(interval_schedule):
    schema = schemas.IntervalScheduleSchema()
    assert schema.dump(interval_schedule) == {
        "start_date": serialize_fmt(interval_schedule.start_date),
        "end_date": serialize_fmt(interval_schedule.end_date),
        "interval": int(interval_schedule.interval.total_seconds()) * 1000000,
        "__version__": __version__,
    }


def test_serialize_onetime_schedule():
    schema = schemas.OneTimeScheduleSchema()
    schedule = schedules.OneTimeSchedule(start_date=pendulum.today("utc"))
    assert schema.dump(schedule) == {
        "__version__": __version__,
        "start_date": serialize_fmt(schedule.start_date),
    }


def test_roundtrip_onetime_schedule():
    schema = schemas.OneTimeScheduleSchema()
    schedule = schedules.OneTimeSchedule(start_date=pendulum.today("utc"))
    new = schema.load(schema.dump(schedule))
    assert isinstance(new, schedules.OneTimeSchedule)
    assert new.start_date == schedule.start_date
    assert new.end_date == schedule.start_date


def test_serialize_interval_at_microsecond_resolution():
    schedule = schedules.IntervalSchedule(
        start_date=pendulum.now("utc"),
        interval=datetime.timedelta(minutes=1, microseconds=1),
    )
    schema = schemas.IntervalScheduleSchema()
    serialized = schema.dump(schedule)
    assert serialized["interval"] == 60000001


def test_serialize_interval_at_annual_resolution():
    schedule = schedules.IntervalSchedule(
        start_date=pendulum.now("utc"),
        interval=datetime.timedelta(days=365, microseconds=1),
    )
    schema = schemas.IntervalScheduleSchema()
    serialized = schema.dump(schedule)
    assert serialized["interval"] == 31536000000001


def test_deserialize_schedule_with_overridden_interval():
    schedule = schedules.IntervalSchedule(
        start_date=pendulum.now("utc"), interval=datetime.timedelta(minutes=1)
    )
    schedule.interval = datetime.timedelta(microseconds=1)
    schema = schemas.IntervalScheduleSchema()
    serialized = schema.dump(schedule)
    assert serialized["interval"] == 1

    with pytest.raises(ValueError) as exc:
        schema.load(serialized)
    assert "Interval can not be less than one minute." in str(exc.value)
