import React from "react";
import { scaleLinear, scaleBand } from "d3-scale";
import { stack } from "d3-shape";
import YAxis from "../Axes/YAxis";
import SimpleXAxis from "../SimpleXAxis";
import BarLabels from "../BarLabels";
import { maxed } from "../../utils/func";
import SpringBar from "../GroupedBar/SpringBar";
import StackedLinesVisualization from "./StackedLinesVisualization";
import TargetLine from "../TargetLine/TargetLine";
import applyYAxisFormat from "../Axes/applyYAxisFormat";
import PercentLine from "./PercentLine";

export default function StackedBarTwo(props) {
  const {
    data,
    width,
    height,
    keys,
    colors,
    yFormat,
    yLabelFormat,
    labelFormat,
    xKeys,
    summedValues,
    setTooltip,
    term,
    bands,
    yAxisFormat,
    legendItems,
    hideBarLabels,
    initialData,
    useSameYAxisScale,
    lineKyes,
    circleKeys,
    negativeYKeys = [],
    maxYScale,
    totalTrend,
  } = props;
  if (!data) return null;

  const positiveKeys = keys.filter(
    (key) => !negativeYKeys.find((k) => k === key)
  );
  const { positiveData, negativeData } = separatePositivesAndNegatives(
    data,
    negativeYKeys
  );

  const dataStack = stack().keys(positiveKeys);
  const negativeDataStack = stack().keys(negativeYKeys);
  const seriesA = data[0]?.second ? dataStack(data.map((d) => d.second)) : null; // ignore negative values here
  const seriesB = dataStack(positiveData.map((d) => d.first)); // Main
  const seriesBNegative = negativeDataStack(negativeData.map((d) => d.first));
  // then we migrate over

  const x = scaleBand().domain(xKeys).range([0, width]).paddingInner(0.4);

  // We create an arbitrary array that would have a length of up to 2 depending on how many bar groups are used.
  // We pass this array to domain().
  const arbitraryArray = [seriesA, seriesB].filter((v) => v != null);

  const xInner = scaleBand()
    .domain(arbitraryArray)
    .range([0, x.bandwidth()])
    .paddingInner(0.4);

  const rightYAxisKeys = [...lineKyes, ...circleKeys];

  const lineAndCircleYValues = initialData.flatMap((item) =>
    rightYAxisKeys.map((key) => item[key])
  );

  // const leftAxisMaxValue = maxed(summedValues.map((s) => +s.total));
  const rightAxisMaxValue = maxed(lineAndCircleYValues);

  // left y axis scale
  // instead of zero here, we want to use the negative
  const maxPositive = findMaxSum(data, positiveKeys);
  const maxNegative = findMaxSum(data, negativeYKeys);
  // on useSameYAxisScale we need to compare both axes max values to get bigger one
  const max = maxYScale
    ? maxYScale
    : useSameYAxisScale
    ? Math.max(maxPositive, rightAxisMaxValue)
    : maxPositive;

  const fullYScale = scaleLinear()
    .domain([-maxNegative * 1.02, max * 1.02])
    .range([height, 0]);
  const zeroHeight = fullYScale(0);

  const negativeYScale = negativeYKeys.length
    ? scaleLinear()
        .domain([0, maxNegative * 1.02])
        .range([zeroHeight, height])
    : null;
  const yScale1 = scaleLinear()
    .domain([0, max * 1.02])
    .range([zeroHeight, 0]);

  // right y axis scale
  const yScale2 = scaleLinear()
    .domain([0, rightAxisMaxValue])
    .range([height, 0]);

  function compareDifferentTypes(name, key) {
    return name === key || +name === +key;
  }

  const setStaticColor = (key, i) => {
    const { color } =
      legendItems.find((lI) => compareDifferentTypes(lI.name, key)) ?? {};

    if (!color) {
      return colors(i);
    }

    return color;
  };

  return (
    <g data-cy="stacked-bar-two-container">
      <g data-cy="grouped-bar-value-axis">
        <YAxis
          {...props}
          yScale={fullYScale}
          yTicksCount={props.yTicksCount}
          yTicksColor={props.yTicksColor}
          hideYAxisLine={props.hideYAxisLine}
          hideYAxisTicks={props.hideYAxisTicks}
          yAxisGrid={props.yAxisGrid}
          yAxisGridColor={props.yAxisGridColor}
          width={width}
          yAxisFormat={yAxisFormat || applyYAxisFormat(yFormat)}
        />
      </g>
      {seriesA
        ? seriesA.map((s, i) =>
            s.map((v, j) =>
              isNaN(x(xKeys[j])) ? null : (
                <SpringBar
                  key={j}
                  x={x(xKeys[j])}
                  width={xInner.bandwidth() * 0.6}
                  y={yScale1(v[1])}
                  height={
                    height - yScale1(isNaN(v[1] - v[0]) ? 0 : v[1] - v[0])
                  }
                  color={legendItems ? setStaticColor(s.key, i) : colors(i)}
                  startPos={height}
                  opacity={0.4}
                  skipAnimation={seriesB[0].length > 20}
                />
              )
            )
          )
        : null}
      {seriesB
        ? seriesB.map((s, i) =>
            s.map((v, j) =>
              isNaN(x(xKeys[j])) ? null : (
                <SpringBar
                  key={j}
                  x={x(xKeys[j]) + x.bandwidth() * 0.4}
                  width={xInner.bandwidth()}
                  y={yScale1(v[1])}
                  height={
                    zeroHeight - yScale1(isNaN(v[1] - v[0]) ? 0 : v[1] - v[0])
                  }
                  color={legendItems ? setStaticColor(s.key, i) : colors(i)}
                  startPos={height}
                  onMouseEnter={() =>
                    setTooltip({
                      key: s.key,
                      keyLabel: s.key,
                      data: v.data,
                      xPos: Math.floor(
                        x(xKeys[j]) +
                          x.bandwidth() * 0.45 +
                          xInner.bandwidth() * 0.6
                      ),
                      yPos: height - yScale1(v[1]),
                    })
                  }
                  onMouseLeave={() => setTooltip(null)}
                  skipAnimation={seriesB[0].length > 20}
                  useLines={seriesB[0].length > 100}
                />
              )
            )
          )
        : null}
      {seriesBNegative
        ? seriesBNegative.map((s, i) =>
            s.map((v, j) =>
              isNaN(x(xKeys[j])) ? null : (
                <SpringBar
                  key={j}
                  x={x(xKeys[j]) + x.bandwidth() * 0.4}
                  width={xInner.bandwidth()}
                  y={negativeYScale(0)}
                  height={negativeYScale(v[0] - v[1]) - negativeYScale(0)}
                  color={
                    legendItems ? setStaticColor(s.key, i + 1) : colors(i + 1)
                  }
                  startPos={height}
                  onMouseEnter={() =>
                    setTooltip({
                      key: s.key,
                      keyLabel: s.key,
                      data: v.data,
                      xPos: Math.floor(
                        x(xKeys[j]) +
                          x.bandwidth() * 0.45 +
                          xInner.bandwidth() * 0.6
                      ),
                      yPos: height - yScale1(v[1]),
                    })
                  }
                  onMouseLeave={() => setTooltip(null)}
                  skipAnimation={seriesB[0].length > 20}
                  useLines={seriesB[0].length > 100}
                />
              )
            )
          )
        : null}
      <SimpleXAxis
        width={width + 15}
        height={height}
        xScale={x}
        values={xKeys}
        xAxisDate={props.xAxisDate}
        xKeyFormat={props.xKeyFormat ?? term}
        bands={bands}
        xInner={xInner}
        allTicks
      />

      {!hideBarLabels && (
        <BarLabels
          skipAnimation={seriesB[0] && seriesB[0].length > 20}
          values={summedValues}
          x={x}
          y={yScale1}
          yFormat={labelFormat || yLabelFormat || yFormat}
          xInner={xInner}
        />
      )}

      {negativeYKeys.length ? (
        <TargetLine
          width={width}
          y={fullYScale}
          target={0}
          thickness={1}
          color="#499bff"
        />
      ) : null}

      {totalTrend ? <PercentLine data={data} y={fullYScale} x={x} /> : null}

      <StackedLinesVisualization
        {...props}
        initialData={initialData}
        width={width}
        height={height}
        xScale={x}
        setStaticColor={setStaticColor}
        bandWidth={xInner.bandwidth()}
        yScale={useSameYAxisScale ? fullYScale : yScale2}
        hideAxis={!!useSameYAxisScale}
      />
    </g>
  );
}

