|
1 | 1 | #!/usr/bin/env python
|
2 |
| - |
3 | 2 | import asyncio
|
| 3 | +import json |
4 | 4 | import os
|
| 5 | +import random |
| 6 | +from pathlib import Path |
5 | 7 |
|
6 | 8 | from mistralai import Mistral
|
7 | 9 | from mistralai.models import (
|
|
11 | 13 |
|
12 | 14 | POLLING_INTERVAL = 10
|
13 | 15 |
|
| 16 | +cwd = Path(__file__).parent |
| 17 | + |
| 18 | +user_contents = [ |
| 19 | + "How far is the Moon from Earth?", |
| 20 | + "What's the largest ocean on Earth?", |
| 21 | + "How many continents are there?", |
| 22 | + "What's the powerhouse of the cell?", |
| 23 | + "What's the speed of light?", |
| 24 | + "Can you solve a Rubik's Cube?", |
| 25 | + "What is the tallest mountain in the world?", |
| 26 | + "Who painted the Mona Lisa?", |
| 27 | +] |
| 28 | + |
| 29 | +# List of assistant contents |
| 30 | +assistant_contents = [ |
| 31 | + "Around 384,400 kilometers. Give or take a few, like that really matters.", |
| 32 | + "The Pacific Ocean. You know, the one that covers more than 60 million square miles. No big deal.", |
| 33 | + "There are seven continents. I hope that wasn't too hard to count.", |
| 34 | + "The mitochondria. Remember that from high school biology?", |
| 35 | + "Approximately 299,792 kilometers per second. You know, faster than your internet speed.", |
| 36 | + "I could if I had hands. What's your excuse?", |
| 37 | + "Mount Everest, standing at 29,029 feet. You know, just a little hill.", |
| 38 | + "Leonardo da Vinci. Just another guy who liked to doodle.", |
| 39 | +] |
| 40 | + |
| 41 | +system_message = "Marv is a factual chatbot that is also sarcastic" |
| 42 | + |
| 43 | +def create_validation_file() -> bytes: |
| 44 | + return json.dumps({ |
| 45 | + "messages": [ |
| 46 | + {"role": "user", "content": "How long does it take to travel around the Earth?"}, |
| 47 | + {"role": "assistant", "content": "Around 24 hours if you're the Earth itself. For you, depends on your mode of transportation."} |
| 48 | + ], |
| 49 | + "temperature": random.random() |
| 50 | + }).encode() |
14 | 51 |
|
15 | 52 | async def main():
|
16 | 53 | api_key = os.environ["MISTRAL_API_KEY"]
|
17 | 54 | client = Mistral(api_key=api_key)
|
18 | 55 |
|
| 56 | + requests = [] |
| 57 | + for um, am in zip( |
| 58 | + random.sample(user_contents, len(user_contents)), |
| 59 | + random.sample(assistant_contents, len(assistant_contents)), |
| 60 | + ): |
| 61 | + requests.append(json.dumps({ |
| 62 | + "messages": [ |
| 63 | + {"role": "system", "content": system_message}, |
| 64 | + {"role": "user", "content": um}, |
| 65 | + {"role": "assistant", "content": am}, |
| 66 | + ] |
| 67 | + })) |
| 68 | + |
19 | 69 | # Create new files
|
20 |
| - with open("examples/fixtures/ft_training_file.jsonl", "rb") as f: |
21 |
| - training_file = await client.files.upload_async( |
22 |
| - file=File(file_name="file.jsonl", content=f) |
23 |
| - ) |
24 |
| - with open("examples/fixtures/ft_validation_file.jsonl", "rb") as f: |
25 |
| - validation_file = await client.files.upload_async( |
26 |
| - file=File(file_name="validation_file.jsonl", content=f) |
27 |
| - ) |
| 70 | + training_file = await client.files.upload_async( |
| 71 | + file=File( |
| 72 | + file_name="file.jsonl", content=("\n".join(requests)).encode() |
| 73 | + ), |
| 74 | + purpose="fine-tune", |
| 75 | + ) |
| 76 | + |
| 77 | + validation_file = await client.files.upload_async( |
| 78 | + file=File( |
| 79 | + file_name="validation_file.jsonl", content=create_validation_file() |
| 80 | + ), |
| 81 | + purpose="fine-tune", |
| 82 | + ) |
28 | 83 | # Create a new job
|
29 | 84 | created_job = await client.fine_tuning.jobs.create_async(
|
30 | 85 | model="open-mistral-7b",
|
31 | 86 | training_files=[{"file_id": training_file.id, "weight": 1}],
|
32 | 87 | validation_files=[validation_file.id],
|
33 | 88 | hyperparameters=CompletionTrainingParametersIn(
|
34 |
| - training_steps=2, |
| 89 | + training_steps=1, |
35 | 90 | learning_rate=0.0001,
|
36 | 91 | ),
|
37 | 92 | )
|
38 |
| - print(created_job) |
39 | 93 |
|
40 |
| - while created_job.status in [ |
41 |
| - "QUEUED", |
42 |
| - "STARTED", |
43 |
| - "VALIDATING", |
44 |
| - "VALIDATED", |
45 |
| - "RUNNING", |
46 |
| - ]: |
| 94 | + while created_job.status in ["RUNNING", "STARTED", "QUEUED", "VALIDATING", "VALIDATED"]: |
47 | 95 | created_job = await client.fine_tuning.jobs.get_async(job_id=created_job.id)
|
48 | 96 | print(f"Job is {created_job.status}, waiting {POLLING_INTERVAL} seconds")
|
49 | 97 | await asyncio.sleep(POLLING_INTERVAL)
|
50 | 98 |
|
51 |
| - if created_job.status != "SUCCESS": |
| 99 | + if created_job.status == "FAILED": |
52 | 100 | print("Job failed")
|
53 | 101 | raise Exception(f"Job failed with {created_job.status}")
|
| 102 | + |
54 | 103 | print(created_job)
|
55 | 104 | # Chat with model
|
56 | 105 | response = await client.chat.complete_async(
|
|
0 commit comments