import pytest
import numpy as np
import pandas as pd
import xarray as xr


from swe2hs.one_dimensional import swe2hs_1d
from swe2hs.vectorized import (
    apply_swe2hs,
    _swe2hs_gufunc,
    _wrapped_swe2hs_gufunc,
    )



def test_swe2hs_gufunc_against_1d_version(swe_data_1d_series, default_swe2hs_params):
    params = default_swe2hs_params
    from_1d = swe2hs_1d(swe_data_1d_series, **params)

    from_gufunc = pd.Series(
        data=_swe2hs_gufunc(
            swe_data_1d_series.to_numpy(),
            swe_data_1d_series.index.month.to_numpy(),
            params['rho_new'],
            params['rho_max_dry'],
            params['rho_max_wet'],
            params['R_max'],
            params['R_min'],
            params['ratio_settling_influence_rho_vs_ob'],
            params['rho_minsettling'],
            params['rho_maxsettling'],
            params['overburden_minsettling'],
            params['overburden_maxsettling'],
            params['max_overburden'],
            params['wetting_speed'],
            ),
        index=swe_data_1d_series.index
        )

    pd.testing.assert_series_equal(from_1d, from_gufunc)


def test_wrapped_swe2hs_gufunc_against_1d_version(swe_data_1d_series, default_swe2hs_params):
    from_1d = swe2hs_1d(
        swe_data_1d_series,
        **default_swe2hs_params
        )

    from_call = pd.Series(
        data=_wrapped_swe2hs_gufunc(
            swe_data_1d_series.to_numpy(),
            swe_data_1d_series.index.month.to_numpy(),
            **default_swe2hs_params
            ),
        index=swe_data_1d_series.index
        )

    pd.testing.assert_series_equal(from_1d, from_call)


def test_apply_swe2hs_dask_vs_numpy(
    swe_data_2d_dataarray_numpy,
    swe_data_2d_dataarray_dask,
    default_swe2hs_params
):
    n = apply_swe2hs(swe_data_2d_dataarray_numpy, **default_swe2hs_params)
    d = apply_swe2hs(swe_data_2d_dataarray_dask, **default_swe2hs_params)
    xr.testing.assert_identical(n, d)


@pytest.mark.parametrize(
    "swe_dataarray",
    [
        ("swe_data_2d_dataarray_numpy"),
        ("swe_data_2d_dataarray_dask"),
        ("swe_data_2d_dataarray_numpy_changed_dimorder"),
    ],
)
def test_apply_swe2hs_against_1d(
    swe_dataarray,
    swe_data_1d_series,
    default_swe2hs_params,
    request
):
    swe_da = request.getfixturevalue(swe_dataarray)
    hs_da = apply_swe2hs(swe_data=swe_da, **default_swe2hs_params)
    hs_series = swe2hs_1d(swe_data_1d_series, **default_swe2hs_params)

    # 1.) check if 1d equals the lon=0 and lat=0 pixel.
    pd.testing.assert_series_equal(
        hs_da.sel(lat=0, lon=0).to_pandas(),
        hs_series,
        check_names=False,
        check_freq=False
        )
    # 2.) check if the lon=0 and lat=0 pixel equals all other pixels
    assert (hs_da == hs_da.sel(lat=0, lon=0)).all()


@pytest.mark.parametrize(
    "swe_dataarray",
    [
        ("swe_data_2d_dataarray_numpy"),
        ("swe_data_2d_dataarray_dask"),
        ("swe_data_2d_dataarray_numpy_changed_dimorder"),
    ],
)
def test_dimension_preservation_in_apply_swe2hs(
    swe_dataarray,
    default_swe2hs_params,
    request
):
    swe_da = request.getfixturevalue(swe_dataarray)
    hs_da = apply_swe2hs(swe_data=swe_da, **default_swe2hs_params)

    assert swe_da.dims == hs_da.dims