StackedBarTwo.defaultProps = {
  labelFormat: ".2s",
  valueKey1: "value1",
  valueKey2: "value2",
  xKey: "month",
  labelSize: 12,
  labelWeight: "normal",
  labelColor: "black",
  intraGroupSpacing: 1,
  valuePadding: 1,
  xTicksColor: "black",
  barWidth1: 1,
  barWidth2: 1,
  meta: { fields: [] },
};

const findMaxSum = (data, keys) => {
  if (!keys) return 0;
  let maxSum = -Infinity; // Initialize to the lowest possible value to ensure any sum will be higher

  data.forEach((item) => {
    const sum = keys.reduce((acc, key) => {
      // Add the value if key exists in the object, otherwise add 0
      return acc + (item.first[key] || 0);
    }, 0);

    if (sum > maxSum) {
      maxSum = sum;
    }
  });

  return maxSum;
};

const separatePositivesAndNegatives = (data, negativeYKeys) => {
  return data.reduce(
    (acc, item) => {
      const positives = {};
      const negatives = {};

      Object.keys(item.first).forEach((key) => {
        if (negativeYKeys.includes(key)) {
          negatives[key] = item.first[key] * -1;
        } else {
          positives[key] = item.first[key];
        }
      });

      if (Object.keys(positives).length > 0) {
        acc.positiveData.push({ first: positives, xKey: item.xKey });
      }
      if (Object.keys(negatives).length > 0) {
        acc.negativeData.push({ first: negatives, xKey: item.xKey });
      }

      return acc;
    },
    { positiveData: [], negativeData: [] }
  );
};
