# -*- coding: utf-8 -*-
"""
Core of the swe2hs algorithm.

Features of this version:
    - constant new snow density as model parameter
    - rho_max depending on overburden a layer has seen
    - switch to rho max wet when swe decreases?

"""

import numpy as np
from numba import njit

from swe2hs import __version__

__author__ = "Johannes Aschauer"
__copyright__ = "Johannes Aschauer"
__license__ = "GPL-3.0-or-later"


@njit
def _calculate_overburden(
    swe_layers: np.ndarray,
) -> np.ndarray:
    # reverse cumsum: https://stackoverflow.com/a/16541726
    # 50% SWE of the layer itself also contributes to overburden.
    return np.cumsum(swe_layers[::-1])[::-1] - swe_layers/2


# Settling
# --------
@njit
def _R_reduction_overburden(
    overburden_layers,
    ob_minsettling,
    ob_maxsettling,
):
    """
    R reduction due to overburden.
    """

    R_r = np.zeros_like(overburden_layers)
    for i, o in enumerate(overburden_layers):
        if o >= ob_maxsettling:
            R_r[i] = 1
        elif o <= ob_minsettling:
            R_r[i] = 0
        else:
            R_r[i] = (o - ob_minsettling)/(ob_maxsettling-ob_minsettling)

    return R_r


@njit
def _R_reduction_density(
    rho_layers,
    rho_minsettling,
    rho_maxsettling,
):

    R_r = np.zeros_like(rho_layers)
    
    for i, r in enumerate(rho_layers):
       
        if r <= rho_maxsettling:
            R_r[i] = 1
        elif r >= rho_minsettling:
            R_r[i] = 0
        else:
            R_r[i] = (r - rho_minsettling)/(rho_maxsettling-rho_minsettling)
    
    return R_r


@njit
def _get_settling_resistance(
    rho_layers,
    overburden_layers,
    R_max,
    R_min,
    ratio_settling_influence_rho_vs_ob,
    ob_minsettling,
    ob_maxsettling,
    rho_minsettling,
    rho_maxsettling,
):

    # calculate weights (scaling factors) of density and overburden.
    w_rho = (R_max-R_min) * ratio_settling_influence_rho_vs_ob/(ratio_settling_influence_rho_vs_ob+1)
    w_sigma = (R_max-R_min) * 1/(ratio_settling_influence_rho_vs_ob+1)
    
    # calculate reduction of R due to density and overburden.
    R_reduction_rho = _R_reduction_density(rho_layers, rho_minsettling, rho_maxsettling)
    R_reduction_sigma = _R_reduction_overburden(overburden_layers, ob_minsettling, ob_maxsettling)
    
    R = R_max - w_rho*R_reduction_rho - w_sigma*R_reduction_sigma
    return R


@njit
def _adjust_rho_max_wet_snowpack(
    rho_max_layers: np.ndarray,
    swe_layers: np.ndarray,
    rho_max_wet: float,
    wetting_speed: float,
) -> np.ndarray:
    return np.where(
        swe_layers > 0,
        rho_max_wet - (rho_max_wet-rho_max_layers)*np.exp(-wetting_speed),
        rho_max_layers
        )


@njit
def _adjust_rho_max_based_on_overburden(
    rho_max_layers,
    overburden_layers,
    swe_layers,
    rho_max_dry,
    rho_max_wet,
    max_overburden
):
    """
    Calculate rho_max based on current overburden:
        - rho_max(ob=0) = rho_max_dry
        - rho_max(ob=max_overburden) = rho_max_wet
        - rho_max(ob>max_overburden) = rho_max_wet
        - in between linear increase

    Parameters
    ----------
    rho_max_layers : :class:`numpy.ndarray`
        current maximum snow density in the layers
    overburden_layers : :class:`numpy.ndarray`
        overburden on the layers
    rho_max_dry : float
        maximum density of "dry" snow.
    rho_max_wet : float
        maximum density of "wet" snow.
    max_overburden : float
        overburden at which the density of wet snow is assumed.

    Returns
    -------
    :class:`numpy.ndarray`
        Adapted rho_max in the layers based on overburden
    """
    
    rho_max_current = np.zeros_like(overburden_layers)
    for i, o in enumerate(overburden_layers):
        if o >= max_overburden:
            rho_max_current[i] = rho_max_wet
        else:
            rho_max_current[i] = ((rho_max_wet-rho_max_dry)/max_overburden)*o + rho_max_dry
    
    updated_rho_max_layers = np.where((rho_max_current>rho_max_layers), 
                                      rho_max_current, 
                                      rho_max_layers)
    
    updated_rho_max_layers = np.where((swe_layers>0), updated_rho_max_layers, 0.)
    
    return updated_rho_max_layers


