# SPDX-License-Identifier: AGPL-3.0-or-later
from dataclasses import dataclass
from datetime import datetime
from typing import List, Iterable, Tuple, Optional

from litestar import get
from litestar.pagination import OffsetPagination
from sqlalchemy import select
from rapidfuzz import fuzz

from ...models import EntityManager
from ...models.datasets import DatasetRelation, Access
from ...models.errors import UserReportableBadRequestException
from ...models.job_after_effects import get_job_after_effects
from ...utils.ncml.metadata import NcmlImportantMetadata
from ...utils.text import normalize, compute_phonetic_set


@dataclass()
class NcMLSearchEntry:
    ds_id: str
    job_id: str
    ncml_imp: NcmlImportantMetadata


@get(path='/ncml_search', sync_to_thread=True)
def search_datasets(
    entity_manager: EntityManager,
    bbox: Optional[str] = None,
    time_start: Optional[datetime] = None,
    time_end: Optional[datetime] = None,
    text_query: Optional[str] = None,
    limit: int = 1000,
    offset: int = 0,
) -> OffsetPagination[NcMLSearchEntry]:
    items: List[NcMLSearchEntry] = []
    if bbox:
        bbox_parts = bbox.split(",")
        if len(bbox_parts) != 4:
            raise UserReportableBadRequestException("Invalid bbox format: expected 'lon_min,lat_min,lon_max,lat_max'")
        lon_min, lat_min, lon_max, lat_max = map(float, bbox_parts)

        def intersects(md: NcmlImportantMetadata) -> bool:
            b = md.bbox  # computed [minLon, minLat, maxLon, maxLat]
            if not b:
                return False
            return not (
                # No overlap if one box is completely to one side of the other
                b[2] < lon_min  # meta.maxLon < filter.minLon
                or b[0] > lon_max  # meta.minLon > filter.maxLon
                or b[3] < lat_min  # meta.maxLat < filter.minLat
                or b[1] > lat_max  # meta.minLat > filter.maxLat
            )

    else:
        def intersects(md: NcmlImportantMetadata) -> bool:
            return True

    if time_start or time_end:
        def temporal_ok(m: NcmlImportantMetadata) -> bool:
            ts = m.time_coverage_start
            te = m.time_coverage_end

            # if metadata missing one end, drop it
            if not ts or not te:
                return False

            # overlap check: (te >= time_start) and (ts <= time_end)
            if time_start and te < time_start:
                return False
            if time_end and ts > time_end:
                return False
            return True
    else:
        def temporal_ok(m: NcmlImportantMetadata) -> bool:
            return True

    if text_query:
        text_query = normalize(text_query)
        text_phonetic_set = compute_phonetic_set([text_query])

        def text_ok(m: NcmlImportantMetadata) -> bool:
            # TODO: expose score for sorting
            if len(text_phonetic_set.intersection(m.phonetic_set)) >= len(text_phonetic_set) * 0.75:
                return True
            if fuzz.token_set_ratio(text_query, m.normalized_text_haystack) >= 75:
                return True
            return False

    else:
        def text_ok(m: NcmlImportantMetadata) -> bool:
            return True

    for ds_id, job_id, ncml_imps in _iter_public_ds_released_job_ncml_imps(entity_manager):
        for ncml_imp in ncml_imps:
            if intersects(ncml_imp) and temporal_ok(ncml_imp) and text_ok(ncml_imp):
                items.append(NcMLSearchEntry(ds_id=ds_id, job_id=job_id, ncml_imp=ncml_imp))

    return OffsetPagination(
        items=items[offset: offset + limit],
        limit=limit,
        offset=offset,
        total=len(items),  # Count of items matching the filters
    )


def _iter_public_ds_ids(entity_manager: EntityManager) -> Iterable[str]:
    # Code inspired from: pymeteoio_platform.app.public.datasets.list_datasets
    with entity_manager.datasets.index.transaction() as tx:
        stm = (
            select(DatasetRelation)
            .filter(DatasetRelation.access == Access.public)
        )
        for ds_rel, in tx.execute(stm).unique():
            ds_rel: DatasetRelation
            yield ds_rel.dataset_id


def _iter_public_ds_released_job_ids(entity_manager: EntityManager) -> Iterable[Tuple[str, str]]:
    """
    :param entity_manager:
    :return: tuples of <dataset id, job id>
    """
    # Code inspired from: pymeteoio_platform.app.public.datasets.get_released_jobs
    with entity_manager.datasets.runs_index.transaction() as tx:
        for ds_id in _iter_public_ds_ids(entity_manager):
            dataset = entity_manager.datasets.get(ds_id, check_user_permission=False)  # Assuming checked before.
            for ds_run_dto in dataset.iter_runs(tx=tx, released=True):
                yield ds_id, ds_run_dto.job_id


def _iter_public_ds_released_job_ncml_imps(entity_manager: EntityManager) -> Iterable[
    Tuple[str, str, List[NcmlImportantMetadata]]]:
    """
    :param entity_manager:
    :return: tuples of <dataset id, job id, list of ncml_imp>
    """
    # NOTE: Make a cached version of this function?
    #        - datasets and released jobs are fetched fresh from DB.
    #        - job after effects, including ncml_imps parsing and derived data, are already cached.
    #       It seems like it's fast enough (with small numbers) without extra caching...
    for ds_id, job_id in _iter_public_ds_released_job_ids(entity_manager):
        # NOTE: the job is expected be already released => it must be already finished => no expected exceptions.
        jae = get_job_after_effects(job_id)
        yield ds_id, job_id, jae.ncml_imps
