Skip to content

Commit a4f2b56

Browse files
committed
update examples
1 parent 76559ed commit a4f2b56

File tree

10 files changed

+162
-22
lines changed

10 files changed

+162
-22
lines changed

examples/async_classifier.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
#!/usr/bin/env python
2+
3+
from pprint import pprint
4+
import asyncio
5+
from mistralai import Mistral, TrainingFile, ClassifierTrainingParametersIn
6+
7+
import os
8+
9+
10+
async def upload_files(client: Mistral, file_names: list[str]) -> list[str]:
11+
# Upload files
12+
print("Uploading files...")
13+
14+
file_ids = []
15+
for file_name in file_names:
16+
with open(file_name, "rb") as file:
17+
f = await client.files.upload_async(
18+
file={
19+
"file_name": file_name,
20+
"content": file.read(),
21+
},
22+
purpose="fine-tune",
23+
)
24+
file_ids.append(f.id)
25+
print("Files uploaded...")
26+
return file_ids
27+
28+
29+
async def train_classifier(client: Mistral,training_file_ids: list[str]) -> str:
30+
print("Creating job...")
31+
job = await client.fine_tuning.jobs.create_async(
32+
model="ministral-3b-latest",
33+
job_type="classifier",
34+
training_files=[
35+
TrainingFile(file_id=training_file_id)
36+
for training_file_id in training_file_ids
37+
],
38+
hyperparameters=ClassifierTrainingParametersIn(
39+
learning_rate=0.0001,
40+
),
41+
auto_start=True,
42+
)
43+
44+
print(f"Job created ({job.id})")
45+
46+
i = 1
47+
while True:
48+
await asyncio.sleep(10)
49+
detailed_job = await client.fine_tuning.jobs.get_async(job_id=job.id)
50+
if detailed_job.status not in [
51+
"QUEUED",
52+
"STARTED",
53+
"VALIDATING",
54+
"VALIDATED",
55+
"RUNNING",
56+
]:
57+
break
58+
print(f"Still training after {i * 10} seconds")
59+
i += 1
60+
61+
if detailed_job.status != "SUCCESS":
62+
print("Training failed")
63+
raise Exception(f"Job failed {detailed_job.status}")
64+
65+
print(f"Training succeed: {detailed_job.fine_tuned_model}")
66+
67+
return detailed_job.fine_tuned_model
68+
69+
70+
async def main():
71+
training_files = ["./examples/fixtures/classifier_sentiments.jsonl"]
72+
client = Mistral(
73+
api_key=os.environ["MISTRAL_API_KEY"],
74+
)
75+
76+
training_file_ids: list[str] = await upload_files(client=client, file_names=training_files)
77+
model_name: str | None = await train_classifier(client=client,training_file_ids=training_file_ids)
78+
79+
if model_name:
80+
print("Calling inference...")
81+
response = client.classifiers.classify(
82+
model=model_name,
83+
inputs=["It's nice", "It's terrible", "Why not"],
84+
)
85+
print("Inference succeed !")
86+
pprint(response)
87+
88+
print("Calling inference (Chat)...")
89+
response = client.classifiers.classify_chat(
90+
model=model_name,
91+
inputs={"messages": [{"role": "user", "content": "Lame..."}]},
92+
)
93+
print("Inference succeed (Chat)!")
94+
pprint(response)
95+
96+
97+
if __name__ == "__main__":
98+
asyncio.run(main())

examples/async_files.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ async def main():
1616
created_file = await client.files.upload_async(
1717
file=File(
1818
file_name="training_file.jsonl",
19-
content=open("examples/file.jsonl", "rb").read(),
19+
content=open("examples/fixtures/ft_training_file.jsonl", "rb").read(),
2020
)
2121
)
2222
print(created_file)

examples/async_jobs.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import os
55

66
from mistralai import Mistral
7-
from mistralai.models import File, TrainingParametersIn
7+
from mistralai.models import File, CompletionTrainingParametersIn
88

99

1010
async def main():
@@ -13,11 +13,11 @@ async def main():
1313
client = Mistral(api_key=api_key)
1414

