import { Layout, PlotMouseEvent } from 'plotly.js';
import Plot from 'react-plotly.js';
import {
  DECREASE_COLOR,
  INCREASE_COLOR,
  PRICE_COLOR,
  PAPER_BG_COLOR,
  PLOT_BG_COLOR,
} from 'shared/utils/plotly';
import { quantile } from 'd3-array';

export const DEMAND = 'demand';
export const MARKET = 'market';
export const SUPPLY = 'supply';

interface waterfallJson {
  name: string;
  trace: {
    x: Array<number>;
    y: Array<string>;
    text: Array<string>;
    measure: Array<string>;
    base: number;
  };
  layout: {
    title: string;
    width: number;
    height: number;
    xaxis: {
      range: Array<number>;
    };
  };
}

interface WaterFallProps {
  waterfall: waterfallJson;
  onClick?: ((event: Readonly<PlotMouseEvent>) => void) | undefined;
  layout?: any;
  groupSmallValues?: number;
}

const WaterFall = ({
  waterfall,
  onClick,
  layout,
  groupSmallValues,
}: WaterFallProps) => {
  const waterfall_copy = JSON.parse(JSON.stringify(waterfall)) as waterfallJson;
  const { trace, layout: api_layout } = waterfall_copy;
  const { title, xaxis } = api_layout;
  var { x, y, measure, text, base } = trace;

  if (base > xaxis.range[0]) {
    for (let i in measure) {
      if (['absolute', 'total'].includes(measure[i])) {
        x[i] += base - xaxis.range[0];
      }
    }
    base = xaxis.range[0];

    const margin = 5;

    xaxis.range[0] = xaxis.range[0] - margin;

    xaxis.range[1] = xaxis.range[1] + margin;
  }

  for (let i in text) {
    if (typeof text[i] === typeof 0) {
      text[i] = parseFloat(text[i]).toFixed(1);
    }
  }

  // Filter features with 0 contribution
  const filter_i: number[] = [];
  let range = base;
  for (let i in measure) {
    if (['+-0.0', '+0.0'].includes(text[i])) {
      filter_i.push(parseInt(i));
    }
    range = range + x[i];
  }

  const excludeIndices = (exclude_list: number[]) => {
    x = x.filter((_, index) => !exclude_list.includes(index));
    y = y.filter((_, index) => !exclude_list.includes(index));
    measure = measure.filter((_, index) => !exclude_list.includes(index));
    text = text.filter((_, index) => !exclude_list.includes(index));
  };

  excludeIndices(filter_i);

  const SPECIAL_ROWS = [DEMAND, MARKET, SUPPLY, 'Base Value Estimator'];

  if (groupSmallValues) {
    const filter_small_value: number[] = [];
    var insertPosition = measure.length;

    // Get total sum of values in market
    const market_contributions = [];

    for (let i in measure) {
      if (SPECIAL_ROWS.includes(y[i].toLowerCase())) {
        continue;
      }
      if (measure[i] === 'relative') {
        market_contributions.push(Math.abs(x[i]));
      }
    }

    market_contributions.sort();

    const cutoff = quantile(market_contributions, 1.0 - groupSmallValues) ?? 0;

    var sum_small_value = 0;

    for (let i in measure) {
      if (SPECIAL_ROWS.includes(y[i].toLowerCase())) {
        continue;
      }
      if (measure[i] === 'relative' && Math.abs(x[i]) <= cutoff) {
        sum_small_value += Math.abs(x[i]);
        filter_small_value.push(parseInt(i));
        insertPosition = Math.min(insertPosition, parseInt(i));
      }
    }
    if (filter_small_value.length > 0) {
      excludeIndices(filter_small_value);
      x.splice(insertPosition, 0, sum_small_value);
      y.splice(
        insertPosition,
        0,
        `Bottom 20% contributions (impact < |${cutoff.toFixed(1)}|)`,
      );
      measure.splice(insertPosition, 0, 'relative');
      text.splice(insertPosition, 0, sum_small_value.toFixed(1));
    }
  }

  const waterfallHeight = measure.length * 35 + 20;

  const waterFallTrace = {
    name: title,
    x: x,
    y: y,
    measure: measure,
    text: text,
    base: base,
    orientation: 'h',
    type: 'waterfall',
    decreasing: { marker: { color: DECREASE_COLOR } },
    increasing: { marker: { color: INCREASE_COLOR } },
    totals: { marker: { color: PRICE_COLOR } },
    textposition: 'outside',
    connector: {
      mode: 'spanning',
      line: { width: 1, color: 'rgba(255,255,255,.7)', dash: 'solid' },
    },
  };

  const _layout: Partial<Layout> = {
    ...api_layout,
    title: null,
    paper_bgcolor: PAPER_BG_COLOR,
    plot_bgcolor: PLOT_BG_COLOR,
    font: {
      color: '#fff',
    },
    margin: {
      t: 8,
    },
    xaxis: {
      linecolor: 'rgba(255,255,255,.5)',
      range: xaxis.range,
      automargin: true,
      gridcolor: 'rgba(255,255,255,.5)',
    },
    yaxis: {
      linecolor: 'rgba(255,255,255,.5)',
      automargin: true,
      autorange: 'reversed',
    },
    height: waterfallHeight,
    ...layout,
  };

  return (
    <Plot
      data={[waterFallTrace as any]}
      layout={_layout}
      onClick={onClick}
      config={{ responsive: true }}
    />
  );
};

export default WaterFall;
