import { extent, median } from 'd3-array';
import {
    scaleLinear,
    scaleOrdinal,
    ScaleOrdinal,
    scaleQuantile,
    ScaleQuantile,
    scaleSequential,
    ScaleSequential,
} from 'd3-scale';
import {
    interpolateBrBG,
    interpolateGnBu,
    interpolateRdBu,
    interpolateRdPu,
    interpolateSpectral,
    interpolateViridis,
    interpolateYlOrBr,
    schemeGnBu,
    schemeRdPu,
    schemeYlOrBr,
} from 'd3-scale-chromatic';
import { Feature, Geometry } from 'geojson';
import { useMemo } from 'react';
import { getFlatValues } from '../helpers/get-flat-values';
import { Indicator } from '../types';

import { hsl } from 'd3-color';
import { ckmeans } from 'simple-statistics';
import { featureIsIncludedInCurrentFilter } from '../helpers/filter';

const DEFAULT_COLOR_SCHEME =
    import.meta.env.VITE_DEFAULT_COLOR_SCHEME ?? 'schemeGnBu';

export const colorScales: { [name: string]: any } = {
    interpolateYlOrBr,
    schemeYlOrBr,
    interpolateSpectral,
    interpolateBrBG,
    interpolateRdBu,
    interpolateGnBu,
    interpolateRdPu,
    schemeGnBu,
    schemeRdPu,
    interpolateViridis,
};

const kommuner = [
    'Botkyrka',
    'Danderyd',
    'Stockholm',
    'Ekerö',
    'Haninge',
    'Huddinge',
    'Järfälla',
    'Lidingö',
    'Nacka',
    'Norrtälje',
    'Nykvarn',
    'Nynäshamn',
    'Österåker',
    'Salem',
    'Sigtuna',
    'Sollentuna',
    'Södertälje',
    'Solna',
    'Sundbyberg',
    'Täby',
    'Tyresö',
    'Upplands Väsby',
    'Upplands-Bro',
    'Vallentuna',
    'Värmdö',
    'Vaxholm',
];

function getColors(nClasses: number) {
    let colors: string[] = [];
    for (let i = 0; i < nClasses; i++) {
        const hslString = `hsl(
            ${i * (360 / nClasses)},
            ${Math.floor(40 + ((i * 10) % 30))}%,
            50%
            )
        `;

        colors.push(hsl(hslString).formatHex());
    }
    return colors;
}

const kommunScale = scaleOrdinal(getColors(26)).domain(kommuner);

export type ScaleType = 'linear' | 'naturalbreaks';
export type ColorType = 'sequential' | 'diverging';
type Args = {
    data: Feature<Geometry>[];
    cVar: Indicator | null;
    scaleType?: ScaleType;
    colorType?: ColorType;
    colorScale?: string;
    featureFilter: string[];
    zoomToFiltered?: boolean;
};

const nGroups = 5;

type ReturnType =
    | { scaleType: 'naturalbreaks'; cScale: ScaleQuantile<string> }
    | {
          scaleType: 'linear';
          cScale: DivergingColorScale | ScaleSequential<string>;
      }
    | { scaleType: 'ordinal'; cScale: ScaleOrdinal<string, string> };

export function useColorScale({
    data,
    cVar,
    scaleType = 'naturalbreaks',
    colorType = 'sequential',

    featureFilter,
    zoomToFiltered,
    colorScale,
}: Args): ReturnType {
    const features = useMemo(
        () =>
            data.filter((d) => {
                if (!zoomToFiltered) return true;
                return featureIsIncludedInCurrentFilter(d, featureFilter);
            }),
        [data, featureFilter, zoomToFiltered]
    );

    const cValueRange = useMemo(() => {
        if (!cVar) {
            return null;
        }

        switch (scaleType) {
            case 'naturalbreaks': {
                const flatValues = getFlatValues<number>(features, cVar.id);
                const _nGroups =
                    flatValues.length <= nGroups
                        ? flatValues.length - 1
                        : nGroups;
                const _extent = extent<number>(flatValues);
                const jenks = ckmeans(flatValues, _nGroups).map((v) => v.pop());
                return [_extent[0], ...jenks];
            }
            case 'linear': {
                const flatValues = getFlatValues<number>(features, cVar.id);
                return extent<number>(flatValues);
            }
        }
    }, [cVar, featureFilter, zoomToFiltered, scaleType, features]);

    if (!cVar) {
        return { cScale: kommunScale, scaleType: 'ordinal' };
    }

    switch (scaleType) {
        case 'naturalbreaks': {
            let scheme = colorScales[
                colorScale ?? DEFAULT_COLOR_SCHEME
            ] as string[];
            let range = cValueRange!;

            // Fallback to avoid cases where user has picked a linear scaletype
            // and then switches to a discrete scale type
            if (!Array.isArray(scheme)) {
                scheme = colorScales[DEFAULT_COLOR_SCHEME] as string[];
            }

            const cScale = scaleQuantile(scheme[nGroups]).domain(range);
            return { cScale, scaleType: 'naturalbreaks' };
        }
        case 'linear': {
            let scheme = colorScales[colorScale ?? 'interpolateGnBu'] as (
                t: number
            ) => string;
            let range = cValueRange as number[];
            let midpoint: number;

            if (colorType === 'diverging') {
                scheme = colorScales[colorScale ?? 'interpolateRdBu'] as (
                    t: number
                ) => string;

                // Fallback to avoid cases where user has picked a discrete scaletype
                // and then switches to a diverging color scale
                if (typeof scheme !== 'function') {
                    scheme = interpolateRdBu;
                }

                if (typeof cVar.midPoint !== 'undefined') {
                    midpoint = cVar.midPoint;
                } else {
                    const flatValues = getFlatValues<number>(features, cVar.id);
                    const cVarMedian = median(flatValues);

                    if (!cVarMedian) {
                        throw new Error('Failed to get median');
                    }

                    midpoint = cVarMedian;
                }

                const c = createDivergingColorScale(
                    midpoint,
                    range[0],
                    range[1],
                    scheme
                );

                return { cScale: c, scaleType: 'linear' };
            }

            const cScale = scaleSequential(scheme).domain(
                cValueRange as number[]
            );

            return { cScale, scaleType: 'linear' };
        }
    }
}

export type DivergingColorScale = {
    (value: number): string;
    domain: () => number[];
};

function createDivergingColorScale(
    midpoint: number,
    minValue: number,
    maxValue: number,
    colorInterpolator: (t: number) => string
): DivergingColorScale {
    // Create two separate scales for the negative and positive parts of the data
    const negativeScale = scaleLinear()
        .domain([minValue, midpoint])
        .range([0, 0.5]);

    const positiveScale = scaleLinear()
        .domain([midpoint, maxValue])
        .range([0.5, 1]);

    // Combine the two scales into a single custom color scale function
    function scale(value: number) {
        if (value < midpoint) {
            return colorInterpolator(negativeScale(value));
        } else {
            return colorInterpolator(positiveScale(value));
        }
    }

    scale.domain = () => [minValue, midpoint, maxValue];
    return scale;
}