1515
# Create new files
16-
with open("examples/file.jsonl", "rb") as f:
16+
with open("examples/fixtures/ft_training_file.jsonl", "rb") as f:
1717
training_file = await client.files.upload_async(
1818
file=File(file_name="file.jsonl", content=f)
1919
)
20-
with open("examples/validation_file.jsonl", "rb") as f:
20+
with open("examples/fixtures/ft_validation_file.jsonl", "rb") as f:
2121
validation_file = await client.files.upload_async(
2222
file=File(file_name="validation_file.jsonl", content=f)
2323
)
@@ -27,7 +27,7 @@ async def main():
2727
model="open-mistral-7b",
2828
training_files=[{"file_id": training_file.id, "weight": 1}],
2929
validation_files=[validation_file.id],
30-
hyperparameters=TrainingParametersIn(
30+
hyperparameters=CompletionTrainingParametersIn(
3131
training_steps=1,
3232
learning_rate=0.0001,
3333
),

examples/async_jobs_chat.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@
44
import os
55

66
from mistralai import Mistral
7-
from mistralai.models import File, TrainingParametersIn
7+
from mistralai.models import (
8+
File,
9+
CompletionTrainingParametersIn,
10+
)
811

912
POLLING_INTERVAL = 10
1013

@@ -14,11 +17,11 @@ async def main():
1417
client = Mistral(api_key=api_key)
1518

1619
# Create new files
17-
with open("examples/file.jsonl", "rb") as f:
20+
with open("examples/fixtures/ft_training_file.jsonl", "rb") as f:
1821
training_file = await client.files.upload_async(
1922
file=File(file_name="file.jsonl", content=f)
2023
)
21-
with open("examples/validation_file.jsonl", "rb") as f:
24+
with open("examples/fixtures/ft_validation_file.jsonl", "rb") as f:
2225
validation_file = await client.files.upload_async(
2326
file=File(file_name="validation_file.jsonl", content=f)
2427
)
@@ -27,22 +30,28 @@ async def main():
2730
model="open-mistral-7b",
2831
training_files=[{"file_id": training_file.id, "weight": 1}],
2932
validation_files=[validation_file.id],
30-
hyperparameters=TrainingParametersIn(
31-
training_steps=1,
33+
hyperparameters=CompletionTrainingParametersIn(
34+
training_steps=2,
3235
learning_rate=0.0001,
3336
),
3437
)
3538
print(created_job)
3639

37-
while created_job.status in ["RUNNING", "QUEUED"]:
40+
while created_job.status in [
41+
"QUEUED",
42+
"STARTED",
43+
"VALIDATING",
44+
"VALIDATED",
45+
"RUNNING",
46+
]:
3847
created_job = await client.fine_tuning.jobs.get_async(job_id=created_job.id)
3948
print(f"Job is {created_job.status}, waiting {POLLING_INTERVAL} seconds")
4049
await asyncio.sleep(POLLING_INTERVAL)
4150

42-
if created_job.status == "FAILED":
51+
if created_job.status != "SUCCESS":
4352
print("Job failed")
44-
return
45-
53+
raise Exception(f"Job failed with {created_job.status}")
54+
print(created_job)
4655
# Chat with model
4756
response = await client.chat.complete_async(
4857
model=created_job.fine_tuned_model,

examples/dry_run_job.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import os
55

66
from mistralai import Mistral
7-
from mistralai.models import TrainingParametersIn
7+
from mistralai.models import CompletionTrainingParametersIn
88

99

1010
async def main():
@@ -13,7 +13,7 @@ async def main():
1313
client = Mistral(api_key=api_key)
1414

1515
# Create new files
16-
with open("examples/file.jsonl", "rb") as f:
16+
with open("examples/fixtures/ft_training_file.jsonl", "rb") as f:
1717
training_file = await client.files.upload_async(
1818
file={"file_name": "test-file.jsonl", "content": f}
1919
)
@@ -22,7 +22,7 @@ async def main():
2222
dry_run_job = await client.fine_tuning.jobs.create_async(
2323
model="open-mistral-7b",
2424
training_files=[{"file_id": training_file.id, "weight": 1}],
25-
hyperparameters=TrainingParametersIn(
25+
hyperparameters=CompletionTrainingParametersIn(
2626
training_steps=1,
2727
learning_rate=0.0001,
2828
warmup_fraction=0.01,

examples/files.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def main():
1515
created_file = client.files.upload(
1616
file=File(
1717
file_name="training_file.jsonl",
18-
content=open("examples/file.jsonl", "rb").read(),
18+
content=open("examples/fixtures/ft_training_file.jsonl", "rb").read(),
1919
)
2020
)
2121
print(created_file)
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
{"text": "I love this product!", "labels": {"sentiment": "positive"}}
2+
{"text": "The game was amazing.", "labels": {"sentiment": "positive"}}
3+
{"text": "The new policy is controversial.", "labels": {"sentiment": "neutral"}}
4+
{"text": "I don't like the new design.", "labels": {"sentiment": "negative"}}
5+
{"text": "The team won the championship.", "labels": {"sentiment": "positive"}}
6+
{"text": "The economy is in a bad shape.", "labels": {"sentiment": "negative"}}
7+
{"text": "The weather is nice today.", "labels": {"sentiment": "positive"}}
8+
{"text": "The match ended in a draw.", "labels": {"sentiment": "neutral"}}
9+
{"text": "The new law will be implemented soon.", "labels": {"sentiment": "neutral"}}
10+
{"text": "I had a great time at the concert.", "labels": {"sentiment": "positive"}}
11+
{"text": "This movie was fantastic!", "labels": {"sentiment": "positive"}}
12+
{"text": "The service was terrible.", "labels": {"sentiment": "negative"}}
13+
{"text": "The food was delicious.", "labels": {"sentiment": "positive"}}
14+
{"text": "I'm not sure about this decision.", "labels": {"sentiment": "neutral"}}
15+
{"text": "The book was boring.", "labels": {"sentiment": "negative"}}
16+
{"text": "The view from the top was breathtaking.", "labels": {"sentiment": "positive"}}
17+
{"text": "The traffic was awful today.", "labels": {"sentiment": "negative"}}
18+
{"text": "The event was well-organized.", "labels": {"sentiment": "positive"}}
19+
{"text": "The meeting went on for too long.", "labels": {"sentiment": "negative"}}
20+
{"text": "The presentation was informative.", "labels": {"sentiment": "positive"}}
21+
{"text": "The new software update is buggy.", "labels": {"sentiment": "negative"}}
22+
{"text": "The concert was sold out.", "labels": {"sentiment": "positive"}}
23+
{"text": "The weather forecast is unreliable.", "labels": {"sentiment": "negative"}}
24+
{"text": "The new phone is expensive.", "labels": {"sentiment": "neutral"}}
25+
{"text": "The customer service was excellent.", "labels": {"sentiment": "positive"}}
26+
{"text": "The new restaurant opened today.", "labels": {"sentiment": "neutral"}}
27+
{"text": "The movie had a surprising ending.", "labels": {"sentiment": "positive"}}
28+
{"text": "The project deadline is approaching.", "labels": {"sentiment": "neutral"}}
29+
{"text": "The team is working hard.", "labels": {"sentiment": "positive"}}
30+
{"text": "The new product launch was successful.", "labels": {"sentiment": "positive"}}
31+
{"text": "The conference was insightful.", "labels": {"sentiment": "positive"}}
32+
{"text": "The flight was delayed.", "labels": {"sentiment": "negative"}}
33+
{"text": "The vacation was relaxing.", "labels": {"sentiment": "positive"}}
File renamed without changes.
File renamed without changes.

examples/jobs.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import os
33

44
from mistralai import Mistral
5-
from mistralai.models import File, TrainingParametersIn
5+
from mistralai.models import File, CompletionTrainingParametersIn
66

77

88
def main():
@@ -11,11 +11,11 @@ def main():
1111
client = Mistral(api_key=api_key)
1212

1313
# Create new files
14-
with open("examples/file.jsonl", "rb") as f:
14+
with open("examples/fixtures/ft_training_file.jsonl", "rb") as f:
1515
training_file = client.files.upload(
1616
file=File(file_name="file.jsonl", content=f)
1717
)
18-
with open("examples/validation_file.jsonl", "rb") as f:
18+
with open("examples/fixtures/ft_validation_file.jsonl", "rb") as f:
1919
validation_file = client.files.upload(
2020
file=File(file_name="validation_file.jsonl", content=f)
2121
)
@@ -25,7 +25,7 @@ def main():
2525
model="open-mistral-7b",
2626
training_files=[{"file_id": training_file.id, "weight": 1}],
2727
validation_files=[validation_file.id],
28-
hyperparameters=TrainingParametersIn(
28+
hyperparameters=CompletionTrainingParametersIn(
2929
training_steps=1,
3030
learning_rate=0.0001,
3131
),

0 commit comments

Comments
 (0)