Skip to content

Commit 75e934b

Browse files
committed
format: lint
Signed-off-by: Roni Friedman-Melamed <[email protected]>
1 parent c22397b commit 75e934b

File tree

3 files changed

+52
-44
lines changed

3 files changed

+52
-44
lines changed

src/instructlab/eval/mmlu.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,9 @@ def run(self, server_url: str | None = None) -> tuple:
153153

154154
return overall_score, individual_scores
155155

156-
def _run_mmlu(self, server_url: str | None = None, return_all_results:bool = False) -> dict:
156+
def _run_mmlu(
157+
self, server_url: str | None = None, return_all_results: bool = False
158+
) -> dict:
157159
if server_url is not None:
158160
# Requires lm_eval >= 0.4.4
159161
model_args = f"base_url={server_url}/completions,model={self.model_path},tokenizer_backend=huggingface"

src/instructlab/eval/unitxt.py

Lines changed: 47 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,13 @@
55
"""
66

77
# Standard
8-
import os, shutil
9-
import yaml
108
from uuid import uuid4
9+
import os
10+
import shutil
1111

1212
# Third Party
1313
from lm_eval.tasks.unitxt import task
14+
import yaml
1415

1516
# First Party
1617
from instructlab.eval.mmlu import MMLUBranchEvaluator
@@ -20,7 +21,8 @@
2021

2122
logger = setup_logger(__name__)
2223

23-
TEMP_DIR_PREFIX = 'unitxt_temp'
24+
TEMP_DIR_PREFIX = "unitxt_temp"
25+
2426

2527
class UnitxtEvaluator(MMLUBranchEvaluator):
2628
"""
@@ -29,45 +31,51 @@ class UnitxtEvaluator(MMLUBranchEvaluator):
2931
Attributes:
3032
model_path absolute path to or name of a huggingface model
3133
unitxt_recipe unitxt recipe (see unitxt.ai for more information)
32-
A Recipe holds a complete specification of a unitxt pipeline
34+
A Recipe holds a complete specification of a unitxt pipeline
3335
Example: card=cards.wnli,template=templates.classification.multi_class.relation.default,max_train_instances=5,loader_limit=20,num_demos=3,demos_pool_size=10
34-
36+
3537
"""
38+
3639
name = "unitxt"
40+
3741
def __init__(
3842
self,
39-
model_path,
43+
model_path,
4044
unitxt_recipe: str,
4145
):
42-
task = self.assign_task_name()
43-
tasks_dir = self.assign_tasks_dir(task)
46+
unitxt_task = self.assign_task_name()
47+
tasks_dir = self.assign_tasks_dir(unitxt_task)
4448
super().__init__(
45-
model_path = model_path,
46-
tasks_dir = tasks_dir,
47-
tasks = [task],
48-
few_shots = 0
49+
model_path=model_path, tasks_dir=tasks_dir, tasks=[unitxt_task], few_shots=0
4950
)
5051
self.unitxt_recipe = unitxt_recipe
5152

52-
def assign_tasks_dir(self, task):
53-
return f'{TEMP_DIR_PREFIX}_{task}'
53+
def assign_tasks_dir(self, task_name):
54+
return f"{TEMP_DIR_PREFIX}_{task_name}"
5455

5556
def assign_task_name(self):
5657
return str(uuid4())
5758

58-
def prepare_unitxt_files(self)->tuple:
59-
task = self.tasks[0]
60-
yaml_file = os.path.join(self.tasks_dir,f"{task}.yaml")
59+
def prepare_unitxt_files(self) -> None:
60+
taskname = self.tasks[0]
61+
yaml_file = os.path.join(str(self.tasks_dir), f"{taskname}.yaml")
6162
create_unitxt_pointer(self.tasks_dir)
62-
create_unitxt_yaml(yaml_file=yaml_file, unitxt_recipe=self.unitxt_recipe, task_name=task)
63+
create_unitxt_yaml(
64+
yaml_file=yaml_file, unitxt_recipe=self.unitxt_recipe, task_name=taskname
65+
)
6366

6467
def remove_unitxt_files(self):
65-
if self.tasks_dir.startswith(TEMP_DIR_PREFIX): #to avoid unintended deletion if this class is inherited
68+
if self.tasks_dir.startswith(
69+
TEMP_DIR_PREFIX
70+
): # to avoid unintended deletion if this class is inherited
6671
shutil.rmtree(self.tasks_dir)
6772
else:
68-
logger.warning(f"unitxt tasks dir did not start with '{TEMP_DIR_PREFIX}' and therefor was not deleted")
73+
logger.warning(
74+
"unitxt tasks dir did not start with '%s' and therefor was not deleted",
75+
TEMP_DIR_PREFIX,
76+
)
6977

70-
def run(self,server_url: str | None = None) -> tuple:
78+
def run(self, server_url: str | None = None) -> tuple:
7179
"""
7280
Runs evaluation
7381
@@ -80,40 +88,40 @@ def run(self,server_url: str | None = None) -> tuple:
8088
os.environ["TOKENIZERS_PARALLELISM"] = "true"
8189
results = self._run_mmlu(server_url=server_url, return_all_results=True)
8290
taskname = self.tasks[0]
83-
global_scores = results['results'][taskname]
84-
global_scores.pop('alias')
91+
global_scores = results["results"][taskname]
92+
global_scores.pop("alias")
8593
try:
86-
instances = results['samples'][taskname]
94+
instances = results["samples"][taskname]
8795
instance_scores = {}
88-
metrics = [metric.replace('metrics.','') for metric in instances[0]['doc']['metrics']]
89-
for i,instance in enumerate(instances):
96+
metrics = [
97+
metric.replace("metrics.", "")
98+
for metric in instances[0]["doc"]["metrics"]
99+
]
100+
for i, instance in enumerate(instances):
90101
scores = {}
91102
for metric in metrics:
92103
scores[metric] = instance[metric][0]
93104
instance_scores[i] = scores
94-
except Exception as e:
105+
except KeyError as e:
95106
logger.error("Error in extracting single instance scores")
96107
logger.error(e)
97108
logger.error(e.__traceback__)
98109
instance_scores = None
99110
self.remove_unitxt_files()
100-
return global_scores,instance_scores
111+
return global_scores, instance_scores
101112