@njit
def _remove_swe_from_top(
    swe_layers: np.ndarray,
    delta_swe: float
) -> np.ndarray:
    """
    Remove SWE from top of the layers in order to compesate for a loss in SWE
    (i.e. negative `delta_swe`).

    Parameters
    ----------
    swe_layers : :class:`numpy.ndarray`
    delta_swe : float

    Returns
    -------
    :class:`numpy.ndarray`
        updated SWE in the layers
    """
    swe_removed = 0
    l = len(swe_layers)-1
    # melting from top means going backward in layers.
    while swe_removed > delta_swe: # both are (or will be) negative
        swe_removed = swe_removed - swe_layers[l]
        swe_layers[l] = 0
        l = l-1
        # minimal floating point errors can cause the while loop to 
        # run away, in that case we force it to stop at bottom of the
        # snowpack.
        if l == -1:
            break
    # fill up last removed layer with excess swe:
    excess_swe = delta_swe - swe_removed
    if excess_swe > 0:
        swe_layers[l+1] = excess_swe
    return swe_layers


@njit
def timestep_forward(
    delta_swe,
    swe_layers,
    rho_layers,
    rho_max_layers,
    rho_new,
    rho_max_dry,
    rho_max_wet,
    R_max,
    R_min,
    ratio_settling_influence_rho_vs_ob,
    rho_minsettling,
    rho_maxsettling,
    overburden_minsettling,
    overburden_maxsettling,
    max_overburden,
    wetting_speed, 
):
    """Process a timestep forward.

    Assumes that an empty new layer has already been created at the end of
    each of the layer arrays where the function can write values to. This 
    padding at the end can be done with :func:`pad_layer_arrays_with_zero`.

    Parameters
    ----------
    delta_swe : _type_
        _description_
    swe_layers_in : _type_
        _description_
    rho_layers_in : _type_
        _description_
    rho_max_layers_in : _type_
        _description_
    rho_new : _type_
        _description_
    rho_max_dry : _type_
        _description_
    rho_max_wet : _type_
        _description_
    R_max : _type_
        _description_
    R_min : _type_
        _description_
    ratio_settling_influence_rho_vs_ob : _type_
        _description_
    rho_minsettling : _type_
        _description_
    rho_maxsettling : _type_
        _description_
    overburden_minsettling : _type_
        _description_
    overburden_maxsettling : _type_
        _description_
    max_overburden : _type_
        _description_

    Returns
    -------
    _type_
        _description_
    """

    if delta_swe > 0:
        swe_layers[-1] = delta_swe
        rho_layers[-1] = rho_max_dry

    if delta_swe < 0:
        swe_layers = _remove_swe_from_top(swe_layers, delta_swe)
        rho_max_layers = _adjust_rho_max_wet_snowpack(
            rho_max_layers,
            swe_layers,
            rho_max_wet,
            wetting_speed,
        )

    overburden_layers = _calculate_overburden(swe_layers)
    # update rho_max based on overburden: should not be reversible.
    rho_max_layers = _adjust_rho_max_based_on_overburden(
        rho_max_layers,
        overburden_layers,
        swe_layers,
        rho_max_dry,
        rho_max_wet,
        max_overburden
        )
    
    # calculate settling resistance
    R = _get_settling_resistance(
        rho_layers,
        overburden_layers,
        R_max=R_max,
        R_min=R_min,
        ratio_settling_influence_rho_vs_ob=ratio_settling_influence_rho_vs_ob,
        ob_minsettling=overburden_minsettling,
        ob_maxsettling=overburden_maxsettling,
        rho_minsettling=rho_minsettling,
        rho_maxsettling=rho_maxsettling,
        )
    
    # update rho, i.e. calculate settling:
    rho_layers = np.where(
        (swe_layers>0),
        rho_max_layers - (rho_max_layers-rho_layers) * np.exp(-1/R),
        0.)
    
    if delta_swe > 0:
        # rho_new should not get modified in the first timestep.
        rho_layers[-1] = rho_new
        rho_max_layers[-1] = rho_max_dry

    return swe_layers, rho_layers, rho_max_layers


