# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: MIT

import pandas as pd
from mpp.core.types import RawDataFrameColumns as rdc
from pandas import MultiIndex
import re


class Normalizer:
    def __init__(self, ref_tsc):
        self.ref_tsc = ref_tsc
        self.__event_exclusions = ['$samplingTime', '$processed_samples']
        self.__non_event_exclusions = ['retire_latency', 'type=STATIC', 'REG_STATIC']
        self.__all_exclusions = [re.escape(e) for e in self.__event_exclusions] + self.__non_event_exclusions

    def normalize(self, df: pd.DataFrame, event_axis: str = 'columns') -> pd.DataFrame:
        """
        Computes normalized event count

        @param df: data frame containing data to normalize
        @param event_axis: axis where event names exist, must be either 'columns' or 'index'

        @return a copy of df where the "value" column is updated to contain normalized values
        """
        self.__validate_event_axis(event_axis)
        do_not_normalize, normalize = self.__split_df(df.copy(), event_axis)
        normalize[rdc.VALUE] = normalize[rdc.VALUE] * self.ref_tsc / normalize[rdc.TSC]
        normalized_df = pd.concat([normalize, do_not_normalize])
        return normalized_df

    def __split_df(self, df: pd.DataFrame, event_axis: str):
        do_not_normalize = []
        if rdc.NAME not in df.columns and not self.__contains_all_event_exclusions(df.index):
            return None, df  # all rows can be normalized
        if event_axis == 'columns' and rdc.NAME in df.columns:
            do_not_normalize = self.__get_all_exclusions_by_column(df)
        elif event_axis == 'index' and self.__contains_all_event_exclusions(df.index):
            do_not_normalize = self.__get_all_exclusions_by_index(df)
        do_not_normalize = df.loc[do_not_normalize]
        normalize = df.loc[list(set(df.index).difference(set(do_not_normalize.index)))]
        return do_not_normalize, normalize

    def __contains_all_event_exclusions(self, multi_index: MultiIndex) -> bool:
        return all(metric in multi_index for metric in self.__event_exclusions)

    def __get_all_exclusions_by_column(self, df: pd.DataFrame):
        return df[rdc.NAME].str.contains('|'.join(self.__all_exclusions))

    def __get_all_exclusions_by_index(self, df: pd.DataFrame):
        return df.index.get_level_values(rdc.NAME).str.contains('|'.join(self.__all_exclusions))

    @staticmethod
    def __validate_event_axis(event_axis: str):
        if event_axis not in ['columns', 'index']:
            raise ValueError("'event_axis' argument must be either 'columns' or 'index'")