102113

103-
def create_unitxt_yaml(yaml_file,unitxt_recipe, task_name):
104-
data = {
105-
'task': f'{task_name}',
106-
'include': 'unitxt',
107-
'recipe': f'{unitxt_recipe}'
108-
}
109-
with open(yaml_file, 'w') as file:
114+
def create_unitxt_yaml(yaml_file, unitxt_recipe, task_name):
115+
data = {"task": f"{task_name}", "include": "unitxt", "recipe": f"{unitxt_recipe}"}
116+
with open(yaml_file, "w", encoding="utf-8") as file:
110117
yaml.dump(data, file, default_flow_style=False)
111-
logger.debug(f"task {task} unitxt recipe written to {yaml_file}")
118+
logger.debug("task %s unitxt recipe written to %s", task_name, yaml_file)
119+
112120

113121
def create_unitxt_pointer(tasks_dir):
114122
class_line = "class: !function " + task.__file__.replace("task.py", "task.Unitxt")
115-
output_file = os.path.join(tasks_dir,'unitxt')
123+
output_file = os.path.join(tasks_dir, "unitxt")
116124
os.makedirs(os.path.dirname(output_file), exist_ok=True)
117-
with open(output_file, 'w') as f:
125+
with open(output_file, "w", encoding="utf-8") as f:
118126
f.write(class_line)
119-
logger.debug(f"Unitxt task pointer written to {output_file}")
127+
logger.debug("Unitxt task pointer written to %s", output_file)

tests/test_unitxt.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,7 @@ def test_unitxt():
77
try:
88
model_path = "instructlab/granite-7b-lab"
99
unitxt_recipe = "card=cards.wnli,template=templates.classification.multi_class.relation.default,max_train_instances=5,loader_limit=20,num_demos=3,demos_pool_size=10"
10-
unitxt = UnitxtEvaluator(
11-
model_path=model_path, unitxt_recipe=unitxt_recipe
12-
)
10+
unitxt = UnitxtEvaluator(model_path=model_path, unitxt_recipe=unitxt_recipe)
1311
overall_score, single_scores = unitxt.run()
1412
print(overall_score)
1513
except Exception as exc:
@@ -19,4 +17,4 @@ def test_unitxt():
1917

2018

2119
if __name__ == "__main__":
22-
assert test_unitxt() == True
20+
assert test_unitxt() == True

0 commit comments

Comments
 (0)