@njit
def _calculate_hs_layers(swe_layers, rho_layers):
    """Loop to avoid division by zero"""
    hs_layers = np.zeros(len(swe_layers), dtype='float64')
    for i, (swe, rho) in enumerate(zip(swe_layers, rho_layers)):
        if swe > 0:
            hs_layers[i] = swe*1000 / rho
    return hs_layers


@njit
def _pad_end_of_array_with_zero(array):
    # np.pad not supported in numba, we need some ugly hacking.
    # https://github.com/numba/numba/issues/4074
    padded = np.zeros(len(array)+1, dtype='float64')
    padded[:-1] = array
    return padded


@njit
def _pad_layer_arrays_with_zero(
    swe_layers_in,
    rho_layers_in,
    rho_max_layers_in,
):
    swe_layers_mod = _pad_end_of_array_with_zero(swe_layers_in)
    rho_layers_mod = _pad_end_of_array_with_zero(rho_layers_in)
    rho_max_layers_mod = _pad_end_of_array_with_zero(rho_max_layers_in)
    return swe_layers_mod, rho_layers_mod, rho_max_layers_mod


@njit
def swe2hs_snowpack_evolution(
    swe_input,
    rho_new,
    rho_max_dry,
    rho_max_wet,
    R_max,
    R_min,
    ratio_settling_influence_rho_vs_ob,
    rho_minsettling,
    rho_maxsettling,
    overburden_minsettling,
    overburden_maxsettling,
    max_overburden,
    wetting_speed,  # should be in range 0.1 - 2.0
):
    """
    Snowpack evolution within the swe2hs model.
    
    Meant to be called on single hydrological years or chunks of nonzeros
    """

    assert R_max >= R_min
    assert overburden_minsettling < overburden_maxsettling
    assert rho_maxsettling < rho_minsettling

    # allocate output arrays:
    hs_out = np.zeros(len(swe_input))
    hs_layers_out = np.zeros((len(swe_input), len(swe_input)))
    rho_layers_out = np.zeros((len(swe_input), len(swe_input)))
    rho_max_layers_out = np.zeros((len(swe_input), len(swe_input)))

    # allocate layer containers
    swe_layers = np.zeros(0) # tracking of swe
    rho_layers = np.zeros(0) # tracking of density
    rho_max_layers = np.zeros(0) #tracking of maximum density

    # iterate through input array.
    for i, swe in enumerate(swe_input):

        # get delta swe:
        if i==0:
            delta_swe = swe
        else:
            delta_swe = swe - swe_input[i-1]

        swe_layers, rho_layers, rho_max_layers = _pad_layer_arrays_with_zero(
            swe_layers,
            rho_layers,
            rho_max_layers,
        )

        swe_layers, rho_layers, rho_max_layers = timestep_forward(
            delta_swe,
            swe_layers,
            rho_layers,
            rho_max_layers,
            rho_new,
            rho_max_dry,
            rho_max_wet,
            R_max,
            R_min,
            ratio_settling_influence_rho_vs_ob,
            rho_minsettling,
            rho_maxsettling,
            overburden_minsettling,
            overburden_maxsettling,
            max_overburden,
            wetting_speed,
        )

        hs_layers = _calculate_hs_layers(swe_layers, rho_layers)

        hs_out[i] = np.sum(hs_layers)
        
        hs_layers_out[:i+1, i] = hs_layers
        
        rho_max_layers_out[:i+1, i] = rho_max_layers
        rho_layers_out[:i+1, i] = rho_layers

    return hs_out, hs_layers_out, rho_layers_out, rho_max_layers_out
