import Decimal from 'decimal.js';

/**
 * Main function to calculate the correlation matrix from RIPS data
 */
export const calculateCorrelationMatrixFromRipsData = (
  ripsData: Map<
    string,
    {
      unitGbp: number | string;
      secId: string;
      performanceId: string;
      date: string;
    }[]
  >,
  order: string[],
) => {
  if (order.length !== Array.from(ripsData.keys()).length) {
    throw new Error(`The order array length does not match the ripsData keys`);
  }

  const returnsByIsin: { [key: string]: Decimal[] } = {};
  const averageReturnsByIsin: { [key: string]: Decimal } = {}; // Cell C9
  const deviationByIsin: { [key: string]: Decimal[] } = {}; // Column D
  const squaredDeviationsByIsin: { [key: string]: Decimal[] } = {}; // Column E
  // Data preparation per isin
  for (const isin of order) {
    const data = ripsData.get(isin);
    // The start value a flat array of the unitGbp values
    const ripsValues =
      data?.map(dataPoint => new Decimal(dataPoint.unitGbp)) || [];

    returnsByIsin[isin] = calculateReturns(ripsValues);
    averageReturnsByIsin[isin] = meanOfValues(returnsByIsin[isin]);
    deviationByIsin[isin] = calculateDeviation(
      returnsByIsin[isin],
      averageReturnsByIsin[isin],
    );
    squaredDeviationsByIsin[isin] = squareDeviations(deviationByIsin[isin]);
  }

  /**
   * Keeping in mind that we have an multi-dimentional array
   * Where the items pushed to correlationMatrix are the rows AKA the Y axis!
   *
   * Table visualization with dummy values of a Correlation Matrix:
   *  ______________________________________
   * |     ISIN     | ISIN1 | ISIN2 | ISIN3 |
   * |--------------|-------|-------|-------|
   * | ISIN1        | 1     | 0.8   | 0.65  |
   * | ISIN2        | 0.8   | 1     | 0.75  |
   * | ISIN3        | 0.65  | 0.75  | 1     |
   */
  const correlationMatrix: number[][] = [];

  // This loop is the rows
  for (let i = 0; i < order.length; i++) {
    const isinY = order[i];
    const xCorrelations: number[] = [];

    // This loop is the columns
    for (let j = 0; j < order.length; j++) {
      if (j <= i) {
        const isinX = order[j];
        const deviationsXY = deviationByIsin[isinY].map((deviationY, index) => {
          const deviationX = deviationByIsin[isinX][index];
          return deviationY.times(deviationX);
        });

        const sXY = sumOfValues(deviationsXY);
        const sY = sumOfValues(squaredDeviationsByIsin[isinY]);
        const sX = sumOfValues(squaredDeviationsByIsin[isinX]);
        const correlation = sXY.dividedBy(sY.times(sX).sqrt()).toNumber();
        xCorrelations.push(correlation);
      } else {
        // Add NaN for elements on the other side of the matrix
        xCorrelations.push(NaN);
      }
    }
    correlationMatrix.push(xCorrelations);
  }
  return correlationMatrix;
};

/**
 * @deprecated - use ./packages/server-core/src/compute/formula/sumOfValues.ts
 */
export const sumOfValues = (values: Decimal[]): Decimal => {
  return values.reduce((acc, value) => acc.plus(value), new Decimal(0));
};

/**
 * @deprecated - use ./packages/server-core/src/compute/formula/meanOfValues.ts
 */
export const meanOfValues = (values: Decimal[]): Decimal => {
  const sum = new Decimal(sumOfValues(values));
  return sum.dividedBy(values.length);
};

/**
 * Correlation matrix specific calculations
 */
export const calculateReturns = (values: Decimal[]): Decimal[] => {
  const returns: Decimal[] = [];
  for (let i = 0; i < values.length; i++) {
    if (i === 0) continue; // Skip first value

    const previousValue = values[i - 1];
    const currentValue = values[i];

    // currentValue / previousValue - 1
    const result = currentValue.dividedBy(previousValue).minus(1);
    returns.push(result);
  }
  return returns;
};

/**
 * @deprecated - use ./packages/server-core/src/compute/formula/calculateDeviation.ts
 */
export const calculateDeviation = (
  values: Decimal[],
  mean: Decimal,
): Decimal[] => {
  const deviations: Decimal[] = [];
  for (let i = 0; i < values.length; i++) {
    const currentValue = values[i];
    const result = currentValue.minus(mean);
    deviations.push(result);
  }
  return deviations;
};

export const squareDeviations = (deviations: Decimal[]): Decimal[] => {
  const squaredDeviations: Decimal[] = [];
  for (let i = 0; i < deviations.length; i++) {
    squaredDeviations.push(deviations[i].pow(2));
  }
  return squaredDeviations;
};
