Skip to content

Commit 607b835

Browse files
committed
fix: example async_jobs_chat not deterministic file
1 parent d8315ac commit 607b835

File tree

1 file changed

+68
-19
lines changed

1 file changed

+68
-19
lines changed

examples/async_jobs_chat.py

Lines changed: 68 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
#!/usr/bin/env python
2-
32
import asyncio
3+
import json
44
import os
5+
import random
6+
from pathlib import Path
57

68
from mistralai import Mistral
79
from mistralai.models import (
@@ -11,46 +13,93 @@
1113

1214
POLLING_INTERVAL = 10
1315

16+
cwd = Path(__file__).parent
17+
18+
user_contents = [
19+
"How far is the Moon from Earth?",
20+
"What's the largest ocean on Earth?",
21+
"How many continents are there?",
22+
"What's the powerhouse of the cell?",
23+
"What's the speed of light?",
24+
"Can you solve a Rubik's Cube?",
25+
"What is the tallest mountain in the world?",
26+
"Who painted the Mona Lisa?",
27+
]
28+
29+
# List of assistant contents
30+
assistant_contents = [
31+
"Around 384,400 kilometers. Give or take a few, like that really matters.",
32+
"The Pacific Ocean. You know, the one that covers more than 60 million square miles. No big deal.",
33+
"There are seven continents. I hope that wasn't too hard to count.",
34+
"The mitochondria. Remember that from high school biology?",
35+
"Approximately 299,792 kilometers per second. You know, faster than your internet speed.",
36+
"I could if I had hands. What's your excuse?",
37+
"Mount Everest, standing at 29,029 feet. You know, just a little hill.",
38+
"Leonardo da Vinci. Just another guy who liked to doodle.",
39+
]
40+
41+
system_message = "Marv is a factual chatbot that is also sarcastic"
42+
43+
def create_validation_file() -> bytes:
44+
return json.dumps({
45+
"messages": [
46+
{"role": "user", "content": "How long does it take to travel around the Earth?"},
47+
{"role": "assistant", "content": "Around 24 hours if you're the Earth itself. For you, depends on your mode of transportation."}
48+
],
49+
"temperature": random.random()
50+
}).encode()
1451

1552
async def main():
1653
api_key = os.environ["MISTRAL_API_KEY"]
1754
client = Mistral(api_key=api_key)
1855

56+
requests = []
57+
for um, am in zip(
58+
random.sample(user_contents, len(user_contents)),
59+
random.sample(assistant_contents, len(assistant_contents)),
60+
):
61+
requests.append(json.dumps({
62+
"messages": [
63+
{"role": "system", "content": system_message},
64+
{"role": "user", "content": um},
65+
{"role": "assistant", "content": am},
66+
]
67+
}))
68+
1969
# Create new files
20-
with open("examples/fixtures/ft_training_file.jsonl", "rb") as f:
21-
training_file = await client.files.upload_async(
22-
file=File(file_name="file.jsonl", content=f)
23-
)
24-
with open("examples/fixtures/ft_validation_file.jsonl", "rb") as f:
25-
validation_file = await client.files.upload_async(
26-
file=File(file_name="validation_file.jsonl", content=f)
27-
)
70+
training_file = await client.files.upload_async(
71+
file=File(
72+
file_name="file.jsonl", content=("\n".join(requests)).encode()
73+
),
74+
purpose="fine-tune",
75+
)
76+
77+
validation_file = await client.files.upload_async(
78+
file=File(
79+
file_name="validation_file.jsonl", content=create_validation_file()
80+
),
81+
purpose="fine-tune",
82+
)
2883
# Create a new job
2984
created_job = await client.fine_tuning.jobs.create_async(
3085
model="open-mistral-7b",
3186
training_files=[{"file_id": training_file.id, "weight": 1}],
3287
validation_files=[validation_file.id],
3388
hyperparameters=CompletionTrainingParametersIn(
34-
training_steps=2,
89+
training_steps=1,
3590
learning_rate=0.0001,
3691
),
3792
)
38-
print(created_job)
3993

40-
while created_job.status in [
41-
"QUEUED",
42-
"STARTED",
43-
"VALIDATING",
44-
"VALIDATED",
45-
"RUNNING",
46-
]:
94+
while created_job.status in ["RUNNING", "STARTED", "QUEUED", "VALIDATING", "VALIDATED"]:
4795
created_job = await client.fine_tuning.jobs.get_async(job_id=created_job.id)
4896
print(f"Job is {created_job.status}, waiting {POLLING_INTERVAL} seconds")
4997
await asyncio.sleep(POLLING_INTERVAL)
5098

51-
if created_job.status != "SUCCESS":
99+
if created_job.status == "FAILED":
52100
print("Job failed")
53101
raise Exception(f"Job failed with {created_job.status}")
102+
54103
print(created_job)
55104
# Chat with model
56105
response = await client.chat.complete_async(

0 commit comments

Comments
 (0)