import numpy as np
import xarray as xr
import dask.array as da
from numba import guvectorize

from .core import swe2hs_snowpack_evolution

from ._default_model_parameters import *


@guvectorize(
    ['void(float64[:], int64[:], float64, float64, float64, float64, float64, float64, float64, float64, float64, float64, float64, float64, float64[:])'],
    '(n),(n),(),(),(),(),(),(),(),(),(),(),(),()->(n)',
    nopython=True
    )
def _swe2hs_gufunc(
    swe_input,
    month_input,
    rho_new,
    rho_max_dry,
    rho_max_wet,
    R_max,
    R_min,
    ratio_settling_influence_rho_vs_ob,
    rho_minsettling,
    rho_maxsettling,
    overburden_minsettling,
    overburden_maxsettling,
    max_overburden,
    wetting_speed,
    hs_out,
):
    """
    Numba gufunc which resets the snowpack at the September 2nd and calls
    :func:`swe2hs_snowpack_evolution` for every hydrological year.
    """
    month = 0
    month_day_before = 0
    month_two_days_before = 0
    split_locations = []
    for i, month in enumerate(month_input):
        if i == 0 or (month_two_days_before != 9 and month_day_before == 9 and month == 9):
            split_locations.append(i)
        month_two_days_before = month_day_before
        month_day_before = month
    split_locations.append(len(month_input))

    for start, stop in zip(split_locations[:-1], split_locations[1:]):
        hs_out[start:stop] = swe2hs_snowpack_evolution(
            swe_input[start:stop],
            rho_new,
            rho_max_dry,
            rho_max_wet,
            R_max,
            R_min,
            ratio_settling_influence_rho_vs_ob,
            rho_minsettling,
            rho_maxsettling,
            overburden_minsettling,
            overburden_maxsettling,
            max_overburden,
            wetting_speed,
            )[0]


def _wrapped_swe2hs_gufunc(
    swe_input,
    month_input,
    rho_new=RHO_NEW,
    rho_max_dry=RHO_MAX_DRY,
    rho_max_wet=RHO_MAX_WET,
    R_max=R_MAX,
    R_min=R_MIN,
    ratio_settling_influence_rho_vs_ob=RATIO_SETTLING_INFLUENCE_RHO_VS_OB,
    rho_minsettling=RHO_MINSETTLING,
    rho_maxsettling=RHO_MAXSETTLING,
    overburden_minsettling=OVERBURDEN_MINSETTLING,
    overburden_maxsettling=OVERBURDEN_MAXSETTLING,
    max_overburden=MAX_OVERBURDEN,
    wetting_speed=WETTING_SPEED,
):
    """
    Wrap the gufunc in order to accept keyword arguments.
    """
    # initialize output
    hs_out = np.zeros(len(swe_input), dtype=np.float64)
    # call vetorized function
    hs_out = _swe2hs_gufunc(
        swe_input,
        month_input,
        rho_new,
        rho_max_dry,
        rho_max_wet,
        R_max,
        R_min,
        ratio_settling_influence_rho_vs_ob,
        rho_minsettling,
        rho_maxsettling,
        overburden_minsettling,
        overburden_maxsettling,
        max_overburden,
        wetting_speed,
        )
    return hs_out


def apply_swe2hs(
    swe_data,
    rho_new=RHO_NEW,
    rho_max_dry=RHO_MAX_DRY,
    rho_max_wet=RHO_MAX_WET,
    R_max=R_MAX,
    R_min=R_MIN,
    ratio_settling_influence_rho_vs_ob=RATIO_SETTLING_INFLUENCE_RHO_VS_OB,
    rho_minsettling=RHO_MINSETTLING,
    rho_maxsettling=RHO_MAXSETTLING,
    overburden_minsettling=OVERBURDEN_MINSETTLING,
    overburden_maxsettling=OVERBURDEN_MAXSETTLING,
    max_overburden=MAX_OVERBURDEN,
    wetting_speed=WETTING_SPEED,
    time_dim_name='time',
):  
    """
    Apply the swe2hs model on a :class:`xarray.DataArray`.
    
    This function calls the vectorized version of swe2hs within
    :func:`xarray.apply_ufunc`. 

    Parameters
    ----------
    swe_data : :class:`xarray.DataArray`
        DataArray containing the SWE data in [m].
    rho_new : float, optional
        _description_, by default 100.0
    rho_max_dry : float, optional
        _description_, by default 300.0
    rho_max_wet : float, optional
        _description_, by default 500.0
    R_max : float, optional
        _description_, by default 35.0
    R_min : float, optional
        _description_, by default 5.0
    ratio_settling_influence_rho_vs_ob : float, optional
        _description_, by default 1.0
    rho_minsettling : float, optional
        _description_, by default 1000.0
    rho_maxsettling : float, optional
        _description_, by default 50.0
    overburden_minsettling : _type_, optional
        _description_, by default 0.
    overburden_maxsettling : float, optional
        _description_, by default 0.2
    max_overburden : float, optional
        _description_, by default 2.0
    time_dim_name : str, optional
        _description_, by default 'time'

    Returns
    -------
    :class:`xarray.DataArray
        Calculated snow depth, same format as the input data.

    Raises
    ------
    TypeError
        _description_
    """
    if not isinstance(swe_data, xr.DataArray):
        raise TypeError("swe2hs: swe data needs to be a xarray.DataArray.")

    input_dims = swe_data.dims

    if time_dim_name not in input_dims:
        raise ValueError(("swe2hs: you assigned the time dimension name "
                          "'{time_dim_name}' which is \nnot in the dimensions "
                          "of the SWE input DataArray."))

    # pass parameters to a dict for later reuse
    params = {
        'rho_new': rho_new,
        'rho_max_dry': rho_max_dry,
        'rho_max_wet': rho_max_wet,
        'R_max': R_max,
        'R_min': R_min,
        'ratio_settling_influence_rho_vs_ob': ratio_settling_influence_rho_vs_ob,
        'rho_minsettling': rho_minsettling,
        'rho_maxsettling': rho_maxsettling,
        'overburden_minsettling': overburden_minsettling,
        'overburden_maxsettling': overburden_maxsettling,
        'max_overburden': max_overburden,
        'wetting_speed': wetting_speed,
    }
    if isinstance(swe_data.data, np.ndarray):
        hs = (xr
            .apply_ufunc(
                _wrapped_swe2hs_gufunc,
                swe_data,
                swe_data.coords[f'{time_dim_name}.month'],
                kwargs=params,
                input_core_dims=[[time_dim_name], [time_dim_name]],
                output_core_dims=[[time_dim_name]],
                )
            )
    elif isinstance(swe_data.data, da.core.Array):
        hs = (xr
            .apply_ufunc(
                _wrapped_swe2hs_gufunc,
                swe_data,
                swe_data.coords[f'{time_dim_name}.month'],
                kwargs=params,
                input_core_dims=[[time_dim_name], [time_dim_name]],
                output_core_dims=[[time_dim_name]],
                dask='parallelized',
                output_dtypes=['float64']
                )
            )
    else:
        raise TypeError(("swe2hs: underlying data in the xr.DataArray needs to "
                         "be numpy.ndarray or dask array."))

    return hs.transpose(input_dims[0], input_dims[1], input_dims[2])
