diff --git a/reclist/reclist.py b/reclist/reclist.py index 634b2b1..f63fb08 100644 --- a/reclist/reclist.py +++ b/reclist/reclist.py @@ -8,6 +8,8 @@ from reclist.charts import CHART_TYPE from reclist.logs import LOGGER, logger_factory from reclist.metadata import METADATA_STORE, metadata_store_factory +from reclist.summarizer import summarize_statistics +from typing import Optional import datetime def rec_test(test_type: str, display_type: CHART_TYPE = None): @@ -148,7 +150,17 @@ def _display_rich_table(self, table_name: str, results: list): return - def __call__(self, verbose=True, *args, **kwargs): + def __call__(self, verbose: str = True, summarize: bool = False, compare_model: Optional[ABC] = None, *args, **kwargs): + """Call reclist + + Allow to get metrics. + If you decide to summarize it is assumed you have an OpenAI key setup in ~/.openai_api_key + + Args: + verbose (str): Wether to have verbose output or not + summarize (bool): Wether to summarize or not the metrics. If compare_model is provided will compare to it. + compare_model (Optional[RecList]): Optional, Model to compare to. + """ from rich.progress import track self.meta_store_path = self.create_data_store() @@ -172,7 +184,15 @@ def __call__(self, verbose=True, *args, **kwargs): test_2_fig = self._generate_report_and_plot(self._test_results, self.meta_store_path) for test, fig in test_2_fig.items(): self.logger_service.save_plot(name=test, fig=fig) - + if summarize: + compare_model_name = None + compare_statistics = None + if compare_model is not None and hasattr(compare_model, 'model_name') and hasattr(compare_model, '_test_results'): + compare_model_name = compare_model.model_name + compare_statistics = compare_model._test_results + summary = summarize_statistics(model_name= self.model_name, statistics= self._test_results, + compare_model_name= compare_model_name, compare_statistics= compare_statistics) + print(summary) return def _generate_report_and_plot(self, test_results: list, meta_store_path: str): diff --git a/reclist/summarizer.py b/reclist/summarizer.py new file mode 100644 index 0000000..dcfa5f5 --- /dev/null +++ b/reclist/summarizer.py @@ -0,0 +1,88 @@ +import guidance +from typing import Optional + +guidance.llm = guidance.llms.OpenAI("gpt-3.5-turbo") +PROMPT = """ +{{#system~}} +Assume you are a Data Scientist assistant helping Data Science practicioners evaluate their recommender system models. +You will be given a list of metrics and you should do 2 tasks: +1. Help summarize the finding +2. Provide advice on what to do that could increase the metrics +You will report your finding being specific for example referring the actual metric and values while being succinct using bullets points. +For example you can look at correlations between metrics, outliers or range of the metrics to draw conclusion. +As a Data Scientist you do not need to report on all the metrics but only on the one providing incremental value to the analysis. +Therefore, it is key to only output information that provide value maximizing the value while minimizing the verbosity of it. +Do not hesitate to group multiple metrics into one bullet point if they are going towards the same conclusion. +It will only make your reasoning stronger. +You should aim for each bullet point to do no more than one sentence so digesting this information is easy and fast. +Finally, you do not need to explain what the metrics are as you are already speaking to an expert. +Do not hesitate to use technical jargon if it helps you to be more concise. +The metrics follow an array of json with each element having theses keys: + 1. "name" This is the name of the metric it follows this pattern _ where slice name is optional. + 2. "description" is an optional description entered by the user + 3. "result" this is where you will get the metric value or additional slice from the metric "name" + 4. "display_type" Ignore this + +In addition here is a mapping of the metrics name: +MRR means mean reciprocal rank +HIT_RATE means hit rate +MRED means miss rate equality difference +BEING_LESS_WRONG compute the cosine similarity between the true label and the predictions. +MR means miss rate which is the opposite of HIT_RATE +{{#if compare_statistics}} +You will be given 2 sets of model metrics to compare 2 different models. +Please focus on the comparison so the Data Scientist can draw conclusion. +{{/if}} +{{~/system}} +{{#user~}} +Given that I have a model that I named {{model_name}} and this statistics: + {{statistics}} +{{#if compare_statistics}} +In addition, my second model is named {{compare_model_name}} and has this statistics: +{{compare_statistics}} +{{/if}} +Please summarize your findings. +{{~/user}} + +{{#assistant~}} +{{gen 'out' temperature=0}} +{{~/assistant}} +""" +PROGRAM = guidance(PROMPT) + + +def summarize_statistics( + model_name: str, + statistics: list, + compare_model_name: Optional[str] = None, + compare_statistics: Optional[list] = None, +) -> Optional[str]: + """This function use OpenAI to summarize or compare 2 models statistics from reclist + + `compare_model_name` and `compare_statistics` are optional and only used if you want to compare 2 models. + If not used it will provide summary on one model defined by `model_name` `statistics` + + + Args: + model_name (str): Model Name + statistics (list): List of statistics as defined by reclist + compare_model_name (Optional[str]): Optional, Model Name to compare + compare_statistics (Optional[list]): Optional, statistics as defined in reclist to compare + + Returns: + String that summarize the model statistics or comparison between model and target model + + Raises: + ValueError: If one of the two is not None while the other is `compare_model_name` or `compare_statistics` + """ + if compare_model_name is not None and compare_statistics is None: + raise ValueError( + "You have specified a compare_model_name without compare_statistics" + ) + if compare_model_name is None and compare_statistics is not None: + raise ValueError( + "You have specified compare_statistics without a compare_model_name" + ) + summary = PROGRAM(**locals()) + return summary["out"] + diff --git a/requirements.txt b/requirements.txt index 8dc00f1..abc9bec 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,5 +10,6 @@ pathos==0.2.8 networkx==2.6.3 python-Levenshtein==0.12.2 pyarrow==12.0.1 +guidance==0.0.64 scikit-learn rich