# SPDX-License-Identifier: AGPL-3.0-or-later
import datetime
from pathlib import Path
from typing import List, Dict, NamedTuple, Optional

from pydantic import BaseModel

from .job_log_parser_data_qa import parse_log_for_data_qa, ExtractedDataQARowDTO
from ..utils.ncml.metadata import NcmlImportantMetadata


class DataQATimelineSegment(BaseModel):
    start: datetime.datetime
    end: datetime.datetime
    kind: str
    count: int = 0


class DataQATimelineParameter(BaseModel):
    parameter: str
    segments: List[DataQATimelineSegment]
    no_timestamp_count: int = 0


class DataQATimelineStation(BaseModel):
    station: str
    parameters: List[DataQATimelineParameter]


class DataQATimeline(BaseModel):
    data: List[DataQATimelineStation]
    reprocessing_warnings: List[str]

    @staticmethod
    def make_from_log_file_path(
        log_path: Path,
        filter_station: Optional[str] = None,
        filter_parameter: Optional[str] = None,
        ncml_imps: Optional[List[NcmlImportantMetadata]] = None  # for improved analysis
    ) -> 'DataQATimeline':

        class _Key(NamedTuple):
            station: str
            parameter: str

        _reprocessing_warnings = []
        _accumulated_segments: Dict[_Key, List[DataQATimelineSegment]] = dict()
        _current_segments: Dict[_Key, DataQATimelineSegment] = dict()
        _no_timestamp_counts: Dict[_Key, int] = dict()

        _max_connect_gap_default = datetime.timedelta(hours=24)
        _station_to_max_connect_gap: Dict[str, datetime.timedelta] = dict()
        _station_to_unknown_max_connect_gap_warned: Dict[str, bool] = dict()

        # Process NcML data to get the max connect gap
        if ncml_imps is None:
            _reprocessing_warnings.append('NcML data not available for improved analysis.')
        else:
            for ncml_imp in ncml_imps:
                if ncml_imp.location_station is None:
                    _reprocessing_warnings.append('Found some NcML data without location station.')
                else:
                    if ncml_imp.time_coverage_resolution_timedelta is None:
                        _station_to_unknown_max_connect_gap_warned[ncml_imp.location_station] = True
                        _reprocessing_warnings.append(
                            f'NcML for station {ncml_imp.location_station} does not provide time coverage resolution.')
                        # NOTE: We could set the default max gap to avoid KeyErrors later,
                        #       but the discovered NcML file might actually not be required,
                        #       and the default might not be used.
                    else:
                        _station_to_max_connect_gap[ncml_imp.location_station
                        ] = ncml_imp.time_coverage_resolution_timedelta

        ############

        def _move_current_segment_to_accumulated(_key: _Key) -> None:
            assert _key in _current_segments, "Can only push a current segment! (logic bug otherwise)"
            _segment = _current_segments[_key]  # NOTE: KeyError here is not expected!
            del _current_segments[_key]
            try:
                _accumulated_segments[_key].append(_segment)
            except KeyError:
                _accumulated_segments[_key] = [_segment]

        # NOTE: Will scan the messages...
        #       A current segment will be built and updated...
        #       When a new segment is required, the current one is moved to the accumulated segments...
        #       At the end of the scan, the current segments are moved to the accumulated segments...
        #       Then, only the accumulated segments are reprocessed to produce the output.

        for row in parse_log_for_data_qa(log_path):
            row: ExtractedDataQARowDTO

            if filter_station and row.station != filter_station:
                continue
            if filter_parameter and row.parameter != filter_parameter:
                continue

            _key = _Key(
                station=row.station,
                parameter=row.parameter,
            )

            if row.data_timestamp is None:
                try:
                    _no_timestamp_counts[_key] += 1
                except KeyError:
                    _no_timestamp_counts[_key] = 1
                continue

            try:
                _current_segment = _current_segments[_key]
            except KeyError:
                # No current segment --> New Segment!
                _current_segments[_key] = DataQATimelineSegment(
                    start=row.data_timestamp,
                    end=row.data_timestamp,
                    kind=row.kind,
                    count=1,
                )
            else:
                # There's a current segment --> Update it if same kind and small time gap, add a new segment otherwise.
                try:
                    _max_connect_gap = _station_to_max_connect_gap[row.station]
                except KeyError:
                    _max_connect_gap = _max_connect_gap_default
                    if not _station_to_unknown_max_connect_gap_warned.get(row.station, False):
                        _reprocessing_warnings.append(
                            f'Unknown time coverage resolution for station {row.station}, '
                            f'will default to {_max_connect_gap_default.total_seconds()} seconds.')
                        _station_to_unknown_max_connect_gap_warned[row.station] = True
                    # NOTE: The following line is redundant and simply avoids the above KeyError handling repeatedly.
                    _station_to_max_connect_gap[row.station] = _max_connect_gap_default
                if (
                    _current_segment.kind == row.kind and
                    row.data_timestamp.timestamp() <= (_current_segment.end + _max_connect_gap).timestamp()
                ):
                    # Update it
                    _current_segment.end = row.data_timestamp
                    _current_segment.count += 1
                else:
                    # Add new
                    _move_current_segment_to_accumulated(_key)
                    _current_segments[_key] = DataQATimelineSegment(
                        start=row.data_timestamp,
                        end=row.data_timestamp,
                        kind=row.kind,
                        count=1,
                    )

        for __key_ in list(_current_segments.keys()):
            _move_current_segment_to_accumulated(__key_)

        return DataQATimeline(
            data=[
                DataQATimelineStation(
                    station=station,
                    parameters=[
                        DataQATimelineParameter(
                            parameter=parameter,
                            segments=_accumulated_segments.get(_Key(station, parameter), []),
                            no_timestamp_count=_no_timestamp_counts.get(_Key(station, parameter), 0),
                        )
                        for parameter in
                        {_key.parameter for _key in _accumulated_segments.keys() if _key.station == station}
                    ],
                )
                for station in {_key.station for _key in _accumulated_segments.keys()}
            ],
            reprocessing_warnings=_reprocessing_warnings
        )
