ํ๊ทธ๋ก ๊ฐ์ธ์ง ํ
์คํธ๋ ๊ฐ์ ธ์ค์ง ์์
+
+ result = {"num": len(texts), "divs": divs, "contents": texts}
+ data.loc[idx, 'ํ
์คํธ'] = str(result)
+ print(idx, result)
+ print("-" * 100)
+ time.sleep(3)
+ break
+
+ except Exception as e:
+ print(f"ํ์ด์ง ๋ก๋ฉ/ํฌ๋กค๋ง ์ค ์ค๋ฅ ๋ฐ์ (์๋ {retry_count+1}/{max_retries}): {e}")
+ retry_count += 1
+ time.sleep(3) # ์ฌ์๋ ์ ๋๊ธฐ
+
+ if retry_count == max_retries:
+ data.loc[idx, 'ํ
์คํธ'] = ""
+ print(f"URL {url} ํฌ๋กค๋ง ์คํจ, ๋ค์ URL๋ก ์ด๋")
+
+ driver.quit()
+
+ output_path = f"{data_dir}/{total_text_csv}"
+ data.to_csv(output_path, index=False)
+ logger.info(f"ํฌ๋กค๋ง ๊ฒฐ๊ณผ๊ฐ {output_path} ์ ์ ์ฅ๋์์ต๋๋ค.")
diff --git a/models/product_summarization/utils/__init__.py b/models/product_summarization/utils/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/models/product_summarization/utils/data_processing.py b/models/product_summarization/utils/data_processing.py
new file mode 100644
index 0000000..135f88a
--- /dev/null
+++ b/models/product_summarization/utils/data_processing.py
@@ -0,0 +1,13 @@
+def product_introduction_processing(row):
+ """
+ 'ํ
์คํธ' ์ปฌ๋ผ์ ์ ์ฅ๋ JSON-like ๋ฌธ์์ด์ ํ์ฑํด ์ํ์๊ฐ ๋ฌธ๊ตฌ๋ฅผ ์์ฑํ๋ ํจ์
+ """
+ if not row['ํ
์คํธ'] or row['ํ
์คํธ'] == "":
+ return ""
+
+ text = eval(row['ํ
์คํธ']) # string -> dict ๋ณํ
+ text_title = ' '.join(text['divs'])
+ text_contents = [item for item in text['contents'] if "SSG.COM" not in item]
+ text_contents = ' '.join(text_contents)
+ product_introduction = text_title + text_contents
+ return product_introduction
diff --git a/models/product_summarization/utils/hcx.py b/models/product_summarization/utils/hcx.py
new file mode 100644
index 0000000..8fee7fe
--- /dev/null
+++ b/models/product_summarization/utils/hcx.py
@@ -0,0 +1,104 @@
+import time
+import requests
+
+# HCX-003 ๊ธฐ๋ณธ ๋ชจ๋ธ
+class CompletionExecutor:
+ def __init__(self, host, api_key, request_id):
+ self._host = host
+ self._api_key = api_key
+ self._request_id = request_id
+
+ def execute(self, completion_request, max_retries=5, retry_delay=20):
+ headers = {
+ 'Authorization': self._api_key,
+ 'X-NCP-CLOVASTUDIO-REQUEST-ID': self._request_id,
+ 'Content-Type': 'application/json; charset=utf-8',
+ }
+
+ for attempt in range(max_retries):
+ try:
+ start_time = time.time()
+ with requests.post(
+ self._host + '/testapp/v1/chat-completions/HCX-003',
+ headers=headers, json=completion_request
+ ) as r:
+ elapsed_time = time.time() - start_time
+ response = r.json()
+
+ if response.get("status", {}).get("code") == "20000":
+ return response["result"]["message"]["content"], elapsed_time
+ else:
+ raise ValueError(f"Invalid status code: {response.get('status', {}).get('code')}")
+ except (requests.RequestException, ValueError, KeyError) as e:
+ if attempt < max_retries - 1:
+ print(f"์๋ฌ ๋ฐ์: {str(e)}. {retry_delay}์ด ํ ์ฌ์๋ํฉ๋๋ค. (์๋ {attempt + 1}/{max_retries})")
+ time.sleep(retry_delay)
+ else:
+ print(f"์ต๋ ์ฌ์๋ ํ์ {max_retries}ํ๋ฅผ ์ด๊ณผํ์ต๋๋ค. ์ต์ข
์๋ฌ: {str(e)}")
+ return None, None
+
+ return None, None
+
+
+# ํ์ต ์์ฑ
+class CreateTaskExecutor:
+ def __init__(self, host, uri, api_key, request_id):
+ self._host = host
+ self._uri = uri
+ self._api_key = api_key
+ self._request_id = request_id
+
+ def _send_request(self, create_request):
+ headers = {
+ 'Authorization': self._api_key,
+ 'X-NCP-CLOVASTUDIO-REQUEST-ID': self._request_id
+ }
+ result = requests.post(self._host + self._uri, json=create_request, headers=headers).json()
+ return result
+
+ def execute(self, create_request):
+ res = self._send_request(create_request)
+ if 'status' in res and res['status']['code'] == '20000':
+ return res['result']
+ else:
+ return res
+
+
+# ํ๋๋ HCX-003 ๋ชจ๋ธ
+class FinetunedCompletionExecutor:
+ def __init__(self, host, api_key, request_id, taskId):
+ self._host = host
+ self._api_key = api_key
+ self._request_id = request_id
+ self._taskID = taskId
+
+ def execute(self, completion_request, max_retries=5, retry_delay=20):
+ headers = {
+ 'Authorization': self._api_key,
+ 'X-NCP-CLOVASTUDIO-REQUEST-ID': self._request_id,
+ 'Content-Type': 'application/json; charset=utf-8',
+ }
+
+ for attempt in range(max_retries):
+ try:
+ start_time = time.time()
+ with requests.post(
+ self._host + f'/testapp/v2/tasks/{self._taskID}/chat-completions',
+ headers=headers, json=completion_request
+ ) as r:
+ elapsed_time = time.time() - start_time
+ response = r.json()
+
+ if response.get("status", {}).get("code") == "20000":
+ return response["result"]["message"]["content"], elapsed_time
+ else:
+ raise ValueError(f"Invalid status code: {response.get('status', {}).get('code')}")
+ except (requests.RequestException, ValueError, KeyError) as e:
+ if attempt < max_retries - 1:
+ print(f"์๋ฌ ๋ฐ์: {str(e)}. {retry_delay}์ด ํ ์ฌ์๋ํฉ๋๋ค. (์๋ {attempt + 1}/{max_retries})")
+ time.sleep(retry_delay)
+ else:
+ print(f"์ต๋ ์ฌ์๋ ํ์ {max_retries}ํ๋ฅผ ์ด๊ณผํ์ต๋๋ค. ์ต์ข
์๋ฌ: {str(e)}")
+ return None, None
+
+ return None, None
diff --git a/models/review/README.md b/models/review/README.md
new file mode 100644
index 0000000..654e908
--- /dev/null
+++ b/models/review/README.md
@@ -0,0 +1,161 @@
+"""
+# ๋ฆฌ๋ทฐ ํ์ดํ๋ผ์ธ
+> ๋ณธ ๋ฆฌ๋ทฐ ํ์ดํ๋ผ์ธ์ ๋ฆฌ๋ทฐ ๋ฐ์ดํฐ๋ฅผ ์ ์ ๋ฐ ์์ฝํ์ฌ, ์ฌ์ฉ์๊ฐ ์ํ ๋ฆฌ๋ทฐ๋ฅผ ํจ๊ณผ์ ์ผ๋ก ์ดํดํ๊ณ , ์ถ์ฒ ํค์๋๋ฅผ ํตํด ์ํ ๊ฒ์ ๋ฐ ์ ๋ ฌ ๊ธฐ๋ฅ์ ์ ๊ณตํ๋ ๊ฒ์ ๋ชฉํ๋ก ํฉ๋๋ค.
+> ํ์ดํ๋ผ์ธ์ ๋ ๊ฐ์ ์ฃผ์ ๋ชจ๋๋ก ๊ตฌ์ฑ๋์ด ์์ต๋๋ค:
+> 1. **ํ์ต ๋ฐ์ดํฐ ์์ฑ ํ์ดํ๋ผ์ธ** (ASTE Task์ฉ ๋ฐ์ดํฐ ์์ฑ ๋ฐ ๋ชจ๋ธ ํ์ธํ๋)
+> 2. **์ถ์ฒ ํค์๋ ๊ฒ์๊ณผ ๋ฆฌ๋ทฐ ์์ฝ ํ์ดํ๋ผ์ธ** (ASTE ์ธํผ๋ฐ์ค, ์๋ฒ ๋ฉ, ํด๋ฌ์คํฐ๋ง ๋ฐ ํ์ฒ๋ฆฌ)
+
+## ์ฃผ์ ํน์ง
+1. **๋ฆฌ๋ทฐ ํฌ๋กค๋ง ๋ฐ ์ ์ฒ๋ฆฌ**:
+ ์จ๋ผ์ธ ์ผํ๋ชฐ์์ ์ํ ์ ๋ณด ๋ฐ ๋ฆฌ๋ทฐ๋ฅผ ์์งํ ํ, ํ
์คํธ ํด๋ ์ง, ํน์๋ฌธ์ ์ ๊ฑฐ, ๊ฐํ ๋ฌธ์ ๋ณํ, ์์ด/์ซ์ ๋น์จ ํํฐ๋ง, ์ค๋ณต ์ ๊ฑฐ, ์งง์ ๋ฆฌ๋ทฐ ๋ฐฐ์ , ๋ง์ถค๋ฒ ๊ต์ ๋ฑ ๋ค์ํ ์ ์ฒ๋ฆฌ ๊ณผ์ ์ ์ํํฉ๋๋ค.
+2. **ASTE(Aspect Sentiment Triplet Extraction) ๋ฐ์ดํฐ ์์ฑ**:
+ ์ ์ฒ๋ฆฌ๋ ๋ฆฌ๋ทฐ ๋ฐ์ดํฐ๋ฅผ ๋ฐํ์ผ๋ก Sentence-BERT ์๋ฒ ๋ฉ๊ณผ K-Means ํด๋ฌ์คํฐ๋ง์ ํ์ฉํ์ฌ ์ค๋ณต์ ๋ฐฉ์งํ ๋ํ ๋ฆฌ๋ทฐ ์ํ์ ์ ํํ๊ณ , GPT API๋ฅผ ํตํด ASTE ๊ด๋ จ 900๊ฐ์ ํ์ต ๋ฐ์ดํฐ๋ฅผ ์์ฑํฉ๋๋ค.
+3. **๋ชจ๋ธ ํ์ธํ๋**:
+ 'DeepSeek-R1-Distill-Qwen' ๋ชจ๋ธ์ ๊ธฐ๋ฐ์ผ๋ก Supervised Fine-Tuning(SFT)์ ์งํํ์ฌ, ๋ฆฌ๋ทฐ์ Aspect, Opinion, Sentiment ์ถ์ถ ๋ฅ๋ ฅ์ ํฅ์์ํต๋๋ค. 100๊ฐ์ ์๋ ๋ ์ด๋ธ๋ง ๋ฐ์ดํฐ์ Custom Evaluation Metric์ ์ด์ฉํ์ฌ ์ ๋์ ์ผ๋ก ๋ชจ๋ธ์ ์ฑ๋ฅ ํ๊ฐ๋ฅผ ์ค์ํฉ๋๋ค.
+4. **์ถ์ฒ ํค์๋ ๋ฐ ๋ฆฌ๋ทฐ ์์ฝ**:
+ ํ์ธํ๋๋ ASTE ๋ชจ๋ธ์ ์ธํผ๋ฐ์ค ๊ฒฐ๊ณผ๋ฅผ Sentence-BERT ์๋ฒ ๋ฉ๊ณผ UMAP ์ฐจ์ ์ถ์, Agglomerative ํด๋ฌ์คํฐ๋ง์ผ๋ก ๋ถ์ํ์ฌ HyperClovaX๋ก ๋ํ ํค์๋๋ฅผ ๋์ถํ๊ณ , ๊ธ์ /๋ถ์ ๋ฆฌ๋ทฐ ์์ฝ์ ์์ฑํฉ๋๋ค.
+5. **ํด๋ฌ์คํฐ๋ง ํ๊ฐ ๋ฐ ์๊ฐํ**:
+ T-SNE๋ฅผ ํ์ฉํ ์๊ฐํ์ Silhouette, DBI ๋ฑ์ ํ๊ฐ ์งํ๋ฅผ ํตํด ํด๋ฌ์คํฐ๋ง ๊ฒฐ๊ณผ์ ํ์ง์ ์ ๋์ ์ผ๋ก ๋ถ์ํฉ๋๋ค.
+
+> [Note] ํ์ต ๋ฐ์ดํฐ ์์ฑ๊ณผ ์ถ์ฒ ํค์๋ ๊ฒ์ ๋ฐ ๋ฆฌ๋ทฐ ์์ฝ์ ๋
๋ฆฝ์ ์ธ ๋ชจ๋๋ก ๊ตฌ์ฑ๋์ด ์์ผ๋ฉฐ, ํ์์ ๋ฐ๋ผ ๊ฐ๋ณ ์คํ์ด ๊ฐ๋ฅํฉ๋๋ค.
+
+## ํด๋ ๊ตฌ์กฐ
+```bash
+.
+โโโ README.md
+โโโ main.py # ํ์ดํ๋ผ์ธ ์คํ ์ฝ๋
+โโโ config
+โ โโโ config.yaml # ์ค์ ํ์ผ (ํ์ผ ๊ฒฝ๋ก, ์คํ ์ต์
๋ฑ)
+โโโ data
+โ โโโ ASTE
+โ โ โโโ ASTE_10_shots.csv # 10-shot ์์ ๋ฐ์ดํฐ
+โ โ โโโ ASTE_sampled.csv # ์ํ๋ง๋ ๋ฆฌ๋ทฐ ๋ฐ์ดํฐ
+โ โ โโโ eval
+โ โ โ โโโ ASTE_annotation_100_golden_label.csv # 100๊ฐ Golden Label ๋ฐ์ดํฐ
+โ โ โโโ inference
+โ โ โ โโโ deepseek_inference.csv # DeepSeek ์ธํผ๋ฐ์ค ๊ฒฐ๊ณผ (๋ชจ๋ธ ์ถ๋ ฅ ํฌํจ)
+โ โ โโโ processed_except_GL.csv # Golden Label ์ ์ธ ์ ์ฒ๋ฆฌ ๋ฆฌ๋ทฐ ๋ฐ์ดํฐ
+โ โ โโโ train
+โ โ โโโ train_data.csv # ํ์ต์ฉ ๋ฆฌ๋ทฐ ๋ฐ์ดํฐ
+โ โโโ crawled_reviews
+โ โ โโโ crawled_reviews_meals.csv # ๋ผ๋ฉด/๊ฐํธ์ ๊ด๋ จ ๋ฆฌ๋ทฐ ํฌ๋กค๋ง ๋ฐ์ดํฐ
+โ โ โโโ crawled_reviews_snacks.csv # ๊ณผ์/๋น๊ณผ ๊ด๋ จ ๋ฆฌ๋ทฐ ํฌ๋กค๋ง ๋ฐ์ดํฐ
+โ โโโ embedding_matrics
+โ โ โโโ cluster_result.png # ํด๋ฌ์คํฐ๋ง ๊ฒฐ๊ณผ ์๊ฐํ ์ด๋ฏธ์ง
+โ โ โโโ clustering_evaluation.json # ํด๋ฌ์คํฐ๋ง ํ๊ฐ ์งํ
+โ โ โโโ deepseek_inference.npy # ์ ์ฒด ์ธํผ๋ฐ์ค ์๋ฒ ๋ฉ ํ๋ ฌ
+โ โ โโโ deepseek_inference_meals.npy # ๋ผ๋ฉด/๊ฐํธ์ ๋ฆฌ๋ทฐ ์๋ฒ ๋ฉ ๋ฐ์ดํฐ
+โ โ โโโ deepseek_inference_reduced.npy # ์ฐจ์ ์ถ์ ํ ์๋ฒ ๋ฉ ๋ฐ์ดํฐ
+โ โ โโโ deepseek_inference_snacks.npy # ๊ณผ์/๋น๊ณผ ๋ฆฌ๋ทฐ ์๋ฒ ๋ฉ ๋ฐ์ดํฐ
+โ โ โโโ meals_cluster_result.png
+โ โ โโโ meals_clustering_evaluation.json
+โ โ โโโ snacks_cluster_result.png
+โ โ โโโ snacks_clustering_evaluation.json
+โ โโโ preprocessed
+โ โโโ meta_reviews_meals.csv # ๋ผ๋ฉด/๊ฐํธ์ ๋ฆฌ๋ทฐ ๋ฉํ ๋ฐ์ดํฐ
+โ โโโ meta_reviews_snacks.csv # ๊ณผ์/๋น๊ณผ ๋ฆฌ๋ทฐ ๋ฉํ ๋ฐ์ดํฐ
+โ โโโ processed_reviews_all.csv # ์ ์ฒด ์ ์ฒ๋ฆฌ ๋ฆฌ๋ทฐ ๋ฐ์ดํฐ
+โ โโโ processed_reviews_meals.csv # ๋ผ๋ฉด/๊ฐํธ์ ๋ฆฌ๋ทฐ ์ ์ฒ๋ฆฌ ๊ฒฐ๊ณผ
+โ โโโ processed_reviews_snacks.csv # ๊ณผ์/๋น๊ณผ ๋ฆฌ๋ทฐ ์ ์ฒ๋ฆฌ ๊ฒฐ๊ณผ
+โโโ prompt
+โ โโโ keyword_recommendation
+โ โ โโโ recommendation_fewshot.json # ์ถ์ฒ ํค์๋ Few-shot ์์
+โ โ โโโ recommendation_prompt.txt # ์ถ์ฒ ํค์๋ ํ๋กฌํํธ ํ
ํ๋ฆฟ
+โ โโโ prompt_loader.py # ํ๋กฌํํธ ๋ก๋ ์คํฌ๋ฆฝํธ
+โ โโโ review_annotation
+โ โ โโโ annotation_fewshot.json # ๋ฆฌ๋ทฐ ์ด๋
ธํ
์ด์
Few-shot ์์
+โ โ โโโ annotation_prompt.txt # ๋ฆฌ๋ทฐ ์ด๋
ธํ
์ด์
ํ๋กฌํํธ ํ
ํ๋ฆฟ
+โ โโโ review_summarization
+โ โโโ negative_fewshot.json # ๋ถ์ ๋ฆฌ๋ทฐ ์์ฝ ์์
+โ โโโ negative_prompt.txt # ๋ถ์ ๋ฆฌ๋ทฐ ์์ฝ ํ๋กฌํํธ ํ
ํ๋ฆฟ
+โ โโโ positive_fewshot.json # ๊ธ์ ๋ฆฌ๋ทฐ ์์ฝ ์์
+โ โโโ positive_prompt.txt # ๊ธ์ ๋ฆฌ๋ทฐ ์์ฝ ํ๋กฌํํธ ํ
ํ๋ฆฟ
+โโโ src
+โ โโโ review_pipeline
+โ โ โโโ ASTE_inference.py # ASTE ๋ชจ๋ธ ์ธํผ๋ฐ์ค ์คํ(๋๋ฏธ ๋ฐ์ดํฐ)
+โ โ โโโ keyword_recommendation.py # ์ถ์ฒ ํค์๋ ์ถ์ถ ๋ฐ ์ ๋ ฌ
+โ โ โโโ qwen_deepseek_14b_inference.py # Qwen 14B ๊ธฐ๋ฐ ์ธํผ๋ฐ์ค
+โ โ โโโ qwen_deepseek_32b_inference.py # Qwen 32B ๊ธฐ๋ฐ ์ธํผ๋ฐ์ค
+โ โ โโโ review_summarization.py # ๋ฆฌ๋ทฐ ์์ฝ ์ถ์ถ ์คํ
+โ โ โโโ visualization.py # ํด๋ฌ์คํฐ๋ง ๊ฒฐ๊ณผ ์๊ฐํ ๋ฐ ํ๊ฐ
+โ โโโ sft_pipeline
+โ โโโ qwen_deepseek_14b_finetuning.py # Qwen 14B ๋ชจ๋ธ ํ์ธํ๋
+โ โโโ qwen_deepseek_32b_finetuning.py # Qwen 32B ๋ชจ๋ธ ํ์ธํ๋
+โ โโโ review_crawling.py # ๋ฆฌ๋ทฐ ๋ฐ์ดํฐ ํฌ๋กค๋ง
+โ โโโ review_preprocessing.py # ๋ฆฌ๋ทฐ ์ ์ฒ๋ฆฌ ์คํ
+โ โโโ sft.py # Supervised Fine-Tuning ์คํ(๋๋ฏธ ๋ฐ์ดํฐ)
+โ โโโ train_data_annotating.py # ๋ฆฌ๋ทฐ ์ด๋
ธํ
์ด์
๋ฐ์ดํฐ ์์ฑ
+โ โโโ train_data_sampling.py # ๋ฆฌ๋ทฐ ์ํ๋ง
+โโโ environment.yml # Conda ํ๊ฒฝ ์ค์ ํ์ผ
+โโโ utils
+ โโโ evaluate.py # ASTE ๋ฐ ํด๋ฌ์คํฐ๋ง ํ๊ฐ ์ฝ๋
+ โโโ utils.py # ์ ํธ๋ฆฌํฐ ํจ์ ๋ชจ์
+```
+
+## ์ค์น ๋ฐ ์คํ ๋ฐฉ๋ฒ
+### 1) ํ๊ฒฝ ๊ตฌ์ถ
+- Python 3.11.11 ๋ฒ์ ๊ถ์ฅ
+- ์์กด์ฑ ํจํค์ง ์ค์น:
+```bash
+conda env create -f environment.yml
+```
+
+### 2) ์ค์
+- `config/config.yaml` ํ์ผ์์ ๋ค์ ์ ๋ณด๋ฅผ ์ ์ ํ ์ค์ ํฉ๋๋ค.
+ - **ํ์ผ ๊ฒฝ๋ก**: ๋ฐ์ดํฐ ํ์ผ ๊ฒฝ๋ก ๋ฑ
+ - **ํ์ดํ๋ผ์ธ ์คํ ์ฌ๋ถ**: pipeline ์น์
์ true/false ๊ฐ์ผ๋ก ๊ฐ ๋จ๊ณ(ํฌ๋กค๋ง, ์ ์ฒ๋ฆฌ, ์ธํผ๋ฐ์ค, ํ์ธํ๋, ์ถ์ฒ ๋ฑ) ์คํ ์ ์ด
+ - **ํ์ต ๋ฐ์ดํฐ ์์ฑ ์ค์ **: GPT๋ชจ๋ธ ์ ํ, ํ์ต ๋ฐ์ดํฐ ๊ฐ์ ์ค์ ๋ฑ
+
+### 3) ์คํ
+- ๊ธฐ๋ณธ ์คํ (๊ธฐ๋ณธ `config/config.yaml` ์ฌ์ฉ ์)
+```bash
+python main.py -p sft
+python main.py -p review
+```
+
+## Input & Output
+### 1. Input:
+- ํฌ๋กค๋ง ๋ฐ์ดํฐ:
+ - `crawled_reviews/`: ์จ๋ผ์ธ ์ผํ๋ชฐ์์ ํฌ๋กค๋งํ ์๋ณธ ๋ฆฌ๋ทฐ ๋ฐ์ดํฐ (์: `crawled_reviews_meals.csv`, `crawled_reviews_snacks.csv`)
+- ์ ์ฒ๋ฆฌ ๋ฐ์ดํฐ:
+ - `preprocessed/`: ์ ์ฒ๋ฆฌ๋ ๋ฆฌ๋ทฐ ๋ฐ์ดํฐ (`processed_reviews_all.csv`, `processed_reviews_meals.csv`, `processed_reviews_snacks.csv`)
+- ASTE ๊ด๋ จ CSV ํ์ผ:
+ - `aste/`: Golden Label, ์ํ๋ง ๋ฐ์ดํฐ ๋ฑ ASTE ํ์ต ๋ฐ ์ธํผ๋ฐ์ค์ ์ฌ์ฉ๋๋ ๋ฐ์ดํฐ
+- ํ๋กฌํํธ ํ
ํ๋ฆฟ:
+ - `prompt/*`: ์ถ์ฒ ํค์๋, ๋ฆฌ๋ทฐ ์ด๋
ธํ
์ด์
๋ฐ ์์ฝ์ ์ํ ํ๋กฌํํธ ํ
ํ๋ฆฟ
+
+### 2. Output:
+- ํ์ต ๋ฐ์ดํฐ ์์ฑ ํ์ดํ๋ผ์ธ ๊ฒฐ๊ณผ
+ - `train_data.csv`: ํ์ธํ๋์ฉ ๋ฆฌ๋ทฐ ํ์ต ๋ฐ์ดํฐ
+ - `ASTE_sampled.csv`: ํด๋ฌ์คํฐ๋ง์ผ๋ก ์ํ๋ง๋ ๋ฆฌ๋ทฐ ๋ฐ์ดํฐ
+ - ๊ธฐํ ํ์ต ๋ฐ์ดํฐ ์์ฑ ๊ฒฐ๊ณผ CSV (GPT API๋ฅผ ํตํ ASTE ๋ฐ์ดํฐ ์์ฑ ๊ฒฐ๊ณผ)
+- ์ถ์ฒ ํค์๋ ๊ฒ์ ๋ฐ ๋ฆฌ๋ทฐ ์์ฝ ํ์ดํ๋ผ์ธ ๊ฒฐ๊ณผ
+ - `deepseek_inference.csv`: ASTE ์ธํผ๋ฐ์ค ๊ฒฐ๊ณผ (Aspect, Opinion, Sentiment ํฌํจ)
+- ์ต์ข
์ถ์ฒ CSV ํ์ผ: ์ถ์ฒ ํค์๋ ๊ธฐ๋ฐ ์ํ ์ฌ์ ๋ ฌ ๋ฐ ๋ฆฌ๋ทฐ ์์ฝ ๊ฒฐ๊ณผ
+- ํด๋ฌ์คํฐ๋ง ํ๊ฐ ์๋ฃ (์๊ฐํ ์ด๋ฏธ์ง, ํ๊ฐ ์งํ JSON ๋ฑ)
+
+
+
+
+## ์ฝ๋ ์ค๋ช
+`main.py`
+ํ์ดํ๋ผ์ธ์ ์ง์
์ ์ผ๋ก, config/config.yaml ํ์ผ์ ์ค์ ์ ๋ฐ๋ผ ๊ฐ ๋จ๊ณ(ํฌ๋กค๋ง, ์ ์ฒ๋ฆฌ, ํ์ต ๋ฐ์ดํฐ ์์ฑ, SFT, ์ธํผ๋ฐ์ค ๋ฑ)๋ฅผ ์์ฐจ์ ์ผ๋ก ์คํํฉ๋๋ค.
+
+- `src/review_pipeline/`
+ - `ASTE_inference.py`: ASTE ๋ชจ๋ธ ์ธํผ๋ฐ์ค๋ฅผ ์ํํ์ฌ ๋ฆฌ๋ทฐ์์ (Aspect, Opinion, Sentiment)๋ฅผ ์ถ์ถํฉ๋๋ค.
+ - `keyword_recommendation.py`: ์ธํผ๋ฐ์ค ๊ฒฐ๊ณผ๋ฅผ ๊ธฐ๋ฐ์ผ๋ก Sentence-BERT ์๋ฒ ๋ฉ๊ณผ ํด๋ฌ์คํฐ๋ง์ ํตํด ๋ํ ํค์๋๋ฅผ ๋์ถํ๊ณ , ์ํ ์ ๋ ฌ์ ํ์ฉํฉ๋๋ค.
+ - `qwen_deepseek_14b_inference.py`, `qwen_deepseek_32b_inference.py`: ๋ค์ํ Qwen ๊ธฐ๋ฐ ๋ชจ๋ธ์ ํ์ฉํ ์ธํผ๋ฐ์ค ์คํ ์ฝ๋์
๋๋ค.
+ - `review_summarization.py`: ๋ฆฌ๋ทฐ์ ํต์ฌ ํฌ์ธํธ๋ฅผ ์์ฝํ์ฌ ๊ธ์ ๋ฐ ๋ถ์ ๋ฆฌ๋ทฐ ์์ฝ์ ์์ฑํฉ๋๋ค.
+ - `visualization.py`: T-SNE ์๊ฐํ ๋ฐ ํด๋ฌ์คํฐ๋ง ํ๊ฐ(์ค๋ฃจ์ฃ, DBI ๋ฑ)๋ฅผ ์ํํฉ๋๋ค.
+
+- `src/sft_pipeline/`
+ - `qwen_deepseek_14b_finetuning.py`, `qwen_deepseek_32b_finetuning.py`: ์ ํ๋ Qwen ๋ชจ๋ธ์ ๋ํด ASTE Task์ SFT(ํ์ธํ๋)๋ฅผ ์งํํฉ๋๋ค.
+ - `review_crawling.py`: ์จ๋ผ์ธ ์ผํ๋ชฐ์์ ์ํ ์ ๋ณด์ ๋ฆฌ๋ทฐ ๋ฐ์ดํฐ๋ฅผ ํฌ๋กค๋งํฉ๋๋ค.
+ - `review_preprocessing.p`y: ํฌ๋กค๋ง๋ ๋ฆฌ๋ทฐ ๋ฐ์ดํฐ๋ฅผ ์ ์ฒ๋ฆฌ(ํน์๋ฌธ์ ์ ๊ฑฐ, ๋ง์ถค๋ฒ ๊ต์ , ์ค๋ณต ์ ๊ฑฐ ๋ฑ)ํฉ๋๋ค.
+ - `sft.py`: SFT(์ํผ๋ฐ์ด์ฆ๋ ํ์ธํ๋) ์คํ์ ์ํ ํต์ฌ ์ฝ๋์
๋๋ค.
+ - `train_data_annotating.py`: GPT API๋ฅผ ํ์ฉํ์ฌ ๋ฆฌ๋ทฐ ์ด๋
ธํ
์ด์
๋ฐ์ดํฐ๋ฅผ ์์ฑํฉ๋๋ค.
+ - `train_data_sampling.py`: Sentence-BERT ์๋ฒ ๋ฉ๊ณผ K-Means ํด๋ฌ์คํฐ๋ง์ ์ด์ฉํด ๋ํ ๋ฆฌ๋ทฐ ์ํ์ ์ถ์ถํฉ๋๋ค.
+
+- `utils/`
+ - `evaluate.py`: ASTE ๋ฐ ํด๋ฌ์คํฐ๋ง ํ๊ฐ(์ ๋์ ์งํ ์ฐ์ถ)๋ฅผ ์ํํ๋ ์ฝ๋์
๋๋ค.
+ - `utils.py`: ๋ฐ์ดํฐ ์ ์ฒ๋ฆฌ, ํ์ผ ์
์ถ๋ ฅ ๋ฑ ๋ค์ํ ์ ํธ๋ฆฌํฐ ํจ์ ๋ชจ์์
๋๋ค.
\ No newline at end of file
diff --git a/models/review/config/config.yaml b/models/review/config/config.yaml
new file mode 100644
index 0000000..a7245fe
--- /dev/null
+++ b/models/review/config/config.yaml
@@ -0,0 +1,34 @@
+# config/config.yaml
+paths:
+ data_dir: "./data"
+ crawled_reviews_dir: "./data/crawled_reviews"
+ preprocessed_dir: "./data/preprocessed"
+ embedding_dir: "./data/embedding_matrics"
+ prompt_dir: "./prompt"
+ final_outputs_dir: "../final_outputs"
+
+ aste_dir: "./data/aste"
+ train_dir: "./data/aste/train"
+ eval_dir: "./data/aste/eval"
+ inference_dir: "./data/aste/inference"
+
+pipeline:
+ sft:
+ review_crawling: false # ํฌ๋กค๋ง์ ๋ณ๋ ์คํ (์, ํ
์คํธ ์ False)
+ review_preprocessing: true
+ train_data_sampling: false # train data ์์ฑ์์๋ง ๋ณ๋ ์คํ
+ train_data_annotating: false # train data ์์ฑ์์๋ง ๋ณ๋ ์คํ
+ sft: true # SFT๋ ํ์ผ๋ก ๋ณ๋ ์คํ
+ review:
+ aste_inference: true # inference ๋ณ๋ ์คํ
+ review_summarization: true
+ keyword_recommendation: true
+
+
+# Train Data Annotation ๊ด๋ จ
+train_data_annotating:
+ num_train_data: 900
+ annotation_model: "gpt-4o"
+
+# Inference ๊ด๋ จ
+inference_data: "deepseek_inference.csv"
\ No newline at end of file
diff --git a/models/review/data/aste/eval/.gitkeep b/models/review/data/aste/eval/.gitkeep
new file mode 100644
index 0000000..e69de29
diff --git a/models/review/data/aste/inference/.gitkeep b/models/review/data/aste/inference/.gitkeep
new file mode 100644
index 0000000..e69de29
diff --git a/models/review/data/aste/train/.gitkeep b/models/review/data/aste/train/.gitkeep
new file mode 100644
index 0000000..e69de29
diff --git a/models/review/data/crawled_reviews/.gitkeep b/models/review/data/crawled_reviews/.gitkeep
new file mode 100644
index 0000000..e69de29
diff --git a/models/review/data/embedding_matrics/.gitkeep b/models/review/data/embedding_matrics/.gitkeep
new file mode 100644
index 0000000..e69de29
diff --git a/models/review/data/embedding_matrics/cluster_result.png b/models/review/data/embedding_matrics/cluster_result.png
new file mode 100644
index 0000000..4cc607c
Binary files /dev/null and b/models/review/data/embedding_matrics/cluster_result.png differ
diff --git a/models/review/data/embedding_matrics/clustering_evaluation.json b/models/review/data/embedding_matrics/clustering_evaluation.json
new file mode 100644
index 0000000..3da1804
--- /dev/null
+++ b/models/review/data/embedding_matrics/clustering_evaluation.json
@@ -0,0 +1,4 @@
+{
+ "Silhouette": 0.7992019057273865,
+ "DBI": 0.31616701731820174
+}
\ No newline at end of file
diff --git a/models/review/data/embedding_matrics/meals_cluster_result.png b/models/review/data/embedding_matrics/meals_cluster_result.png
new file mode 100644
index 0000000..c4533f0
Binary files /dev/null and b/models/review/data/embedding_matrics/meals_cluster_result.png differ
diff --git a/models/review/data/embedding_matrics/meals_clustering_evaluation.json b/models/review/data/embedding_matrics/meals_clustering_evaluation.json
new file mode 100644
index 0000000..2a56105
--- /dev/null
+++ b/models/review/data/embedding_matrics/meals_clustering_evaluation.json
@@ -0,0 +1,5 @@
+{
+ "category": "meals",
+ "Silhouette": 0.8087015748023987,
+ "DBI": 0.2853419528622432
+}
\ No newline at end of file
diff --git a/models/review/data/embedding_matrics/snacks_cluster_result.png b/models/review/data/embedding_matrics/snacks_cluster_result.png
new file mode 100644
index 0000000..2a555b3
Binary files /dev/null and b/models/review/data/embedding_matrics/snacks_cluster_result.png differ
diff --git a/models/review/data/embedding_matrics/snacks_clustering_evaluation.json b/models/review/data/embedding_matrics/snacks_clustering_evaluation.json
new file mode 100644
index 0000000..5a10606
--- /dev/null
+++ b/models/review/data/embedding_matrics/snacks_clustering_evaluation.json
@@ -0,0 +1,5 @@
+{
+ "category": "snacks",
+ "Silhouette": 0.8068775534629822,
+ "DBI": 0.28233734826217466
+}
\ No newline at end of file
diff --git a/models/review/data/preprocessed/.gitkeep b/models/review/data/preprocessed/.gitkeep
new file mode 100644
index 0000000..e69de29
diff --git a/models/review/environment.yml b/models/review/environment.yml
new file mode 100644
index 0000000..2510c63
--- /dev/null
+++ b/models/review/environment.yml
@@ -0,0 +1,150 @@
+name: review
+channels:
+ - conda-forge
+ - defaults
+dependencies:
+ - _libgcc_mutex=0.1=main
+ - _openmp_mutex=5.1=1_gnu
+ - asttokens=3.0.0=pyhd8ed1ab_1
+ - bzip2=1.0.8=h5eee18b_6
+ - ca-certificates=2025.1.31=hbcca054_0
+ - comm=0.2.2=pyhd8ed1ab_1
+ - debugpy=1.8.11=py311h6a678d5_0
+ - decorator=5.1.1=pyhd8ed1ab_1
+ - exceptiongroup=1.2.2=pyhd8ed1ab_1
+ - executing=2.1.0=pyhd8ed1ab_1
+ - importlib-metadata=8.6.1=pyha770c72_0
+ - ipykernel=6.29.5=pyh3099207_0
+ - ipython=8.32.0=pyh907856f_0
+ - jedi=0.19.2=pyhd8ed1ab_1
+ - jupyter_client=8.6.3=pyhd8ed1ab_1
+ - jupyter_core=5.7.2=pyh31011fe_1
+ - ld_impl_linux-64=2.40=h12ee557_0
+ - libffi=3.4.4=h6a678d5_1
+ - libgcc-ng=11.2.0=h1234567_1
+ - libgomp=11.2.0=h1234567_1
+ - libsodium=1.0.18=h36c2ea0_1
+ - libstdcxx-ng=11.2.0=h1234567_1
+ - libuuid=1.41.5=h5eee18b_0
+ - matplotlib-inline=0.1.7=pyhd8ed1ab_1
+ - ncurses=6.4=h6a678d5_0
+ - nest-asyncio=1.6.0=pyhd8ed1ab_1
+ - openssl=3.0.15=h5eee18b_0
+ - packaging=24.2=pyhd8ed1ab_2
+ - parso=0.8.4=pyhd8ed1ab_1
+ - pexpect=4.9.0=pyhd8ed1ab_1
+ - pickleshare=0.7.5=pyhd8ed1ab_1004
+ - pip=25.0=py311h06a4308_0
+ - platformdirs=4.3.6=pyhd8ed1ab_1
+ - prompt-toolkit=3.0.50=pyha770c72_0
+ - psutil=5.9.0=py311h5eee18b_1
+ - ptyprocess=0.7.0=pyhd8ed1ab_1
+ - pure_eval=0.2.3=pyhd8ed1ab_1
+ - pygments=2.19.1=pyhd8ed1ab_0
+ - python=3.11.11=he870216_0
+ - python-dateutil=2.9.0.post0=pyhff2d567_1
+ - pyzmq=26.2.0=py311h6a678d5_0
+ - readline=8.2=h5eee18b_0
+ - setuptools=75.8.0=py311h06a4308_0
+ - six=1.17.0=pyhd8ed1ab_0
+ - sqlite=3.45.3=h5eee18b_0
+ - stack_data=0.6.3=pyhd8ed1ab_1
+ - tk=8.6.14=h39e8969_0
+ - tornado=6.4.2=py311h5eee18b_0
+ - traitlets=5.14.3=pyhd8ed1ab_1
+ - typing_extensions=4.12.2=pyha770c72_1
+ - wcwidth=0.2.13=pyhd8ed1ab_1
+ - wheel=0.45.1=py311h06a4308_0
+ - xz=5.4.6=h5eee18b_1
+ - zeromq=4.3.5=h6a678d5_0
+ - zipp=3.21.0=pyhd8ed1ab_1
+ - zlib=1.2.13=h5eee18b_1
+ - pip:
+ - annotated-types==0.7.0
+ - anyio==4.8.0
+ - attrs==25.1.0
+ - beautifulsoup4==4.13.3
+ - bs4==0.0.2
+ - certifi==2025.1.31
+ - charset-normalizer==3.4.1
+ - contourpy==1.3.1
+ - cycler==0.12.1
+ - distro==1.9.0
+ - filelock==3.17.0
+ - fonttools==4.56.0
+ - fsspec==2025.2.0
+ - h11==0.14.0
+ - hdbscan==0.8.40
+ - httpcore==1.0.7
+ - httpx==0.28.1
+ - huggingface-hub==0.28.1
+ - idna==3.10
+ - ipywidgets==8.1.5
+ - jinja2==3.1.5
+ - jiter==0.8.2
+ - joblib==1.4.2
+ - jpype1==1.5.2
+ - jupyterlab-widgets==3.0.13
+ - kiwisolver==1.4.8
+ - konlpy==0.6.0
+ - llvmlite==0.44.0
+ - lxml==5.3.0
+ - markupsafe==3.0.2
+ - matplotlib==3.10.0
+ - mpmath==1.3.0
+ - networkx==3.4.2
+ - numba==0.61.0
+ - numpy==2.1.0
+ - nvidia-cublas-cu12==12.4.5.8
+ - nvidia-cuda-cupti-cu12==12.4.127
+ - nvidia-cuda-nvrtc-cu12==12.4.127
+ - nvidia-cuda-runtime-cu12==12.4.127
+ - nvidia-cudnn-cu12==9.1.0.70
+ - nvidia-cufft-cu12==11.2.1.3
+ - nvidia-curand-cu12==10.3.5.147
+ - nvidia-cusolver-cu12==11.6.1.9
+ - nvidia-cusparse-cu12==12.3.1.170
+ - nvidia-cusparselt-cu12==0.6.2
+ - nvidia-nccl-cu12==2.21.5
+ - nvidia-nvjitlink-cu12==12.4.127
+ - nvidia-nvtx-cu12==12.4.127
+ - openai==1.61.1
+ - outcome==1.3.0.post0
+ - pandas==2.2.3
+ - pillow==11.1.0
+ - pydantic==2.10.6
+ - pydantic-core==2.27.2
+ - pynndescent==0.5.13
+ - pyparsing==3.2.1
+ - pysocks==1.7.1
+ - python-dotenv==1.0.1
+ - pytz==2025.1
+ - pyyaml==6.0.2
+ - regex==2024.11.6
+ - requests==2.32.3
+ - safetensors==0.5.2
+ - scikit-learn==1.6.1
+ - scipy==1.15.1
+ - selenium==4.28.1
+ - sentence-transformers==3.4.1
+ - sentencepiece==0.2.0
+ - sniffio==1.3.1
+ - sortedcontainers==2.4.0
+ - soupsieve==2.6
+ - sympy==1.13.1
+ - threadpoolctl==3.5.0
+ - tokenizers==0.21.0
+ - torch==2.6.0
+ - tqdm==4.67.1
+ - transformers==4.48.3
+ - trio==0.28.0
+ - trio-websocket==0.11.1
+ - triton==3.2.0
+ - tzdata==2025.1
+ - umap-learn==0.5.7
+ - urllib3==2.3.0
+ - webdriver-manager==4.0.2
+ - websocket-client==1.8.0
+ - widgetsnbextension==4.0.13
+ - wsproto==1.2.0
+prefix: /data/ephemeral/home/.condaenv/envs/review
diff --git a/models/review/main.py b/models/review/main.py
new file mode 100644
index 0000000..5e581c9
--- /dev/null
+++ b/models/review/main.py
@@ -0,0 +1,82 @@
+# main.py
+import argparse
+import yaml
+import os
+import logging
+
+# ๊ฐ ํ์ดํ๋ผ์ธ ๋ชจ๋ import
+from src.review_pipeline import (
+ aste_inference,
+ review_summarization,
+ keyword_recommendation
+)
+from src.sft_pipeline import (
+ review_crawling,
+ review_preprocessing,
+ train_data_sampling,
+ train_data_annotating,
+ sft # 3.sft.py โ ์์ง ์ฝ๋ ์์
+)
+
+def setup_logger():
+ logging.basicConfig(
+ level=logging.INFO,
+ format="%(asctime)s [%(levelname)s] %(name)s - %(message)s"
+ )
+
+def run_sft_pipeline(config):
+ logging.info("SFT ํ์ดํ๋ผ์ธ ์์")
+ if config["pipeline"]["sft"].get("review_crawling", False):
+ review_crawling.run_review_crawling(config)
+ if config["pipeline"]["sft"].get("review_preprocessing", False):
+ review_preprocessing.run_review_preprocessing(config)
+ if config["pipeline"]["sft"].get("train_data_sampling", False):
+ train_data_sampling.run_train_data_sampling(config)
+ if config["pipeline"]["sft"].get("train_data_annotating", False):
+ train_data_annotating.run_train_data_annotating(config)
+ if config["pipeline"]["sft"].get("sft", False):
+ sft.run_sft(config)
+ logging.info("SFT ํ์ดํ๋ผ์ธ ์๋ฃ.")
+
+def run_review_pipeline(config):
+ logging.info("๋ฆฌ๋ทฐ ํ์ดํ๋ผ์ธ ์์")
+ # (ํฌ๋กค๋ง์ ๋ณ๋ ์คํํ ๊ฒฝ์ฐ config์ crawling ์ต์
์ ๋ฐ๋ผ ์คํ)
+ if config["pipeline"]["review"].get("aste_inference", False):
+ aste_inference.run_aste_inference(config)
+ if config["pipeline"]["review"].get("review_summarization", False):
+ review_summarization.run_review_summarization(config)
+ if config["pipeline"]["review"].get("keyword_recommendation", False):
+ keyword_recommendation.run_keyword_recommendation(config)
+ logging.info("๋ฆฌ๋ทฐ ํ์ดํ๋ผ์ธ ์๋ฃ.")
+
+
+def main():
+ setup_logger()
+ parser = argparse.ArgumentParser(description="ํ์ดํ๋ผ์ธ ์คํ")
+ parser.add_argument(
+ "--config",
+ "-c",
+ default="config/config.yaml",
+ help="์ค์ ํ์ผ ๊ฒฝ๋ก (๊ธฐ๋ณธ๊ฐ: config/config.yaml)"
+ )
+ parser.add_argument(
+ "--pipeline",
+ "-p",
+ choices=["review", "sft", "all"],
+ default="all",
+ help="์คํํ ํ์ดํ๋ผ์ธ ์ ํ (review, sft, all)"
+ )
+ args = parser.parse_args()
+
+ with open(args.config, "r", encoding="utf-8") as f:
+ config = yaml.safe_load(f)
+
+ if args.pipeline in ["sft", "all"]:
+ run_sft_pipeline(config)
+
+ if args.pipeline in ["review", "all"]:
+ run_review_pipeline(config)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/models/review/prompt/keyword_recommendation/recommendation_fewshot.json b/models/review/prompt/keyword_recommendation/recommendation_fewshot.json
new file mode 100644
index 0000000..c7e1f79
--- /dev/null
+++ b/models/review/prompt/keyword_recommendation/recommendation_fewshot.json
@@ -0,0 +1,18 @@
+[
+ {
+ "query": "['ํ์ฌ ๊ฐ์๋ฐ ์ฑ์ฐ๋ ค๊ณ ๊ตฌ๋งคํ์ต๋๋ค', '์ฌ๋ฌด์ค ๊ฐ์์ผ๋ก ์ข์์', '์ฌ๋ฌด์ค ๊ฐ์์ผ๋ก ์ข์ต๋๋ค']",
+ "answer": "์ฌ๋ฌด์ค ํ์ ๊ฐ์"
+ },
+ {
+ "query": "['๋ฐ์ญํ๊ณ ', '๋ฐ์ญ๋ฐ์ญํด์']",
+ "answer": "๋ฐ์ญํจ ๋ํ์"
+ },
+ {
+ "query": "['์ปคํผ๋ ์ฐ์ ๋ ๋จน์ผ๋ฉด ๋์ฑ ์ข์์', '์ปคํผ๋ ์ฐ์ ๋ ๋ค ์ด์ธ๋ ค์']",
+ "answer": "์ปคํผ์ ์ฐฐ๋ก"
+ },
+ {
+ "query": "['๊ฐ์ฑ๋น ์ข์์', '๊ฐ์ฑ๋น ์ข์ ์ ํ์ด๋ค์']",
+ "answer": "๊ฐ์ฑ๋น ๊ฐ"
+ }
+]
\ No newline at end of file
diff --git a/models/review/prompt/keyword_recommendation/recommendation_prompt.txt b/models/review/prompt/keyword_recommendation/recommendation_prompt.txt
new file mode 100644
index 0000000..10f31a3
--- /dev/null
+++ b/models/review/prompt/keyword_recommendation/recommendation_prompt.txt
@@ -0,0 +1,9 @@
+๋น์ ์ ์จ๋ผ์ธ ์๋ฃํ ๋ฆฌ๋ทฐ ๋ถ์๊ฐ์
๋๋ค. ์ ๊ณต๋ ๋ฆฌ๋ทฐ๋ค์ ๋ถ์ํ์ฌ, ํด๋น ํด๋ฌ์คํฐ๋ฅผ ๋ํํ๋ ์งง๊ณ ์ง๊ด์ ์ธ ๋ง์ผํ
ํค์๋๋ฅผ ํ๋๋ง ์์ฑํด์ฃผ์ธ์.
+
+๋ค์ ๊ธฐ์ค์ ๋ฐ๋ฅด์ธ์:
+1. ๋ฆฌ๋ทฐ์์ ๋ฐ๋ณต์ ์ผ๋ก ์ธ๊ธ๋๋ ์ฃผ์ ํน์ง์ ๋ฐ์ํ ๊ฒ (์: \"๊ฐ์ฑ๋น ์ข์\", \"์งํ ํ๋ฏธ\")
+2. ์์ฐ์ค๋ฝ๊ณ ์งง์ ํ๊ตญ์ด ํํ์ ์ฌ์ฉํ ๊ฒ
+3. ์ํ๋ช
์ด๋ ๋ธ๋๋ ๋ฑ ๊ณ ์ ๋ช
์ฌ๋ ํฌํจํ์ง ์์ ๊ฒ
+
+์ถ๋ ฅ์ ์๋์ JSON ํํ๋ฅผ ๋ฐ๋ฅด๋ฉฐ, ๋ฐ๋์ ํ๋์ ํค์๋๋ง ์์ฑํฉ๋๋ค.
+{\"keyword\": \"<ํค์๋>\"}
\ No newline at end of file
diff --git a/models/review/prompt/prompt_loader.py b/models/review/prompt/prompt_loader.py
new file mode 100644
index 0000000..cc8627e
--- /dev/null
+++ b/models/review/prompt/prompt_loader.py
@@ -0,0 +1,21 @@
+# prompt/prompt_loader.py
+import os
+import json
+
+def load_prompt(prompt_filename: str, prompt_dir: str = "./prompt") -> str:
+ """
+ ์ฃผ์ด์ง ํ์ผ ์ด๋ฆ์ ํ๋กฌํํธ ํ
์คํธ๋ฅผ prompt ํด๋์์ ์ฝ์ด ๋ฐํํฉ๋๋ค.
+ ์) "summarization_positive_prompt.txt"
+ """
+ file_path = os.path.join(prompt_dir, prompt_filename)
+ with open(file_path, "r", encoding="utf-8") as f:
+ return f.read()
+
+def load_fewshot(fewshot_filename: str, prompt_dir: str = "./prompt") -> list:
+ """
+ ์ฃผ์ด์ง ํ์ผ ์ด๋ฆ์ few-shot ์์๋ฅผ JSON ํ์ผ๋ก๋ถํฐ ์ฝ์ด ๋ฐํํฉ๋๋ค.
+ ์) "summarization_positive_fewshot.json"
+ """
+ file_path = os.path.join(prompt_dir, fewshot_filename)
+ with open(file_path, "r", encoding="utf-8") as f:
+ return json.load(f)
diff --git a/models/review/prompt/review_annotation/annotation_fewshot.json b/models/review/prompt/review_annotation/annotation_fewshot.json
new file mode 100644
index 0000000..b2c37e5
--- /dev/null
+++ b/models/review/prompt/review_annotation/annotation_fewshot.json
@@ -0,0 +1,34 @@
+[
+ {
+ "query": "๋ง์์ด์ ์ฌ์ฃผ๋ฌธํด์.",
+ "answer": "
\n๋ฆฌ๋ทฐ \"๋ง์์ด์ ์ฌ์ฃผ๋ฌธํด์.\"์์ '๋ง'๊ณผ '์ํ' ๋ ์์ฑ์ ์ถ์ถ.\n\n```json\n[\n {\"์์ฑ\": \"๋ง\", \"ํ๊ฐ\": \"๋ง์๋ค.\", \"๊ฐ์ \": \"๊ธ์ \"},\n {\"์์ฑ\": \"์ํ\", \"ํ๊ฐ\": \"์ฌ์ฃผ๋ฌธํ๋ค.\", \"๊ฐ์ \": \"๊ธ์ \"}\n]\n```"
+ },
+ {
+ "query": "์ธ๊ฒ ์ ์๋ค์ ๋ฐฐ์ก.",
+ "answer": "
\n\"์ธ๊ฒ ์ ์๋ค์\"์์ ๊ฐ๊ฒฉ๊ณผ ์ํ ํ๊ฐ ๋์ถ.\n\n```json\n[\n {\"์์ฑ\": \"๊ฐ๊ฒฉ\", \"ํ๊ฐ\": \"์ ๋ ดํ๋ค.\", \"๊ฐ์ \": \"๊ธ์ \"},\n {\"์์ฑ\": \"์ํ\", \"ํ๊ฐ\": \"์ ์๋ค.\", \"๊ฐ์ \": \"๊ธ์ \"}\n]\n```"
+ },
+ {
+ "query": "๋ฐฐ์ก์ด ๋น ๋ฅด๊ณ ํ์ง์ด ์ข์์.",
+ "answer": "
\n\"๋ฐฐ์ก์ด ๋น ๋ฅด๊ณ \" โ ๋ฐฐ์ก, \"ํ์ง์ด ์ข์์\" โ ์ํ.\n\n```json\n[\n {\"์์ฑ\": \"๋ฐฐ์ก\", \"ํ๊ฐ\": \"๋น ๋ฅด๋ค.\", \"๊ฐ์ \": \"๊ธ์ \"},\n {\"์์ฑ\": \"์ํ\", \"ํ๊ฐ\": \"ํ์ง์ด ์ข๋ค.\", \"๊ฐ์ \": \"๊ธ์ \"}\n]\n```"
+ },
+ {
+ "query": "๋
ธ๋ธ๋๋ ๊ณผ์ ์ค ์ ์ ํ
์
๋๋ค ์ธ๊ณ ๋ง์์ด์.",
+ "answer": "
\n๋
ธ๋ธ๋๋ ๊ณผ์ โ ์ํ, \"์ ์ ํ
์
๋๋ค\" โ ๊ธ์ ํ๊ฐ, \"์ธ๊ณ \" โ ๊ฐ๊ฒฉ, \"๋ง์์ด์\" โ ๋ง.\n\n```json\n[\n {\"์์ฑ\": \"์ํ\", \"ํ๊ฐ\": \"์ ์ ํ
์ด๋ค.\", \"๊ฐ์ \": \"๊ธ์ \"},\n {\"์์ฑ\": \"๊ฐ๊ฒฉ\", \"ํ๊ฐ\": \"์ ๋ ดํ๋ค.\", \"๊ฐ์ \": \"๊ธ์ \"},\n {\"์์ฑ\": \"๋ง\", \"ํ๊ฐ\": \"๋ง์๋ค.\", \"๊ฐ์ \": \"๊ธ์ \"}\n]\n```"
+ },
+ {
+ "query": "์ฒ์ ์ฃผ๋ฌธํด ๋ดค๋๋ฐ ๋ง์์ด์.",
+ "answer": "
\n\"๋ง์์ด์\" โ ๋ง ํ๊ฐ.\n\n```json\n[\n {\"์์ฑ\": \"๋ง\", \"ํ๊ฐ\": \"๋ง์๋ค.\", \"๊ฐ์ \": \"๊ธ์ \"}\n]\n```"
+ },
+ {
+ "query": "๋ก๊ตญ ํด ๋จน๊ธฐ ์ข๊ณ ์ฐ๊ฐ ๋์ผ ๋ ์ฌ์ฉํ๊ธฐ ์ข์์.",
+ "answer": "
\n๋ ๋ฌธ์ฅ ๋ชจ๋ '์ํ' ์์ฑ์ผ๋ก ํ๊ฐ.\n\n```json\n[\n {\"์์ฑ\": \"์ํ\", \"ํ๊ฐ\": \"๋ก๊ตญ ํด ๋จน๊ธฐ ์ข๋ค.\", \"๊ฐ์ \": \"๊ธ์ \"},\n {\"์์ฑ\": \"์ํ\", \"ํ๊ฐ\": \"์ฐ๊ฐ ๋์ผ ๋ ์ฌ์ฉํ๊ธฐ ์ข๋ค.\", \"๊ฐ์ \": \"๊ธ์ \"}\n]\n```"
+ },
+ {
+ "query": "๋ง์์ด์ ํญ์ ์ฌ ๋จน๋ ๊ณผ์.",
+ "answer": "
\n\"๋ง์์ด์\" โ ๋ง, \"ํญ์ ์ฌ ๋จน๋\" โ ์ํ.\n\n```json\n[\n {\"์์ฑ\": \"๋ง\", \"ํ๊ฐ\": \"๋ง์๋ค.\", \"๊ฐ์ \": \"๊ธ์ \"},\n {\"์์ฑ\": \"์ํ\", \"ํ๊ฐ\": \"ํญ์ ์ฌ ๋จน๋๋ค.\", \"๊ฐ์ \": \"๊ธ์ \"}\n]\n```"
+ },
+ {
+ "query": "์ข์ ์ ํ ๋น ๋ฅด๊ฒ ์ ๋ฐ์์ต๋๋ค.",
+ "answer": "
\n\"์ข์ ์ ํ\" โ ์ํ, \"๋น ๋ฅด๊ฒ ์ ๋ฐ์์ต๋๋ค\" โ ๋ฐฐ์ก ํ๊ฐ ๋ถ๋ฆฌ.\n\n```json\n[\n {\"์์ฑ\": \"์ํ\", \"ํ๊ฐ\": \"์ข๋ค.\", \"๊ฐ์ \": \"๊ธ์ \"},\n {\"์์ฑ\": \"๋ฐฐ์ก\", \"ํ๊ฐ\": \"๋น ๋ฅด๋ค.\", \"๊ฐ์ \": \"๊ธ์ \"},\n {\"์์ฑ\": \"๋ฐฐ์ก\", \"ํ๊ฐ\": \"์ ๋์ฐฉํ๋ค.\", \"๊ฐ์ \": \"๊ธ์ \"}\n]\n```"
+ }
+]
diff --git a/models/review/prompt/review_annotation/annotation_prompt.txt b/models/review/prompt/review_annotation/annotation_prompt.txt
new file mode 100644
index 0000000..9fd0bc3
--- /dev/null
+++ b/models/review/prompt/review_annotation/annotation_prompt.txt
@@ -0,0 +1,9 @@
+๋น์ ์ ๊ฐ์ฑ ๋ถ์ ๋ฐ ์ํ ํ๊ฐ ์ ๋ฌธ๊ฐ์ด๋ค. ์ฃผ์ด์ง ์ํ ๋ฆฌ๋ทฐ๋ฅผ ๋ถ์ํ์ฌ ํด๋น ๋ฆฌ๋ทฐ์์ ์ํ ์์ฑ๊ณผ ํ๊ฐ, ๊ฐ์ ์ ์ถ์ถํ๋ผ.
+
+### ์์
๋ชฉํ:
+1. ์
๋ ฅ ๋ฆฌ๋ทฐ๋ฅผ ๋ถ์ํ์ฌ JSON ํ์์ผ๋ก ์ถ๋ ฅ
+2. ์ค๊ฐ ๋ถ์ ๊ณผ์ ์
(๋ถ์๊ณผ์ ) ํ๊ทธ๋ก ํฌํจ
+3. ์ถ๋ ฅ ์์๋ ๋ฆฌ์คํธ ํํ์ JSON ๊ฐ์ฒด๋ก ๊ตฌ์ฑ
+
+์์ฑ์ ์ฃผ๋ก: ["์ํ", "๋ฐฐ์ก", "๊ฐ๊ฒฉ", "๋ง", "์ ์ ๋", "์", "ํฌ์ฅ"].
+ํ์์ ์๋ก์ด ์์ฑ ์ฌ์ฉ ๊ฐ๋ฅํ๋ฉฐ ๋ชจ๋ ์ํ๋ช
์ "์ํ"์ผ๋ก ํต์ผํ๋ค.
diff --git a/models/review/prompt/review_summarization/negative_fewshot.json b/models/review/prompt/review_summarization/negative_fewshot.json
new file mode 100644
index 0000000..1a686cf
--- /dev/null
+++ b/models/review/prompt/review_summarization/negative_fewshot.json
@@ -0,0 +1,45 @@
+[
+ {
+ "query": [
+ "์ข ๋๋ผํด์",
+ "์๋ฌด๋๋ ์๋ก๋ณ๋ณด๋ค๋ ์ข ๋ถ์กฑํ์ง๋ง",
+ "์ฝ๊ฐ ๊ธฐ๋ฆ๋ง์ด ๋ง์ด ๋๊ธด ํ๋๋ฐ",
+ "๋ง์ด ๊ธฐ๋๋ฉ๋๋ค",
+ "์ฝ๊ฐ ๋๋ผํ๊ธด ํ๋ฐ ๊ทธ๋๋ ๋ถ๋ด ์์ด ๋จน๊ธฐ ์ข์ ๊ฒ ๊ฐ์์",
+ "๋ง๋ ๋์์ง ์์ต๋๋ค",
+ "ํ์ ํ๊ณผ ๋ง์ด ๋ณ๋ฐ ์ฐจ์ด ์๊ณ "
+ ],
+ "answer": "์กฐ๊ธ ๊ธฐ๋ฆ์ง ๋ง์ด ๋๊ธด ํ๋๋ฐ, ๊ทธ๋๋ ๋ถ๋ด ์์ด ๋จน์ ์ ์์ด์."
+ },
+ {
+ "query": [
+ "๋ฌ์์ ๋๋ฃฝ์ง์ ๊ตฌ์ํ๊ณ ๋ด๋ฐฑํจ์ ์์ด์",
+ "์์ง ์ ๋จน์ด ๋ดค์ง๋ง ๋ง์๊ฒ ์ฃ ",
+ "์ฒ์์ ์ง์ง ๋๋ฃฝ์ง ํฅ์ด ๋์ ๊ด์ฐฎ๋ค ์ถ์๋๋ฐ",
+ "๋ฌด๋ํ๊ณ ์ด์ดํ๊ณ ๋ฌ๋ฌํด์",
+ "๋ง์ด ๋ฌ์์",
+ "์์ง ์ ๋จน์ด ๋ดค์ด์",
+ "ํ๊ธด ๋๋ฃฝ์ง ์คํ ๋ฟ๋ฆฐ ๋ง์
๋๋ค"
+ ],
+ "answer": "๋๋ฃฝ์ง์ ๊ตฌ์ํ ๋ง๋ณด๋ค๋ ๋จ๋ง์ด ๊ฐํ๊ฒ ๋๊ปด์ ธ ์์ฌ์ ์ต๋๋ค."
+ },
+ {
+ "query": [
+ "์์ง ๋จน์ด๋ณด์ง ์์์ ๋ง์ ์ ๋ชจ๋ฅด๊ฒ ์ด์"
+ ],
+ "answer": "์๋ณธ ๋ฆฌ๋ทฐ๊ฐ ์์ต๋๋ค."
+ },
+ {
+ "query": [
+ "๋ฑ๊ฐ ํฌ์ฅ์ด๋ผ ๋จน๊ธฐ๋ ํธํ์ง๋ง",
+ "์ํฌ์ฅ์ด๋ผ ์ฐ๋ ๊ธฐ ๊ฑฑ์ ์ ์ข ๋์ง๋ง ๋ณด๊ดํ๊ธฐ ํธํด์ ์ฌ๊ฒ ๋๋ค์"
+ ],
+ "answer": "๋ฑ๊ฐ ํฌ์ฅ์ด ํธ๋ฆฌํ์ง๋ง, ์ํฌ์ฅ์ด๋ผ ์ฐ๋ ๊ธฐ๊ฐ ๋์ด๋ ๊น ๊ฑฑ์ ๋๋ค์."
+ },
+ {
+ "query": [
+ "์ข
์ด๋ดํฌ์ ์ํ ์กฐ์กํ๊ฒ ์ฒ๋ฐ์์ ๋ณด๋ด๋ ๊ฑด ์ฌ์ ํ๋ค์"
+ ],
+ "answer": "์ํ์ด ์ข
์ด๋ดํฌ์ ๋ถํธํ๊ฒ ํฌ์ฅ๋์ด ์์ด ์์ฌ์์ด ์์์ต๋๋ค."
+ }
+]
diff --git a/models/review/prompt/review_summarization/negative_prompt.txt b/models/review/prompt/review_summarization/negative_prompt.txt
new file mode 100644
index 0000000..0ccd304
--- /dev/null
+++ b/models/review/prompt/review_summarization/negative_prompt.txt
@@ -0,0 +1 @@
+๋น์ ์ ์ ๋ฌธ ๋ฆฌ๋ทฐ ๋ถ์๊ฐ์
๋๋ค. ์๋์ ์ ๊ณต๋ ์๋น์๋ค์ ์๋ณธ ๋ฆฌ๋ทฐ๋ฅผ ๊ธฐ๋ฐ์ผ๋ก, ๋ถ์ ์ ๋๋ ์ค๋ฆฝ์ ์ธ ํ๊ฐ์ ํต์ฌ ํฌ์ธํธ๋ฅผ ํ๋์ ์์ ํ ๋ฌธ์ฅ์ผ๋ก ์์ฝํ์ธ์. ์๋น์๊ฐ ์ง์ ํ ๋จ์ ์ด๋ ๊ฐ์ ์ ์ ์ ์คํ ๋์๋ง๋ก ํํํ์ญ์์ค. ์ถ๋ ฅ์ ๋ฐ๋์ JSON ํ์์ผ๋ก, ์: {{\"summarization\": \"<์์ฝ>\"}} ํํ์ฌ์ผ ํฉ๋๋ค."
\ No newline at end of file
diff --git a/models/review/prompt/review_summarization/positive_fewshot.json b/models/review/prompt/review_summarization/positive_fewshot.json
new file mode 100644
index 0000000..e771a96
--- /dev/null
+++ b/models/review/prompt/review_summarization/positive_fewshot.json
@@ -0,0 +1,23 @@
+[
+ {
+ "query": "['๋ฐฐ์ก ๋น ๋ฅด๊ณ ์ข์ต๋๋ค', '์๋ ํฌ์ฅ ๋๋ฌด ์ข์์', '๋น ๋ฅด๊ณ ์์ ๋ฐฐ์ก ๊ฐ์ฌํฉ๋๋ค', 'ํฌ์ฅ ๋ฐ ๋ฐฐ์ก ์ํ ๋ค ์ข์์', '์ฐ๊ทธ๋ฌ์ง ๊ณณ ์์ด ๋ฐฐ์ก ์ข๊ณ ', '๋น ๋ฅธ ๋ฐฐ์ก ๊ฐ์ฌํฉ๋๋ค', '๋น ๋ฅธ ๋ฐฐ์ก์ด๊ณ ', '๋นจ๋ฆฌ ๋ฐฐ์ก๋์ด์ ์์ฃผ ๋ง์กฑํฉ๋๋ค']",
+ "answer": "๋น ๋ฅด๊ณ ์์ ํ ๋ฐฐ์ก๊ณผ ๊ผผ๊ผผํ ํฌ์ฅ ์ํ์ ๋งค์ฐ ๋ง์กฑํฉ๋๋ค."
+ },
+ {
+ "query": "['๋น ๋ฅด๊ฒ ํธํ๊ฒ ๊ตฌ์
ํ์ด์', '๋ฐฐ์ก๋ ๋น ๋ฅด๊ณ ์ข์์', '๋น ๋ฅธ ๋ฐฐ์ก ๊ฐ์ฌํฉ๋๋ค', '๋ฐฐ์ก ์ข์์ ๋ ์ ์ฉํด์', '์ง์์ ํธํ ๋ฐ์๋ณผ ์ ์์ด ๊ฐ์ฌํฉ๋๋ค', '๋นจ๋ฆฌ ๋ฐฐ์กํด ์ฃผ์
์', '๋ฐฐ์ก ๋น ๋ฅด๊ฒ ์์ต๋๋ค', 'ํ๋ฐฐ ์์ ์จ ๊ณ ๋ง์ต๋๋ค', '๋ฐฐ์ก ์ ์์ต๋๋ค', '์๊ฐ ๋ง์ถฐ ๋ณด๋ด์ฃผ์
จ์ด์', '๋ฐฐ์ก๋ ๋น ๋ฅด๊ตฌ์', '๋ฐฐ์ก ๋น ๋ฅด๊ณ ์ ํํ๊ฒ ์์ต๋๋ค', '๋งค๋ฒ ์ ๋ฐ๊ณ ์์ต๋๋ค', '๋น ๋ฅธ ๋ฐฐ์ก ์ต๊ณ ์์', '์งํผ๋ฐฑ์ผ๋ก ๋ผ์์ด์ ๋ฐ์ญํจ์ด ์ค๋ ์ ์ง๋๊ณ ', '๋ฐฐ์ก๋ ๋น ๋ฅด๋ ์์ฃผ ์ํค๋ค์', 'ํฌ์ฅ ์ํ๋ ์ข์ต๋๋ค', '๋น ๋ฅด๊ฒ ๋ฐฐ์ก๋๊ณ ']",
+ "answer": "๋น ๋ฅด๊ณ ์ ํํ ๋ฐฐ์ก๊ณผ ๊ผผ๊ผผํ ํฌ์ฅ ๋๋ถ์ ๋งค์ฐ ๋ง์กฑํ๋ฉฐ ์์ฃผ ์ด์ฉํฉ๋๋ค."
+ },
+ {
+ "query": "['๋๋ฌด ๋ง์์ด์', '๋คํฌ์ด์ฝ๋ฆฟ์ด์ง๋ง ๋ฌ์ฝคํฉ๋๋ค', '๋ฌ์ง๋ ์ฐ์ง๋ ์๊ณ ์ข์์', '๋ฌ๋ฌํ๊ณ ๋คํฌํด์ ๊ตฟ', '๋ฌ์ง ์๊ณ ๋ง์์ด์', '๊ฐ๊ฒฉ ๋๋น ๋ง์๊ณ ', '๋ง์์ด์ ์์ฃผ ๋จน์ด์', '๊ณ ๊ธ์ง๋ฉด์ ๊ธฐ๋ถ ๋์ ์ด๋ง์ด ์๋', '๋ฌ๋ฌํ๋ ๋ง์์ต๋๋ค', '๋ง์๊ฒ ์ ๋จน์์ต๋๋ค', '์ด ์ด์ฝ๋ฆฟ ์ข์ํด์ ๊ตฌ๋งคํฉ๋๋ค', '๋คํฌ์ด์ฝ๋ฆฟ ๋ง์์ต๋๋ค', '๋๋ฌด ๋ฌ์ง๋ ์๊ณ ๋ง์์ด์', '์์ํ๊ฒ ๋ง์์ด์', '๋คํฌ์ด์ฝ๋ฆฟ ์ปคํผ๋ ๋จน์ผ๋ฉด ์ ๋ง ๋ง์์ฃ ', '๋ง๋ ๊ฐ๊ฒฉ๋ ์๋์ 1์', '๋ง์๊ณ ๋ง์์ด์', '์์ด๋ค์ด ๋ง์๊ฒ ์ ๋จน์์ต๋๋ค']",
+ "answer": "๋ฌ์ฝคํ๋ฉด์๋ ์์ํ ๋คํฌ์ด์ฝ๋ฆฟ์ ๊ท ํ ์กํ ๋ง์ด ์งํ๊ณ ๋ถ๋๋ฌ์ ์์ฃผ ์ฌ๊ตฌ๋งคํ๊ฒ ๋ฉ๋๋ค."
+ },
+ {
+ "query": "['์ญ์ ์ด๋์ธ', '๋๋ฌด ๋ง์์ด์', '์ธ์ ๋ ๋ง์์ด์', 'ํฌ์นด์นฉ ๋๋ฌด ๋ง์์ด์', '๊ฐ์์นฉ ์ค์ ์ ์ผ ๋ง์์', '์์ด๋ค์ด ๋ง์๋ค๊ณ ํ๋ค์', '์ด๋ค ์ ํ๋ ๋ฐ๋ผ์ฌ ์ ์๋ ๋ง', '์ํ๋ง์ด ์ข์', '์ญ์ ์ํ๋ง์ด ๋ง์๋ค์', '๋ง์์ด์ ํญ์ ๋จ์ด์ง๋ฉด ์ํค๋ ์ ํ์ด์์', 'ํญ์ ๋จน์ง๋ง ๋ง์์ด์', '๊ฐ์์นฉ ์ค์์๋ ํฌ์นด์นฉ์ ๋ฐ๋ผ๊ฐ ์ ํ์ด ์๋ ๊ฒ ๊ฐ์์', 'ํฌ์นด์นฉ ์ด๋์ธ๋ง ์ข์์', '์ค๋ฆฌ์ง๋๋ณด๋ค ๋ง์๋ค์', '๋ง์์ด์ ํญ์ ๊ตฌ์
ํ๋ ๊ณผ์์
๋๋ค', '๋ง์์ด์ ๋จน๊ธฐ ์ข์์', 'ํญ์ ๋ง์๊ฒ ๋จน๊ณ ์์ด์']",
+ "answer": "์งญ์งคํ๊ณ ๊ณ ์ํ ์ด๋์ธ ๋ง์ด ๋ฐ์ด๋, ํฌ์นด์นฉ ์ค์์ ๊ฐ์ฅ ๋ง์๊ณ ํญ์ ์ฌ๊ตฌ๋งคํ๊ฒ ๋ฉ๋๋ค."
+ },
+ {
+ "query": "['๋๋ฌด ๋ง์์ด์', '์ฌ์ ํ๋ค์', '์ฌ๊ธฐ ๊ฒ์ด ๋ง์์ด์', '๋ด๋ฐฑํ๊ณ ๋ง์์ด์', '์ ๋ง ๊ทธ๋๋ก๋ค์', '์ ๋ง ๊ณ ์ํ๋ค์', '๋ง์๋ ๊ณผ์์์', '์ด์ง ์งญ์งคํ๋ ๋ง์์ด์', '๊ฐ๋ ๋จน์ผ๋ฉด ๋ง์๋ ๊ณผ์์์', '์คํ
๋์
๋ฌ์ธ ๋งํผ ๋ง์ ๋ณด์ฅ๋ผ ์์ฃ ', '๋ง์ด ๋ณํ์ง ์์์ ์ข์์', '๋ง์์ด์ ์์ฃผ ๊ตฌ๋งคํด์', '๋ง์์ด์ ๋จน๊ธฐ ์ข์์', '์ง๋ฆฌ์ง ์๋ ๋ง', '๊ณ ์ํ๊ณ ๋ง์์ด์', '๋ด๋ฐฑํ๊ณ ๊ณ ์ํ๊ณ ๋ง์์ด์', '๋ง์๊ฒ ๋จน๊ณ ์์ต๋๋ค', '๋๋ฌด ๋ง์์ด์', '์ถ์ต์ ๋ง์ด์ฃ ', '์ปคํผ ํ ์์ด ์๊ฐ๋๋ ๋ง์ด์์']",
+ "answer": "๋ณํจ์๋ ๊ณ ์ํ๊ณ ๋ด๋ฐฑํ ๋ง์ผ๋ก, ์ธ์ ๋จน์ด๋ ๋ง์์ด ์์ฃผ ๊ตฌ๋งคํ๊ฒ ๋ฉ๋๋ค."
+ }
+ ]
+
\ No newline at end of file
diff --git a/models/review/prompt/review_summarization/positive_prompt.txt b/models/review/prompt/review_summarization/positive_prompt.txt
new file mode 100644
index 0000000..d100b27
--- /dev/null
+++ b/models/review/prompt/review_summarization/positive_prompt.txt
@@ -0,0 +1 @@
+๋น์ ์ ์ ๋ฌธ ๋ฆฌ๋ทฐ ๋ถ์๊ฐ์
๋๋ค. ์๋์ ์ ๊ณต๋ ์๋น์๋ค์ ์๋ณธ ๋ฆฌ๋ทฐ๋ฅผ ๊ธฐ๋ฐ์ผ๋ก, ๊ธ์ ์ ์ธ ํ๊ฐ์ ํต์ฌ ํฌ์ธํธ๋ฅผ ํ๋์ ์์ ํ ๋ฌธ์ฅ์ผ๋ก ์์ฝํ์ธ์. ์๋น์๊ฐ ๊ฐ์กฐํ ๊ฐ์ ๊ณผ ์ฅ์ ์ ์ ์คํ ๋์๋ง๋ก ํํํ์ญ์์ค. ์ถ๋ ฅ์ ๋ฐ๋์ JSON ํ์์ผ๋ก, ์: {"summarization": "<์์ฝ>"} ํํ์ฌ์ผ ํฉ๋๋ค.
diff --git a/models/review/src/review_pipeline/__init__.py b/models/review/src/review_pipeline/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/models/review/src/review_pipeline/aste_inference.py b/models/review/src/review_pipeline/aste_inference.py
new file mode 100644
index 0000000..c57ef54
--- /dev/null
+++ b/models/review/src/review_pipeline/aste_inference.py
@@ -0,0 +1,8 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+
+def run_aste_inference(config):
+ print("qwen_deepseek_14,32b_inference.py ํ์ผ๋ก Inference๋ฅผ ์งํํ์ธ์.\n")
+
+if __name__ == "__main__":
+ run_aste_inference({"paths": {}})
diff --git a/models/review/src/review_pipeline/keyword_recommendation.py b/models/review/src/review_pipeline/keyword_recommendation.py
new file mode 100644
index 0000000..75b48ff
--- /dev/null
+++ b/models/review/src/review_pipeline/keyword_recommendation.py
@@ -0,0 +1,263 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+"""
+[์ถ์ฒ ํค์๋ ๊ธฐ๋ฐ ์ํ ์ฌ์ ๋ ฌ ํ์ดํ๋ผ์ธ]
+- ์
๋ ฅ CSV ํ์ผ ๋ก๋ถํฐ ์ถ์ฒ ํค์๋ ๋ณ๋ก ์ํ์ ์ฌ์ ๋ ฌํ์ฌ ์ต์ข
์ถ์ฒ CSV ํ์ผ์ ์์ฑํ๋ค.
+- ์ต์ข
์ถ๋ ฅ CSV ํ์ผ์ ๋ค์ ์ด๋ก ๊ตฌ์ฑ๋๋ค:
+ ์นดํ
๊ณ ๋ฆฌ, ํค์๋, ID, ์ํ๋ช
, opinion ๊ฐ์
+"""
+
+import os
+import re
+import json
+import time
+import requests
+import pandas as pd
+import numpy as np
+from dotenv import load_dotenv
+from utils.utils import load_data, expand_inference_data, sentenceBERT_embeddings, umap_reduce_embeddings, agglomerative_clustering, visualize_clustering, evaluate_clustering
+from prompt.prompt_loader import load_prompt, load_fewshot
+
+#########################################################
+# ๋ฐ์ดํฐ ์ ์ฒ๋ฆฌ ๋ฐ ํ์ฅ ๊ด๋ จ ํจ์
+#########################################################
+
+def filter_invalid_value(raw_value):
+ """
+ unsloth_deepseek_32b ์นผ๋ผ์ ๊ฐ์ ์ ๋ฌ๋ฐ์,
+ - float(NaN) ๋๋ "[]"์ธ ๊ฒฝ์ฐ None ๋ฐํ
+ - ์ ์์ ์ธ JSON ๋ฌธ์์ด์ด๋ฉด ํ์ฑ ํ ๋ฆฌ์คํธ ๋ฐํ, ๊ทธ๋ ์ง ์์ผ๋ฉด None ๋ฐํ
+ """
+ if isinstance(raw_value, float) or (isinstance(raw_value, str) and raw_value.strip() == "[]"):
+ return None
+ try:
+ parsed = json.loads(raw_value)
+ if isinstance(parsed, list):
+ return parsed
+ else:
+ return None
+ except json.JSONDecodeError:
+ return None
+
+def load_and_prepare_data(config):
+ """
+ config์ ๋ช
์๋ ๊ฒฝ๋ก์ ํ์ผ๋ช
์ ํ์ฉํ์ฌ ์
๋ ฅ CSV ํ์ผ์ ๋ก๋ํ๊ณ ,
+ unsloth_deepseek_32b ์นผ๋ผ์ ๊ฐ์ ํ์ฑํ ํ ํ์ฅํ๋ค.
+ ๋ถํ์ํ ์ด์ ์ญ์ ํ๋ค.
+ """
+ train_data_path = os.path.join(config["paths"]["inference_dir"], config["inference_data"])
+ df_infer = load_data(train_data_path)
+ df_infer["unsloth_deepseek_32b"] = df_infer["unsloth_deepseek_32b"].apply(filter_invalid_value)
+ df_infer = df_infer.dropna(subset=["unsloth_deepseek_32b"])
+ aste_df = expand_inference_data(df_infer, "unsloth_deepseek_32b")
+ cols_to_drop = ['aste_hcx', 'aste_gpt', 'aste_golden_label']
+ aste_df.drop(columns=cols_to_drop, errors='ignore', inplace=True)
+ print(f"๋ฆฌ๋ทฐ Opinion ๋ฐ์ดํฐ ๊ฐ์: {aste_df.shape[0]}")
+ return aste_df
+
+def extract_product_id(review_id):
+ """
+ review-ID ํ์์ด "emart-(์ซ์)-(์ซ์)"์ธ ๊ฒฝ์ฐ,
+ ์ฒซ ๋ฒ์งธ ๋ถ๋ถ("emart-์ซ์")๋ฅผ ์ถ์ถํ์ฌ ์ํ ID๋ก ์ฌ์ฉํ๋ค.
+ """
+ match = re.match(r"(emart-\d+)-\d+", review_id)
+ return match.group(1) if match else review_id
+
+#########################################################
+# ํด๋ฌ์คํฐ๋ง ๋ฐ ์ถ์ฒ ํค์๋ ์์ฑ ๊ด๋ จ ํจ์
+#########################################################
+
+def get_sorted_clusters(df, cluster_column="cluster_label"):
+ """
+ ํด๋ฌ์คํฐ ๋ผ๋ฒจ๋ณ ํฌ๊ธฐ๋ฅผ ๊ณ์ฐํ์ฌ ์ ๋ ฌ๋ DataFrame ๋ฐํ
+ """
+ cluster_sizes = df[cluster_column].value_counts().reset_index()
+ cluster_sizes.columns = [cluster_column, "size"]
+ return cluster_sizes[cluster_sizes[cluster_column] != -1].sort_values(by="size", ascending=False)
+
+def hcx_generate_cluster_keywords(df, sorted_clusters, text_column="review", cluster_column="cluster_label"):
+ """
+ ๊ฐ ํด๋ฌ์คํฐ์ ๋ํด HCX API๋ฅผ ํ์ฉํ์ฌ ๋ํ ํค์๋๋ฅผ ์์ฑํ๋ค.
+ """
+ id_keyword_map = {}
+ for cluster in sorted_clusters[cluster_column]:
+ size = sorted_clusters[sorted_clusters[cluster_column] == cluster]['size'].values[0]
+ print(f"Cluster {cluster} (Size: {size})")
+ cluster_texts = df[df[cluster_column] == cluster][text_column]
+ unique_texts = list(set(cluster_texts.to_list()))[:50]
+ print("์ํ ๋ฆฌ๋ทฐ:", unique_texts)
+ keyword = robust_inference(str(unique_texts))
+ id_keyword_map[cluster] = keyword.strip().strip('"')
+ print("์์ฑ๋ ํค์๋:", id_keyword_map[cluster])
+ print("-" * 80)
+ return id_keyword_map
+
+def generate_recommendations(df, id_keyword_map, selected_clusters):
+ """
+ ์ ํ๋ ํด๋ฌ์คํฐ๋ณ๋ก, ๊ฐ ํด๋ฌ์คํฐ ๋ด์์ ์ต์ 2ํ ์ด์ ๋ฑ์ฅํ ์ํ์ ๋์์ผ๋ก
+ ์ถ์ฒ ์ํ DataFrame์ ์์ฑํ๋ค.
+ ์ต์ข
์ถ๋ ฅ ์ด์:
+ ์นดํ
๊ณ ๋ฆฌ, ํค์๋, ID, ์ํ๋ช
, opinion ๊ฐ์
+ """
+ result = pd.DataFrame(columns=["์นดํ
๊ณ ๋ฆฌ", "ํค์๋", "ID", "์ํ๋ช
", "opinion ๊ฐ์"])
+ for cluster in selected_clusters:
+ # ์ํ๋ช
๋ณ ๋น๋์ ๊ณ์ฐ (5ํ ์ด์ ๋ฑ์ฅํ ๊ฒฝ์ฐ๋ง)
+ targets = df[df["cluster_label"] == cluster].value_counts(subset=["name"])
+ targets = targets[targets >= 5]
+ items = [item[0] for item in targets.keys().tolist()]
+ # ๊ฐ ์ํ์ review-ID์์ ์ํ ID ์ถ์ถ
+ ids = [extract_product_id(df[df["name"] == name]["review-ID"].values[0]) for name in items]
+ # ์นดํ
๊ณ ๋ฆฌ๋ ๊ฐ ์ํ์ "category" ์ด์ ์ฒซ ๋ฒ์งธ ๊ฐ ์ฌ์ฉ
+ categories = [df[df["name"] == name]["category"].values[0] for name in items]
+ keyword = id_keyword_map.get(cluster, "")
+ for item, count, pid, cat in zip(items, targets.tolist(), ids, categories):
+ temp_df = pd.DataFrame({
+ "์นดํ
๊ณ ๋ฆฌ": [cat],
+ "ํค์๋": [keyword],
+ "ID": [pid],
+ "์ํ๋ช
": [item],
+ "opinion ๊ฐ์": [count]
+ })
+ result = pd.concat([result, temp_df], ignore_index=True)
+ return result
+
+def get_prompt_and_fewshot():
+ """
+ aspect์ sentiment์ ๋ฐ๋ฅธ ํ๋กฌํํธ์ few-shot ์์๋ฅผ ๋ฐํํ๋ค.
+ ํ๋กฌํํธ์ few-shot ์์๋ prompt ํด๋ ๋ด์ ํ์ผ๋ก ๋ถ๋ฆฌ๋์ด ๊ด๋ฆฌ๋๋ค.
+ """
+ prompt = load_prompt(prompt_filename="recommendation_prompt.txt",
+ prompt_dir="./prompt/keyword_recommendation/")
+ fewshot = load_fewshot(fewshot_filename="recommendation_fewshot.json",
+ prompt_dir="./prompt/keyword_recommendation/")
+ return prompt, fewshot
+
+def robust_inference(query, retry_delay=2):
+ """
+ API ์์ฒญ์ด ์คํจํ๋ฉด ์ฌ์๋ํ๋ ํจ์ (์ถ์ฒ ํค์๋ ์์ฑ์ ์ํด HCX API ํธ์ถ)
+ """
+ while True:
+ result = inference(query)
+ if not (result.startswith("API Error") or result.startswith("Request Error")):
+ return result
+ print("API ์ค๋ฅ ๋ฐ์, ๋ค์ ์๋ํฉ๋๋ค...")
+ time.sleep(retry_delay)
+
+def inference(query):
+ load_dotenv(os.path.expanduser("~/.env"))
+ AUTHORIZATION = os.getenv("AUTHORIZATION")
+ X_NCP_CLOVASTUDIO_REQUEST_ID = os.getenv("X_NCP_CLOVASTUDIO_REQUEST_ID")
+ if not AUTHORIZATION or not X_NCP_CLOVASTUDIO_REQUEST_ID:
+ raise ValueError("ํ์ API ์ธ์ฆ ์ ๋ณด๊ฐ .env์ ์ค์ ๋์ด ์์ง ์์ต๋๋ค.")
+ headers = {
+ 'Authorization': AUTHORIZATION,
+ 'X-NCP-CLOVASTUDIO-REQUEST-ID': X_NCP_CLOVASTUDIO_REQUEST_ID,
+ 'Content-Type': 'application/json; charset=utf-8',
+ }
+ prompt, fewshot = get_prompt_and_fewshot()
+
+ messages = [{"role": "system", "content": prompt}]
+ for example in fewshot:
+ # ๋ง์ฝ fewshot์ query๊ฐ ๋ฆฌ์คํธ๋ฉด join
+ query_text = " ".join(example["query"]) if isinstance(example["query"], list) else example["query"]
+ messages.append({"role": "user", "content": query_text})
+ messages.append({"role": "assistant", "content": example["answer"]})
+ messages.append({"role": "user", "content": query})
+
+ request_data = {
+ 'messages': messages,
+ 'topP': 0.8,
+ 'topK': 0,
+ 'maxTokens': 1024,
+ 'temperature': 0.5,
+ 'repeatPenalty': 5.0,
+ 'stopBefore': [],
+ 'includeAiFilters': False,
+ 'seed': 42
+ }
+ try:
+ response = requests.post("https://clovastudio.stream.ntruss.com/testapp/v1/chat-completions/HCX-003", headers=headers, json=request_data)
+ response.raise_for_status()
+ response_json = response.json()
+ if response_json.get("status", {}).get("code") == "20000":
+ output_text = response_json["result"]["message"]["content"]
+ return output_text
+ else:
+ return f"API Error: {response_json.get('status', {}).get('message')}"
+ except requests.exceptions.RequestException as e:
+ return f"Request Error: {e}"
+
+#########################################################
+# ์ต์ข
์ถ์ฒ ํ์ดํ๋ผ์ธ ์คํ ํจ์
+#########################################################
+
+def run_keyword_recommendation(config):
+ print("\n[์ถ์ฒ ํค์๋ ๊ธฐ๋ฐ ์ํ ์ฌ์ ๋ ฌ ํ์ดํ๋ผ์ธ ์์]\n")
+
+ # ๋ฐ์ดํฐ ๋ก๋ ๋ฐ ์ ์ฒ๋ฆฌ
+ aste_df = load_and_prepare_data(config)
+
+ # ๊ธ์ ์๊ฒฌ๋ง ์ ํ (์ถ์ฒ ํค์๋๋ ๊ธ์ ๋ฆฌ๋ทฐ ๊ธฐ๋ฐ)
+ infer_pos = aste_df[aste_df["sentiment"] == "๊ธ์ "]
+ infer_pos.loc[:, 'category'] = infer_pos['category'].replace({'์์ด๊ฐ์': '๋ผ๋ฉด/๊ฐํธ์'})
+
+ category_map = {
+ "๊ณผ์/๋น๊ณผ": "snacks",
+ "๋ผ๋ฉด/๊ฐํธ์": "meals"
+ }
+
+ infer_pos["category"] = infer_pos["category"].map(category_map).fillna(infer_pos["category"])
+
+ # ์นดํ
๊ณ ๋ฆฌ๋ณ๋ก DataFrame ๋ถํ
+ category_dfs = {category: df for category, df in infer_pos.groupby("category")}
+
+ all_recommendations = [] # ์ ์ฒด ์นดํ
๊ณ ๋ฆฌ์ ์ถ์ฒ ๊ฒฐ๊ณผ๋ฅผ ์ ์ฅํ ๋ฆฌ์คํธ
+
+ # ๊ฐ ์นดํ
๊ณ ๋ฆฌ๋ณ๋ก ์ ์ฒด ํ์ดํ๋ผ์ธ ์คํ
+ for category, df_cat in category_dfs.items():
+ print(f"\n[์นดํ
๊ณ ๋ฆฌ: {category} ์ฒ๋ฆฌ ์์]\n")
+
+ # 1) ์๋ฒ ๋ฉ ๋งคํธ๋ฆญ์ค ์์ฑ
+ embedding_file = os.path.join(config["paths"]["embedding_dir"], f"deepseek_inference_{category}.npy")
+ embedding_matrix = sentenceBERT_embeddings(embedding_file, df=df_cat, column="opinion")
+
+ # 2) UMAP ์ฐจ์ ์ถ์
+ reduced_embeddings = umap_reduce_embeddings(embedding_matrix, n_components=256)
+
+ # 3) ํด๋ฌ์คํฐ๋ง (Agglomerative Clustering ์ฌ์ฉ, ์๊ณ๊ฐ ์กฐ์ )
+ cluster_labels = agglomerative_clustering(reduced_embeddings, distance_threshold=21.5)
+ df_cat['cluster_label'] = cluster_labels
+
+ # 4) ํด๋ฌ์คํฐ๋ณ ์ ๋ ฌ ๋ฐ ํค์๋ ์์ฑ
+ sorted_clusters = get_sorted_clusters(df_cat, cluster_column="cluster_label")
+ id_keyword_map = hcx_generate_cluster_keywords(df_cat, sorted_clusters, text_column="review", cluster_column="cluster_label")
+
+ # 5) ์ถ์ฒ ๋์ ํด๋ฌ์คํฐ ์ ํ (๋ชจ๋ ํด๋ฌ์คํฐ ์ฌ์ฉ)
+ selected_clusters = sorted_clusters["cluster_label"].tolist()
+
+ # 6) ์ถ์ฒ ์ํ ์์ฑ
+ recommendation_df = generate_recommendations(df_cat, id_keyword_map, selected_clusters)
+ final_columns = ["์นดํ
๊ณ ๋ฆฌ", "ํค์๋", "ID", "์ํ๋ช
", "opinion ๊ฐ์"]
+ recommendation_df = recommendation_df[final_columns]
+
+ # 7) ๊ฐ ์นดํ
๊ณ ๋ฆฌ๋ณ CSV ํ์ผ ์ ์ฅ
+ output_file = os.path.join(config["paths"]["final_outputs_dir"], f"recommendation_{category}.csv")
+ recommendation_df.to_csv(output_file, index=False)
+ print(f"์นดํ
๊ณ ๋ฆฌ {category}์ ์ถ์ฒ ๊ฒฐ๊ณผ๊ฐ ์ ์ฅ๋์์ต๋๋ค: {output_file}")
+
+ all_recommendations.append(recommendation_df)
+ print(f"\n[์นดํ
๊ณ ๋ฆฌ: {category} ์ฒ๋ฆฌ ์๋ฃ]\n")
+
+ return all_recommendations
+
+if __name__ == "__main__":
+ # config๋ฅผ ํ์ฉํ์ฌ ๊ฒฝ๋ก ๋ฐ ํ์ผ๋ช
์ ๋ชจ๋ํ
+ config = {
+ "paths": {
+ "inference_dir": "./data/aste/inference",
+ "embedding_dir": "./data/embedding_matrics",
+ "final_outputs_dir": "../final_outputs"
+ },
+ "inference_data": "inferenced_reviews_snacks_meals_temp.csv"
+ }
+ run_keyword_recommendation(config)
diff --git a/models/review/src/review_pipeline/qwen_deepseek_14b_inference.py b/models/review/src/review_pipeline/qwen_deepseek_14b_inference.py
new file mode 100644
index 0000000..3a28274
--- /dev/null
+++ b/models/review/src/review_pipeline/qwen_deepseek_14b_inference.py
@@ -0,0 +1,151 @@
+import time
+import torch
+import numpy as np
+import pandas as pd
+from tqdm import tqdm
+from ast import literal_eval
+from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
+
+from utils.evaluate import evaluate_aste
+
+
+PROMPT = """๋น์ ์ ์ํ ๋ฆฌ๋ทฐ์ ๊ฐ์ฑ ๋ถ์ ๋ฐ ํ๊ฐ ์ ๋ฌธ๊ฐ์
๋๋ค. ์ฃผ์ด์ง ์ํ ๋ฆฌ๋ทฐ๋ฅผ ๋ถ์ํ์ฌ ํด๋น ๋ฆฌ๋ทฐ์์ ์์ฑ๊ณผ ํ๊ฐ, ๊ฐ์ฑ์ ์ถ์ถํ์ธ์.
+
+### ์์
๋ชฉํ:
+1. ์
๋ ฅ์ผ๋ก ์ฃผ์ด์ง๋ ์ํ ๋ฆฌ๋ทฐ๋ฅผ ๋ถ์ํ์ฌ JSON ํ์์ผ๋ก ์ถ๋ ฅํ์ธ์.
+2. JSON ํ์์ ์๋์ ๊ฐ์ด ๋ฆฌ์คํธ ํํ๋ก ๊ตฌ์ฑ๋ฉ๋๋ค.
+ ```json
+ [
+ {{"์์ฑ": "<์์ฑ๋ช
1>", "ํ๊ฐ": "<ํ๊ฐ ๋ด์ฉ1>", "๊ฐ์ ": "<๊ธ์ /๋ถ์ /์ค๋ฆฝ>"}},
+ {{"์์ฑ": "<์์ฑ๋ช
2>", "ํ๊ฐ": "<ํ๊ฐ ๋ด์ฉ2>", "๊ฐ์ ": "<๊ธ์ /๋ถ์ /์ค๋ฆฝ>"}},
+ ...
+ ]
+ ```
+
+์์ฑ์ ๋ค์ ์ค ํ๋์ผ ๊ฐ๋ฅ์ฑ์ด ๋์ต๋๋ค: ["์ํ", "๋ฐฐ์ก", "๊ฐ๊ฒฉ", "๋ง", "์ ์ ๋", "์", "ํฌ์ฅ"].
+๋ง์ฝ ์๋ก์ด ์์ฑ์ด ํ์ํ๋ฉด ์์ฑํ์ฌ ์ฌ์ฉ ๊ฐ๋ฅํฉ๋๋ค.
+๋ชจ๋ ์ํ๋ช
์ "์ํ"์ผ๋ก ํต์ผํฉ๋๋ค.
+
+### ์ธ๋ถ ๊ท์น:
+- ๊ฐ์ ๋ถ์
+ - ๋ฆฌ๋ทฐ์์ ๊ฐ์ ์ด ๊ธ์ ์ ์ธ ๊ฒฝ์ฐ "๊ฐ์ ": "๊ธ์ "์ผ๋ก ์ค์ ํฉ๋๋ค.
+ - ๋ถ์ ์ ์ธ ํํ์ด ํฌํจ๋ ๊ฒฝ์ฐ "๊ฐ์ ": "๋ถ์ "์ผ๋ก ์ค์ ํฉ๋๋ค.
+ - ํ๊ฐ๊ฐ ๋ชจํธํ๊ฑฐ๋ ๊ฐ์ ์ ์ธ ๊ฒฝ์ฐ "๊ฐ์ ": "์ค๋ฆฝ"์ผ๋ก ์ค์ ํฉ๋๋ค.
+
+- ํ๊ฐ ๋ฌธ๊ตฌ ์ ์
+ - ๋ฆฌ๋ทฐ์์ ๋ํ๋ ์ฃผ์ ํ๊ฐ๋ฅผ ๊ฐ๊ฒฐํ ๋ฌธ์ฅ์ผ๋ก ๋ณํํฉ๋๋ค.
+ - ํต์ฌ ํค์๋๋ฅผ ์ ์งํ๋ฉด์ ๋ถํ์ํ ํํ์ ์ ๊ฑฐํฉ๋๋ค.
+ - ํ๊ฐ ๋ฌธ๊ตฌ๋ '~๋ค.'๋ก ๋๋๋ ํ์ฌํ, ํ์๋ฌธ์ผ๋ก ๋ต๋ณํฉ๋๋ค. ์๋ฅผ ๋ค์ด, '์ข์ต๋๋ค.' ๊ฐ ์๋ '์ข๋ค.' ๋ก ์์ฑํฉ๋๋ค.
+
+- ์์ธ ์ฌํญ
+ - ์ํ ์ฌ์ฉ ํ๊ธฐ๊ฐ ์๋ ์ํ์ ๋ํ ์์์ด๋ ๊ธฐ๋ํ๋ ๋ถ๋ถ์ ๋ถ๋ฆฌํ์ฌ ์ ์ธํฉ๋๋ค.
+ - ๋ณตํฉ์ ์ธ ํ๊ฐ๊ฐ ์กด์ฌํ๋ ๊ฒฝ์ฐ ํด๋น ๋ด์ฉ์ ๋ถ๋ฆฌํ์ฌ ๊ฐ๊ฐ JSON ํญ๋ชฉ์ผ๋ก ์์ฑํฉ๋๋ค. ์๋ฅผ ๋ค์ด, '๋ฐฐ์ก์ด ์์ ํ๊ณ ๋นจ๋์ด์'์ ๊ฒฝ์ฐ '์์ ํ๋ค.' ์ '๋น ๋ฅด๋ค.' ๋ ๊ฐ์ง๋ก ๊ตฌ๋ถํฉ๋๋ค.
+
+### ์
๋ ฅ:
+{review}
+
+"""
+
+
+def prepare_data(path):
+ data_df = pd.read_csv(path)
+ return data_df
+
+def extract_aste(model_size, quant_type, data_df, col_name):
+ """
+ model_size: 14 or 8
+ quant_type: 8 or 4
+ data_df: DataFrame
+ """
+
+ # model_name = f"deepseek-ai/DeepSeek-R1-Distill-Qwen-{model_size}B"
+ model_name = "deepseek_14b_custom_eval"
+
+ if not quant_type:
+ bnb_config = None
+ elif quant_type == 4:
+ bnb_config = BitsAndBytesConfig(
+ load_in_4bit=True,
+ bnb_4bit_compute_dtype=torch.float32
+ )
+ else:
+ bnb_config = BitsAndBytesConfig(
+ load_in_8bit=True
+ )
+
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
+ model = AutoModelForCausalLM.from_pretrained(
+ model_name,
+ quantization_config=bnb_config,
+ device_map="auto"
+ )
+
+ total_time = 0
+ for idx, row in tqdm(data_df[pd.isna(data_df[col_name])].iterrows()):
+ txt = row["processed"]
+ # answer = row["aste_golden_label"]
+
+ input_text = PROMPT.format(review=txt)
+ print(row["review-ID"], txt)
+
+ start_time = time.time()
+ inputs = tokenizer(input_text, return_tensors="pt").to("cuda")
+ output = model.generate(**inputs,
+ max_new_tokens=512,
+ temperature=0.6,
+ top_p=0.95,
+ do_sample=True)
+ output = tokenizer.decode(output[0])
+
+ try:
+ aste = str(literal_eval(output.split("")[1].split("```json")[1].split("```<")[0]))
+ # thinking = output.split("
")[1].split("")[0]
+ except Exception as e:
+ aste = np.nan
+ # thinking = output.split("
")[1]
+
+ spent = time.time() - start_time
+
+ # print(f"\nreasoning: \n{thinking}")
+ print(f"\naste: {aste}")
+ # print(f"\nanswer: {answer}")
+ print(f"\n{spent}")
+ print("\n=============================================================\n")
+
+ total_time += spent
+ data_df.loc[idx, col_name] = aste
+ # data_df.loc[idx, "thinking"] = thinking
+
+ data_df.to_csv("./data/aste/inference/deepseek_14b_inference.csv", index=False)
+ print(f"total time: {total_time / len(data_df)}")
+
+
+if __name__ == "__main__":
+ # data_path = "processed_except_GL.csv" #
+ data_path = "./data/aste/eval/aste_annotation_100_golden_label.csv"
+ col_name = "inference"
+
+ data_df = prepare_data(data_path)
+ data_df[col_name] = None
+
+ start_time = time.time()
+ extract_aste(14, 4, data_df, col_name)
+ first_lap = time.time() - start_time
+
+ num_null = 0
+ while data_df[col_name].isna().sum() > 0:
+ num_null += data_df[col_name].isna().sum()
+ extract_aste(14, 4, data_df, col_name)
+ end_time = time.time() - start_time
+
+ print(f"\nํ ๋ฐํด: {first_lap}์ด, ์ด: {end_time}์ด, ํ๊ท : {end_time / 100}์ด")
+ print(f"์ด {num_null}๊ฐ์ ์ถ๋ก ์คํจ ํ ์ฌ์๋")
+
+ print("\n=== Start Evaluation ===\n")
+
+ evaluate_aste(
+ data_df,
+ golden_label_col="aste_golden_label",
+ model_prediction_col=col_name,
+ )
+
\ No newline at end of file
diff --git a/models/review/src/review_pipeline/qwen_deepseek_32b_inference.py b/models/review/src/review_pipeline/qwen_deepseek_32b_inference.py
new file mode 100644
index 0000000..fb8f5cb
--- /dev/null
+++ b/models/review/src/review_pipeline/qwen_deepseek_32b_inference.py
@@ -0,0 +1,241 @@
+import argparse
+import pandas as pd
+import time
+import re
+import json
+from tqdm import tqdm
+
+from unsloth import FastLanguageModel
+from utils.evaluate import evaluate_aste # ํ๊ฐ ํจ์ ์ํฌํธ (๊ฒฝ๋ก์ ๋ง๊ฒ ์์ )
+
+
+def load_model():
+ """๋ชจ๋ธ๊ณผ ํ ํฌ๋์ด์ ๋ฅผ ๋ก๋ํฉ๋๋ค."""
+ model, tokenizer = FastLanguageModel.from_pretrained(
+ model_name='./output_zeroshot/checkpoint-336',
+ max_seq_length=2048,
+ dtype=None,
+ load_in_4bit=True,
+ temperature=0.6
+ )
+ FastLanguageModel.for_inference(model)
+ return model, tokenizer
+
+
+PROMPT_TEMPLATE = """๋น์ ์ ์ํ ๋ฆฌ๋ทฐ์ ๊ฐ์ฑ ๋ถ์ ๋ฐ ํ๊ฐ ์ ๋ฌธ๊ฐ์
๋๋ค. ์ฃผ์ด์ง ์ํ ๋ฆฌ๋ทฐ๋ฅผ ๋ถ์ํ์ฌ ํด๋น ๋ฆฌ๋ทฐ์์ ์์ฑ๊ณผ ํ๊ฐ, ๊ฐ์ฑ์ ์ถ์ถํ์ธ์.
+
+### ์์
๋ชฉํ:
+1. ์
๋ ฅ์ผ๋ก ์ฃผ์ด์ง๋ ์ํ ๋ฆฌ๋ทฐ๋ฅผ ๋ถ์ํ์ฌ JSON ํ์์ผ๋ก ์ถ๋ ฅํ์ธ์.
+2. JSON ํ์์ ์๋์ ๊ฐ์ด ๋ฆฌ์คํธ ํํ๋ก ๊ตฌ์ฑ๋ฉ๋๋ค.
+ ```json
+ [
+ {{"์์ฑ": "<์์ฑ๋ช
1>", "ํ๊ฐ": "<ํ๊ฐ ๋ด์ฉ1>", "๊ฐ์ ": "<๊ธ์ /๋ถ์ /์ค๋ฆฝ>"}},
+ {{"์์ฑ": "<์์ฑ๋ช
2>", "ํ๊ฐ": "<ํ๊ฐ ๋ด์ฉ2>", "๊ฐ์ ": "<๊ธ์ /๋ถ์ /์ค๋ฆฝ>"}},
+ ...
+ ]
+ ```
+
+์์ฑ์ ๋ค์ ์ค ํ๋์ผ ๊ฐ๋ฅ์ฑ์ด ๋์ต๋๋ค: ["์ํ", "๋ฐฐ์ก", "๊ฐ๊ฒฉ", "๋ง", "์ ์ ๋", "์", "ํฌ์ฅ"].
+๋ง์ฝ ์๋ก์ด ์์ฑ์ด ํ์ํ๋ฉด ์์ฑํ์ฌ ์ฌ์ฉ ๊ฐ๋ฅํฉ๋๋ค.
+๋ชจ๋ ์ํ๋ช
์ "์ํ"์ผ๋ก ํต์ผํฉ๋๋ค.
+
+### ์ธ๋ถ ๊ท์น:
+- ๊ฐ์ ๋ถ์
+ - ๋ฆฌ๋ทฐ์์ ๊ฐ์ ์ด ๊ธ์ ์ ์ธ ๊ฒฝ์ฐ "๊ฐ์ ": "๊ธ์ "์ผ๋ก ์ค์ ํฉ๋๋ค.
+ - ๋ถ์ ์ ์ธ ํํ์ด ํฌํจ๋ ๊ฒฝ์ฐ "๊ฐ์ ": "๋ถ์ "์ผ๋ก ์ค์ ํฉ๋๋ค.
+ - ํ๊ฐ๊ฐ ๋ชจํธํ๊ฑฐ๋ ๊ฐ์ ์ ์ธ ๊ฒฝ์ฐ "๊ฐ์ ": "์ค๋ฆฝ"์ผ๋ก ์ค์ ํฉ๋๋ค.
+
+- ํ๊ฐ ๋ฌธ๊ตฌ ์ ์
+ - ๋ฆฌ๋ทฐ์์ ๋ํ๋ ์ฃผ์ ํ๊ฐ๋ฅผ ๊ฐ๊ฒฐํ ๋ฌธ์ฅ์ผ๋ก ๋ณํํฉ๋๋ค.
+ - ํต์ฌ ํค์๋๋ฅผ ์ ์งํ๋ฉด์ ๋ถํ์ํ ํํ์ ์ ๊ฑฐํฉ๋๋ค.
+ - ํ๊ฐ ๋ฌธ๊ตฌ๋ '~๋ค.'๋ก ๋๋๋ ํ์ฌํ, ํ์๋ฌธ์ผ๋ก ๋ต๋ณํฉ๋๋ค. ์๋ฅผ ๋ค์ด, '์ข์ต๋๋ค.' ๊ฐ ์๋ '์ข๋ค.' ๋ก ์์ฑํฉ๋๋ค.
+
+- ์์ธ ์ฌํญ
+ - ์ํ ์ฌ์ฉ ํ๊ธฐ๊ฐ ์๋ ์ํ์ ๋ํ ์์์ด๋ ๊ธฐ๋ํ๋ ๋ถ๋ถ์ ๋ถ๋ฆฌํ์ฌ ์ ์ธํฉ๋๋ค.
+ - ๋ณตํฉ์ ์ธ ํ๊ฐ๊ฐ ์กด์ฌํ๋ ๊ฒฝ์ฐ ํด๋น ๋ด์ฉ์ ๋ถ๋ฆฌํ์ฌ ๊ฐ๊ฐ JSON ํญ๋ชฉ์ผ๋ก ์์ฑํฉ๋๋ค. ์๋ฅผ ๋ค์ด, '๋ฐฐ์ก์ด ์์ ํ๊ณ ๋นจ๋์ด์'์ ๊ฒฝ์ฐ '์์ ํ๋ค.' ์ '๋น ๋ฅด๋ค.' ๋ ๊ฐ์ง๋ก ๊ตฌ๋ถํฉ๋๋ค.
+
+### ์์:
+์์๋ฅผ ์ฐธ๊ณ ํ์ฌ ๋
ผ๋ฆฌ์ ์ผ๋ก ์ฌ๊ณ ํ๋, ๋ถํ์ํ ๊ณผ์ ์ ์๋ตํ๊ณ ๊ฐ๊ฒฐํ๊ฒ ๋ต๋ณํ์ธ์.
+
+์์ 1.
+์ด๋ฒ์ ๋ง์ด ์ฃผ๋ฌธํ๋๋ฐ ์ ๋์ฐฉํ๊ณ ์ ๊ณ๋๋ ๊นจ์ง ๊ฑฐ ์์ด ์ ๋ฐ์์ด์ ์ ํต ๊ธฐํ๋ ๋๋ํ๊ณ ๋ง์กฑํฉ๋๋ค.
+[{{'์์ฑ': '๋ฐฐ์ก', 'ํ๊ฐ': '์ ๋์ฐฉํ๋ค.', '๊ฐ์ ': '๊ธ์ '}}, {{'์์ฑ': '๋ฐฐ์ก', 'ํ๊ฐ': '์์ ํ๋ค.', '๊ฐ์ ': '๊ธ์ '}}, {{'์์ฑ': '์ํ', 'ํ๊ฐ': '์ ํต๊ธฐํ์ด ๋๋ํ๋ค.', '๊ฐ์ ': '๊ธ์ '}}, {{'์์ฑ': '์ํ', 'ํ๊ฐ': '๋ง์กฑ์ค๋ฝ๋ค.', '๊ฐ์ ': '๊ธ์ '}}]
+
+์์ 2.
+๋ฐฐ์ก ๋น ๋ฅด๊ณ ์ข์ต๋๋ค.
+[{{'์์ฑ': '๋ฐฐ์ก', 'ํ๊ฐ': '๋น ๋ฅด๋ค.', '๊ฐ์ ': '๊ธ์ '}}, {{'์์ฑ': '๋ฐฐ์ก', 'ํ๊ฐ': '์ข๋ค.', '๊ฐ์ ': '๊ธ์ '}}]
+
+์์ 3.
+์๋ ์ ๋จน๋ ๊ฑฐ๋ผ ์ฌ๊ตฌ๋งค์
๋๋ค.
+[{{'์์ฑ': '์ํ', 'ํ๊ฐ': '์์ฃผ ๋จน๋๋ค.', '๊ฐ์ ': '๊ธ์ '}}, {{'์์ฑ': '์ํ', 'ํ๊ฐ': '์ฌ๊ตฌ๋งค๋ค.', '๊ฐ์ ': '๊ธ์ '}}]
+
+์์ 4.
+๋น๊ตํด ๋ณด๋ ค๊ณ ์ฌ ๋ดค์ด์. ์ ์ ํ๊ณ ๋ง์๋ค์.
+[{{"์์ฑ": "์ํ", "ํ๊ฐ": "๋น๊ตํด๋ณด๋ ค ๊ตฌ๋งคํ๋ค.", "๊ฐ์ ": "์ค๋ฆฝ"}}, {{'์์ฑ': '์ ์ ๋', 'ํ๊ฐ': '์ ์ ํ๋ค.', '๊ฐ์ ': '๊ธ์ '}}, {{'์์ฑ': '๋ง', 'ํ๊ฐ': '๋ง์๋ค.', '๊ฐ์ ': '๊ธ์ '}}]
+
+์์ 5.
+๋ฆฌ๋ทฐ ๋ณด๊ณ ์ฃผ๋ฌธํ๋๋ฐ ๊ธฐ๋๋๋ค์.
+[]
+
+์์ 6.
+๋
ธ๋ธ๋๋ ์ ๋ ดํ๊ณ ์ข์์.
+[{{'์์ฑ': '๊ฐ๊ฒฉ', 'ํ๊ฐ': '์ ๋ ดํ๋ค.', '๊ฐ์ ': '๊ธ์ '}}, {{'์์ฑ': '์ํ', 'ํ๊ฐ': '์ข๋ค.', '๊ฐ์ ': '๊ธ์ '}}]
+
+์์ 7.
+์น์ฆํผ์๋ฅผ ์ข์ํด์ ๋ง์๊ฒ ์ ๋จน์์ด์ ๋น๊ต์ ๋ํฐํด์ ๊ธ๋ฐฉ ๋ฐฐ ๋ถ๋ฌ์ ธ์ ๋ง์๊ณ ๋ฐฐ๋ถ๋ฅด๊ฒ ์ ๋จน์์ด์.
+[{{'์์ฑ': '๋ง', 'ํ๊ฐ': '๋ง์๋ค.', '๊ฐ์ ': '๊ธ์ '}}, {{'์์ฑ': '์', 'ํ๊ฐ': '๋ํฐํด์ ๊ธ๋ฐฉ ๋ฐฐ๋ถ๋ฅด๋ค.', '๊ฐ์ ': '๊ธ์ '}}, {{'์์ฑ': '์ํ', 'ํ๊ฐ': '์ ๋จน์๋ค.', '๊ฐ์ ': '๊ธ์ '}}]
+
+์์ 8.
+์ฐ์ ๋ ๋จน์ด๋ ๋ง์๊ณ ๊ทธ๋ฅ ๊ณผ์์ฒ๋ผ ๋จน์ด๋ ๋ง์์ด์.
+[{{'์์ฑ': '๋ง', 'ํ๊ฐ': '์ฐ์ ๋ ๋จน์ผ๋ฉด ๋ง์๋ค.', '๊ฐ์ ': '๊ธ์ '}}, {{'์์ฑ': '๋ง', 'ํ๊ฐ': '๊ทธ๋ฅ ๋จน์ด๋ ๋ง์๋ค.', '๊ฐ์ ': '๊ธ์ '}}]
+
+์์ 9.
+์์ฐ๊นก์ด ๋น๊ธธ ๋๊ฐ ์์ด์ ๊ฐ์ถ๋๋ฆฝ๋๋ค ์ฌ๊ตฌ๋งคํ๋ ค๊ณ ์ ์ ๋ ดํ๊ฒ ์ ์์ด์.
+[{{'์์ฑ': '๋ง', 'ํ๊ฐ': '๋น๊ธฐ๋ ๋ง์ด๋ค.', '๊ฐ์ ': '๊ธ์ '}}, {{'์์ฑ': '์ํ', 'ํ๊ฐ': '์ถ์ฒํ๋ค.', '๊ฐ์ ': '๊ธ์ '}}, {{'์์ฑ': '์ํ', 'ํ๊ฐ': '์ฌ๊ตฌ๋งค ์์ฌ ์๋ค.', '๊ฐ์ ': '๊ธ์ '}}, {{'์์ฑ': '๊ฐ๊ฒฉ', 'ํ๊ฐ': '์ ๋ ดํ๋ค.', '๊ฐ์ ': '๊ธ์ '}}]
+
+์์ 10.
+๋จ์ด์ง๋ฉด ๋ค์ ์ฑ์ ๋ฃ๊ณ ๋จน๊ณ ์์ด์ ๋๋ฌด ๋ง์์ด์.
+[{{'์์ฑ': '์ํ', 'ํ๊ฐ': 'ํญ์ ์ฌ๊ตฌ๋งคํ๋ค.', '๊ฐ์ ': '๊ธ์ '}}, {{'์์ฑ': '๋ง', 'ํ๊ฐ': '๋ง์๋ค.', '๊ฐ์ ': '๊ธ์ '}}]
+
+### ์
๋ ฅ:
+{review}
+
+"""
+
+
+def inference(review_text: str, model, tokenizer):
+ """
+ ์
๋ ฅ ๋ฆฌ๋ทฐ์ ๋ํด ๋ชจ๋ธ์ ํตํ ์ถ๋ก ์ ์ํํฉ๋๋ค.
+ ๋ฐํ: chain-of-thought(cot)์ ์ค์ ๋ต๋ณ(ans)
+ """
+ messages = [{"role": "user", "content": PROMPT_TEMPLATE.format(review=review_text)}]
+ print("์
๋ ฅ ๋ฉ์์ง:", messages)
+
+ # ์ฑํ
ํ
ํ๋ฆฟ ์ ์ฉ (๋ฌธ์์ด ์์ฑ)
+ formatted_text = tokenizer.apply_chat_template(
+ messages,
+ tokenize=False,
+ add_generation_prompt=True
+ )
+
+ model_inputs = tokenizer([formatted_text], return_tensors="pt").to(model.device)
+ generated_ids = model.generate(
+ **model_inputs,
+ max_new_tokens=512
+ )
+ # ์
๋ ฅ ํ ํฐ ๊ธธ์ด ์ดํ์ ํ ํฐ๋ง ์ถ์ถ
+ generated_ids = [
+ output_ids[len(input_ids):]
+ for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
+ ]
+ response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
+
+ # chain-of-thought์ ์ต์ข
๋ต๋ณ ๋ถ๋ฆฌ ( ํ๊ทธ ๊ธฐ์ค)
+ cot = response.split("")[0].strip()
+ ans = response.split("")[-1].strip()
+
+ return cot, ans
+
+def post_process_answer(answer: str):
+ """
+ ๋ชจ๋ธ์ ๋ต๋ณ์์ JSON ์ฝ๋ ๋ธ๋ก์ ์ถ์ถํ๊ณ ํ์ฑํฉ๋๋ค.
+ """
+ match = re.search(r"```json\s*(.*?)\s*```", answer, re.DOTALL)
+ if match:
+ json_str = match.group(1)
+ else:
+ json_str = answer # ์ฝ๋ ๋ธ๋ก์ด ์์ผ๋ฉด ์ ์ฒด ํ
์คํธ ์ฌ์ฉ
+
+ try:
+ data = json.loads(json_str)
+ except json.JSONDecodeError as e:
+ print("JSON ํ์ฑ ์๋ฌ:", e)
+ data = None
+ return data
+
+def run_inference_on_dataframe(df: pd.DataFrame, model, tokenizer, num_samples: int = None) -> pd.DataFrame:
+ """
+ DataFrame์ ๊ฐ ๋ฆฌ๋ทฐ์ ๋ํด ๋ชจ๋ธ ์ถ๋ก ์ ์ํํ๊ณ , ๊ฒฐ๊ณผ ๋ฐ ์์ ์๊ฐ์ ๊ธฐ๋กํฉ๋๋ค.
+ JSON ํ์ฑ์ ์คํจํ ๊ฒฝ์ฐ ์ต๋ 20ํ๊น์ง ์ฌ์๋ํ๋ฉฐ, ๊ฐ ๋ฆฌ๋ทฐ๋ณ ์คํจ ํ์๋ฅผ failure_counts ๋ฆฌ์คํธ์ ์ ์ฅํฉ๋๋ค.
+ """
+ times = []
+ failure_counts = [] # ๊ฐ ๋ฆฌ๋ทฐ๋ณ ์ฌ์๋(์คํจ) ํ์๋ฅผ ์ ์ฅํ ๋ฆฌ์คํธ
+ if num_samples is not None:
+ df = df.head(num_samples)
+
+ for idx, row in tqdm(df.iterrows(), total=df.shape[0], desc="Inference"):
+ review_text = row["processed"]
+ start_time = time.time()
+ attempts = 0 # ์ฌ์๋ ํ์ (์ฑ๊ณต ์ 0์ด๋ฉด ์ต์ด ์๋์ ์ฑ๊ณตํ ๊ฒ)
+ ans_json = None
+
+ # ์ต๋ 20ํ ์๋ (์ต์ด ์๋ ํฌํจํ์ฌ ์ต๋ 20๋ฒ)
+ while attempts < 20:
+ cot, ans = inference(review_text, model, tokenizer)
+ ans_json = post_process_answer(ans)
+ print(ans_json)
+ if ans_json is not None:
+ break
+ attempts += 1
+ print(f"JSON ์ถ์ถ ์คํจ, ์ฌ์๋ {attempts}ํ")
+
+ if ans_json is None:
+ print("์ต๋ ์ฌ์๋ ํ์(20ํ)๋ฅผ ์ด๊ณผํ์ต๋๋ค. ๊ฒฐ๊ณผ๋ฅผ None์ผ๋ก ์ ์ฅํฉ๋๋ค.")
+ elapsed = time.time() - start_time
+ times.append(elapsed)
+ failure_counts.append(attempts)
+ print(f"์ฒ๋ฆฌ ์๊ฐ: {elapsed:.2f}์ด, ์คํจ ํ์: {attempts}")
+ print("\n" + "=" * 60 + "\n")
+
+ # ๊ฒฐ๊ณผ๋ฅผ ์๋ก์ด ์ปฌ๋ผ์ ์ ์ฅ (JSON ๋ฌธ์์ด๋ก)
+ df.loc[idx, "unsloth_deepseek_32b"] = json.dumps(ans_json, ensure_ascii=False)
+
+ print("์ ์ฒด ์ฒ๋ฆฌ ์๊ฐ ๋ฆฌ์คํธ:", times)
+ if times:
+ print("ํ๊ท ์ฒ๋ฆฌ ์๊ฐ: {:.2f}์ด".format(sum(times) / len(times)))
+ print("๊ฐ ๋ฆฌ๋ทฐ๋ณ ์ฌ์๋ ์คํจ ํ์:", failure_counts)
+ return df
+
+def main():
+ parser = argparse.ArgumentParser(description="aste ๋ชจ๋ธ ์ถ๋ก ๋ฐ ํ๊ฐ ์คํฌ๋ฆฝํธ")
+ parser.add_argument("--input_csv", type=str, required=True,
+ help="์
๋ ฅ CSV ํ์ผ ๊ฒฝ๋ก (์: aste_annotation_100_gpt_after.csv)")
+ parser.add_argument("--output_csv", type=str, default=None,
+ help="์ถ๋ ฅ CSV ํ์ผ ๊ฒฝ๋ก (์ ์ฅํ ๊ฒฝ์ฐ)")
+ parser.add_argument("--num_samples", type=int, default=None,
+ help="์ถ๋ก ํ ์ํ ๊ฐ์ (ID ํํฐ๋ง ํ ์ง์ ๊ฐ๋ฅ)")
+ parser.add_argument("--selected_review_ids", type=str, nargs="*", default=None,
+ help="ํ๊ฐํ ํน์ review-ID ๋ชฉ๋ก (์: emart-118 emart-50 ...)")
+ args = parser.parse_args()
+
+ model, tokenizer = load_model()
+
+ df = pd.read_csv(args.input_csv)
+ print(f"์ด {len(df)}๊ฐ์ ๋ฆฌ๋ทฐ ๋ก๋๋จ.")
+
+ # ๋ง์ฝ ํน์ ID๋ค์ด ์ง์ ๋์๋ค๋ฉด ํํฐ๋ง
+ if args.selected_review_ids:
+ df = df[df["review-ID"].isin(args.selected_review_ids)]
+ print(f"์ ํ๋ ๋ฆฌ๋ทฐ: {df['review-ID'].tolist()}")
+ print(f"์ด {len(df)}๊ฐ์ ๋ฆฌ๋ทฐ๊ฐ ์ ํ๋จ.")
+
+ # ๋ง์ฝ num_samples ์ต์
์ด ์๋ค๋ฉด head()๋ฅผ ์ฌ์ฉ (์ ํ๋ ID์ ๊ฐ์๋ณด๋ค ์์ผ๋ฉด ์ ์ฒด๊ฐ ์ฌ์ฉ๋จ)
+ if args.num_samples is not None:
+ df = df.head(args.num_samples)
+
+ print("์ต์ข
๋ฐ์ดํฐ ์:", len(df))
+
+ df = run_inference_on_dataframe(df, model, tokenizer, num_samples=args.num_samples)
+
+ evaluate_aste(
+ df.head(args.num_samples),
+ golden_label_col="aste_golden_label",
+ model_prediction_col="unsloth_deepseek_32b"
+ )
+
+ if args.output_csv:
+ df.to_csv(args.output_csv, index=False)
+ print(f"๊ฒฐ๊ณผ๊ฐ {args.output_csv}์ ์ ์ฅ๋์์ต๋๋ค.")
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/models/review/src/review_pipeline/review_summarization.py b/models/review/src/review_pipeline/review_summarization.py
new file mode 100644
index 0000000..3c87c0e
--- /dev/null
+++ b/models/review/src/review_pipeline/review_summarization.py
@@ -0,0 +1,278 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+"""
+[๋ฆฌ๋ทฐ ์์ฝ ์ถ์ถ ํ์ดํ๋ผ์ธ]
+- ๋ฆฌ๋ทฐ ๋ฐ์ดํฐ๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ๋ ๊ฐ์ aspect๋ณ ์์ฝ์ ์์ฑ:
+ 1. "๋ง" (๋จ๋
)
+ 2. "๋ฐฐ์ก"๊ณผ "ํฌ์ฅ"์ ํตํฉํ "๋ฐฐ์ก ๋ฐ ํฌ์ฅ"
+- ๊ฐ ์ํ์ ์๋ณธ ๋ฆฌ๋ทฐ๋ฅผ ํ์ฉํ์ฌ ๊ฐ aspect๋ณ ๊ธ์ /๋ถ์ ํต์ฌ ํฌ์ธํธ๋ฅผ ์์ฝํ๊ณ ,
+ ๊ฐ aspect๋ณ ๊ณ ์ ๋ฆฌ๋ทฐ ๊ฐ์๋ฅผ ์ฐ์ถํ์ฌ ์ต์ข
CSV ํ์ผ๋ก ์ ์ฅํ๋ค.
+"""
+
+import os
+import sys
+import time
+import json
+import requests
+import pandas as pd
+from dotenv import load_dotenv
+from prompt.prompt_loader import load_prompt, load_fewshot
+import re
+
+def load_data(file_path):
+ df = pd.read_csv(file_path)
+ print("๋ฐ์ดํฐ ๋ก๋:", file_path)
+ return df
+
+def expand_inference_data(df, json_column="unsloth_deepseek_32b"):
+ expanded = []
+ for idx, row in df.iterrows():
+ raw_value = row[json_column]
+
+ # ๋ง์ฝ raw_value๊ฐ ๋ฌธ์์ด์ด๋ฉด json.loads()๋ฅผ ์ฌ์ฉ, ์ด๋ฏธ ๋ฆฌ์คํธ๋ฉด ๊ทธ๋๋ก ์ฌ์ฉ
+ if isinstance(raw_value, str):
+ try:
+ parsed = json.loads(raw_value)
+ except json.JSONDecodeError:
+ print(f"JSON ํ์ฑ ์๋ฌ, review-ID: {row.get('review-ID', 'N/A')}")
+ continue
+ elif isinstance(raw_value, list):
+ parsed = raw_value
+ else:
+ # ๊ทธ ์ธ ํ์
์ธ ๊ฒฝ์ฐ ๋ฌด์
+ continue
+
+ if isinstance(parsed, list):
+ for item in parsed:
+ new_row = row.copy()
+ new_row["aspect"] = item.get("์์ฑ", None)
+ new_row["opinion"] = item.get("ํ๊ฐ", None)
+ new_row["sentiment"] = item.get("๊ฐ์ ", None)
+ expanded.append(new_row)
+ expanded_df = pd.DataFrame(expanded)
+ expanded_df.reset_index(drop=True, inplace=True)
+ expanded_df.ffill(inplace=True)
+ return expanded_df
+
+def filter_invalid_value(raw_value):
+ """
+ unsloth_deepseek_32b ์นผ๋ผ์ ํ ๊ฐ์ ์ ๋ฌ๋ฐ์,
+ - float(NaN) ๋๋ "[]" ์ธ ๊ฒฝ์ฐ None ๋ฐํ
+ - ์ ์์ ์ธ JSON ๋ฌธ์์ด์ด๋ฉด ํ์ฑํด์ ๋ฐํ, ๊ทธ๋ ์ง ์์ผ๋ฉด None ๋ฐํ
+ """
+ if isinstance(raw_value, float) or (isinstance(raw_value, str) and raw_value.strip() == "[]"):
+ return None
+ try:
+ parsed = json.loads(raw_value)
+ if isinstance(parsed, list):
+ return parsed
+ else:
+ return None
+ except json.JSONDecodeError:
+ return None
+
+def load_and_prepare_data(config):
+ train_data_path = os.path.join(config["paths"]["inference_dir"], config["inference_data"])
+ df_infer = load_data(train_data_path)
+ df_infer["unsloth_deepseek_32b"] = df_infer["unsloth_deepseek_32b"].apply(filter_invalid_value)
+ df_infer = df_infer.dropna(subset=["unsloth_deepseek_32b"])
+ aste_df = expand_inference_data(df_infer, "unsloth_deepseek_32b")
+ # ํ์ ์๋ ์ด ์ ๊ฑฐ
+ cols_to_drop = ['aste_hcx', 'aste_gpt', 'aste_golden_label']
+ aste_df.drop(columns=cols_to_drop, errors='ignore', inplace=True)
+ print(f"๋ฆฌ๋ทฐ Opinion ๋ฐ์ดํฐ ๊ฐ์: {aste_df.shape[0]}")
+ # FILTER_KEYWORDS = "๋ง|์ข"
+ # aste_df = aste_df[~aste_df["opinion"].str.contains(FILTER_KEYWORDS)]
+ # print(f"ํํฐ๋ง ํ ๋ฆฌ๋ทฐ ๋ฐ์ดํฐ: {aste_df.shape[0]}")
+ return aste_df
+
+def get_prompt_and_fewshot(aspect: str, sentiment: str):
+ """
+ aspect์ sentiment์ ๋ฐ๋ฅธ ํ๋กฌํํธ์ few-shot ์์๋ฅผ ๋ฐํํ๋ค.
+ ํ๋กฌํํธ์ few-shot ์์๋ prompt ํด๋ ๋ด์ ํ์ผ๋ก ๋ถ๋ฆฌ๋์ด ๊ด๋ฆฌ๋๋ค.
+ """
+ if sentiment == "๊ธ์ ":
+ prompt = load_prompt(prompt_filename="positive_prompt.txt",
+ prompt_dir="./prompt/review_summarization/")
+ fewshot = load_fewshot(fewshot_filename="positive_fewshot.json",
+ prompt_dir="./prompt/review_summarization/")
+ else:
+ prompt = load_prompt(prompt_filename="negative_prompt.txt",
+ prompt_dir="./prompt/review_summarization/")
+ fewshot = load_fewshot(fewshot_filename="negative_fewshot.json",
+ prompt_dir="./prompt/review_summarization/")
+ return prompt, fewshot
+
+def inference(query: str, sentiment: str, aspect: str) -> str:
+ load_dotenv(os.path.expanduser("~/.env"))
+ AUTHORIZATION = os.getenv("AUTHORIZATION")
+ X_NCP_CLOVASTUDIO_REQUEST_ID = os.getenv("X_NCP_CLOVASTUDIO_REQUEST_ID")
+ if not AUTHORIZATION or not X_NCP_CLOVASTUDIO_REQUEST_ID:
+ raise ValueError("ํ์ API ์ธ์ฆ ์ ๋ณด๊ฐ .env์ ์ค์ ๋์ด ์์ง ์์ต๋๋ค.")
+ headers = {
+ 'Authorization': AUTHORIZATION,
+ 'X-NCP-CLOVASTUDIO-REQUEST-ID': X_NCP_CLOVASTUDIO_REQUEST_ID,
+ 'Content-Type': 'application/json; charset=utf-8',
+ }
+ prompt, fewshot = get_prompt_and_fewshot(aspect, sentiment)
+ # query๊ฐ ๋ฌธ์์ด์ธ๋ฐ ๋ฆฌ์คํธ ํํ์ ํํ์ด๋ฉด join
+ if query.startswith("[") and query.endswith("]"):
+ try:
+ import ast
+ q_list = ast.literal_eval(query)
+ if isinstance(q_list, list):
+ query = " ".join(q_list)
+ except Exception as e:
+ print("Query parsing error:", e)
+
+ messages = [{"role": "system", "content": prompt}]
+ for example in fewshot:
+ # ๋ง์ฝ fewshot์ query๊ฐ ๋ฆฌ์คํธ๋ฉด join
+ query_text = " ".join(example["query"]) if isinstance(example["query"], list) else example["query"]
+ messages.append({"role": "user", "content": query_text})
+ messages.append({"role": "assistant", "content": example["answer"]})
+ messages.append({"role": "user", "content": query})
+
+ request_data = {
+ 'messages': messages,
+ 'topP': 0.8,
+ 'topK': 0,
+ 'maxTokens': 1024,
+ 'temperature': 0.5,
+ 'repeatPenalty': 5.0,
+ 'stopBefore': [],
+ 'includeAiFilters': False,
+ 'seed': 42
+ }
+ try:
+ response = requests.post(
+ "https://clovastudio.stream.ntruss.com/testapp/v1/chat-completions/HCX-003",
+ headers=headers, json=request_data
+ )
+ response.raise_for_status()
+ response_json = response.json()
+ if response_json.get("status", {}).get("code") == "20000":
+ output_text = response_json["result"]["message"]["content"]
+ return output_text
+ else:
+ return f"API Error: {response_json.get('status', {}).get('message')}"
+ except requests.exceptions.RequestException as e:
+ return f"Request Error: {e}"
+
+def robust_inference(query: str, sentiment: str, aspect: str, retry_delay: int = 2) -> str:
+ while True:
+ result = inference(query, sentiment, aspect)
+ if not (result.startswith("API Error") or result.startswith("Request Error")):
+ print(result)
+ return result
+ print("API ์ค๋ฅ ๋ฐ์, ๋ค์ ์๋ํฉ๋๋ค...")
+ time.sleep(retry_delay)
+
+def summarize_opinions_with_original(reviews: list, sentiment: str, aspect: str) -> tuple:
+ unique_reviews = list(set(reviews))
+ sample_reviews = unique_reviews[:20]
+ if not sample_reviews:
+ return "์์ต๋๋ค.", "์์ต๋๋ค."
+ reviews_str = str(sample_reviews)
+ summary = robust_inference(reviews_str, sentiment, aspect)
+ summary = summary.strip()[1:-1] if summary.startswith('"') and summary.endswith('"') else summary
+ return summary, reviews_str
+
+def process_product(product_id: str, aste_df: pd.DataFrame) -> dict:
+ SENTIMENT_POSITIVE = "๊ธ์ "
+ SENTIMENT_NEGATIVE = "๋ถ์ "
+ product_reviews = aste_df[aste_df["review-ID"].str.contains(product_id)]
+ product_name = product_reviews["name"].dropna().unique()[0]
+ product_dict = {
+ "ID": product_id,
+ "์ํ๋ช
": product_name,
+ }
+ # ๋ง ๋จ๋
์ฒ๋ฆฌ
+ pos_reviews_m = product_reviews[(product_reviews["aspect"] == "๋ง") & (product_reviews["sentiment"] == SENTIMENT_POSITIVE)]["review"].tolist()
+ neg_reviews_m = product_reviews[(product_reviews["aspect"] == "๋ง") & (product_reviews["sentiment"] != SENTIMENT_POSITIVE)]["review"].tolist()
+ summary_pos_m, _ = summarize_opinions_with_original(pos_reviews_m, SENTIMENT_POSITIVE, "๋ง")
+ summary_neg_m, _ = summarize_opinions_with_original(neg_reviews_m, SENTIMENT_NEGATIVE, "๋ง")
+ product_dict["๋ง-๊ธ์ "] = summary_pos_m
+ product_dict["๋ง-๋ถ์ "] = summary_neg_m
+
+ # ๋ฐฐ์ก ๋ฐ ํฌ์ฅ ํตํฉ ์ฒ๋ฆฌ: ๋ฐฐ์ก๊ณผ ํฌ์ฅ ๋ ๋ค ํฌํจ
+ pos_reviews_dp = product_reviews[((product_reviews["aspect"] == "๋ฐฐ์ก") | (product_reviews["aspect"] == "ํฌ์ฅ")) &
+ (product_reviews["sentiment"] == SENTIMENT_POSITIVE)]["review"].tolist()
+ neg_reviews_dp = product_reviews[((product_reviews["aspect"] == "๋ฐฐ์ก") | (product_reviews["aspect"] == "ํฌ์ฅ")) &
+ (product_reviews["sentiment"] != SENTIMENT_POSITIVE)]["review"].tolist()
+ summary_pos_dp, _ = summarize_opinions_with_original(pos_reviews_dp, SENTIMENT_POSITIVE, "๋ฐฐ์ก ๋ฐ ํฌ์ฅ")
+ summary_neg_dp, _ = summarize_opinions_with_original(neg_reviews_dp, SENTIMENT_NEGATIVE, "๋ฐฐ์ก ๋ฐ ํฌ์ฅ")
+ product_dict["๋ฐฐ์ก ๋ฐ ํฌ์ฅ-๊ธ์ "] = summary_pos_dp
+ product_dict["๋ฐฐ์ก ๋ฐ ํฌ์ฅ-๋ถ์ "] = summary_neg_dp
+
+ return product_dict
+
+def update_summary_counts(summary_df: pd.DataFrame, aste_df: pd.DataFrame) -> pd.DataFrame:
+ summary_df["num ๋ง-๊ธ์ "] = None
+ summary_df["num ๋ง-๋ถ์ "] = None
+ summary_df["num ๋ฐฐ์ก ๋ฐ ํฌ์ฅ-๊ธ์ "] = None
+ summary_df["num ๋ฐฐ์ก ๋ฐ ํฌ์ฅ-๋ถ์ "] = None
+
+ for idx, row in summary_df.iterrows():
+ product_id = row["ID"]
+ product_reviews = aste_df[aste_df["review-ID"].str.contains(product_id)]
+ num_m_pos = len(product_reviews[(product_reviews["aspect"] == "๋ง") & (product_reviews["sentiment"] == "๊ธ์ ")]["review"].unique())
+ num_m_neg = len(product_reviews[(product_reviews["aspect"] == "๋ง") & (product_reviews["sentiment"] != "๊ธ์ ")]["review"].unique())
+ num_dp_pos = len(product_reviews[((product_reviews["aspect"] == "๋ฐฐ์ก") | (product_reviews["aspect"] == "ํฌ์ฅ")) & (product_reviews["sentiment"] == "๊ธ์ ")]["review"].unique())
+ num_dp_neg = len(product_reviews[((product_reviews["aspect"] == "๋ฐฐ์ก") | (product_reviews["aspect"] == "ํฌ์ฅ")) & (product_reviews["sentiment"] != "๊ธ์ ")]["review"].unique())
+
+ summary_df.at[idx, "num ๋ง-๊ธ์ "] = num_m_pos
+ summary_df.at[idx, "num ๋ง-๋ถ์ "] = num_m_neg
+ summary_df.at[idx, "num ๋ฐฐ์ก ๋ฐ ํฌ์ฅ-๊ธ์ "] = num_dp_pos
+ summary_df.at[idx, "num ๋ฐฐ์ก ๋ฐ ํฌ์ฅ-๋ถ์ "] = num_dp_neg
+
+ return summary_df
+
+def extract_product_id(review_id):
+ match = re.match(r"(emart-\d+)-\d+", review_id)
+ return match.group(1) if match else review_id # ๋งค์นญ ์คํจ ์ ์๋ ๊ฐ ๋ฐํ
+
+def load_and_prepare_data(config):
+ train_data_path = os.path.join(config["paths"]["inference_dir"], config["inference_data"])
+ df_infer = load_data(train_data_path)
+ df_infer["unsloth_deepseek_32b"] = df_infer["unsloth_deepseek_32b"].apply(filter_invalid_value)
+ df_infer = df_infer.dropna(subset=["unsloth_deepseek_32b"])
+ aste_df = expand_inference_data(df_infer, "unsloth_deepseek_32b")
+ cols_to_drop = ['aste_hcx', 'aste_gpt', 'aste_golden_label']
+ aste_df.drop(columns=cols_to_drop, errors='ignore', inplace=True)
+ print(f"๋ฆฌ๋ทฐ Opinion ๋ฐ์ดํฐ ๊ฐ์: {aste_df.shape[0]}")
+ print(f"ํํฐ๋ง ํ ๋ฆฌ๋ทฐ ๋ฐ์ดํฐ: {aste_df.shape[0]}")
+ return aste_df
+
+
+def run_review_summarization(config):
+ print("\n[๋ฆฌ๋ทฐ ์์ฝ ์ถ์ถ ์์]\n")
+ aste_df = load_and_prepare_data(config)
+ product_ids = list({extract_product_id(item) for item in aste_df["review-ID"].tolist()})
+ print(f"์ฒ๋ฆฌํ ์ํ ์: {len(product_ids)}")
+ summary_list = []
+ for prod_id in product_ids:
+ prod_summary = process_product(prod_id, aste_df)
+ summary_list.append(prod_summary)
+ summary_df = pd.DataFrame(summary_list)
+ summary_df = update_summary_counts(summary_df, aste_df)
+ final_columns = ["ID", "์ํ๋ช
", "๋ง-๊ธ์ ", "๋ง-๋ถ์ ", "๋ฐฐ์ก ๋ฐ ํฌ์ฅ-๊ธ์ ", "๋ฐฐ์ก ๋ฐ ํฌ์ฅ-๋ถ์ ",
+ "num ๋ง-๊ธ์ ", "num ๋ง-๋ถ์ ", "num ๋ฐฐ์ก ๋ฐ ํฌ์ฅ-๊ธ์ ", "num ๋ฐฐ์ก ๋ฐ ํฌ์ฅ-๋ถ์ "]
+ summary_df = summary_df[final_columns]
+ output_file = os.path.join(config["paths"]["final_outputs_dir"], "summarization.csv")
+ os.makedirs(os.path.dirname(output_file), exist_ok=True)
+ summary_df.to_csv(output_file, index=False)
+ print(f"์ต์ข
์์ฝ ๊ฒฐ๊ณผ๊ฐ ์ ์ฅ๋์์ต๋๋ค: {output_file}")
+ print("\n[๋ฆฌ๋ทฐ ์์ฝ ์ถ์ถ ์๋ฃ]\n")
+ return summary_df
+
+
+if __name__ == "__main__":
+ run_review_summarization({
+ "paths": {
+ "inference_dir": "./data/aste/inference",
+ "preprocessed_dir": "./data/preprocessed",
+ "final_outputs_dir": "../final_outputs"
+ },
+ "inference_data": "inferenced_reviews_snacks_meals_2.csv"
+ })
diff --git a/models/review/src/review_pipeline/visualization.py b/models/review/src/review_pipeline/visualization.py
new file mode 100644
index 0000000..94a712d
--- /dev/null
+++ b/models/review/src/review_pipeline/visualization.py
@@ -0,0 +1,114 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+"""
+[ํด๋ฌ์คํฐ๋ง ๊ฒฐ๊ณผ ์๊ฐํ ๋ฐ ์ ๋์ ํ๊ฐ ํ์ดํ๋ผ์ธ]
+- ์
๋ ฅ CSV ํ์ผ ๋ก๋ถํฐ ์ถ์ฒ ํค์๋ ๋ณ๋ก ์ํ์ ์ฌ์ ๋ ฌํ์ฌ ์ต์ข
์ถ์ฒ CSV ํ์ผ์ ์์ฑํ๋ค.
+- ์ต์ข
์ถ๋ ฅ CSV ํ์ผ์ ๋ค์ ์ด๋ก ๊ตฌ์ฑ๋๋ค:
+ ์นดํ
๊ณ ๋ฆฌ, ํค์๋, ID, ์ํ๋ช
, opinion ๊ฐ์
+- ๋ณธ ํ์ดํ๋ผ์ธ์ config๋ฅผ ํ์ฉํ๊ณ , prompt ํด๋์ ํ๋กฌํํธ ๋ฐ few-shot ์์๋ฅผ ์ฐธ๊ณ ํ์ฌ API ํธ์ถ์ ์ํํ๋ค.
+- run_keyword_recommendation() ํจ์๋ฅผ ํตํด ์ ์ฒด ํ์ดํ๋ผ์ธ์ ์คํํ๋ค.
+"""
+
+import os, sys
+import re
+import json
+import time
+import requests
+import pandas as pd
+import numpy as np
+import yaml
+
+project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
+if project_root not in sys.path:
+ sys.path.insert(0, project_root)
+
+from utils.utils import load_data, expand_inference_data, sentenceBERT_embeddings, umap_reduce_embeddings, agglomerative_clustering, visualize_clustering, evaluate_clustering
+
+#########################################################
+# ๋ฐ์ดํฐ ์ ์ฒ๋ฆฌ ๋ฐ ํ์ฅ ๊ด๋ จ ํจ์
+#########################################################
+
+def filter_invalid_value(raw_value):
+ """
+ unsloth_deepseek_32b ์นผ๋ผ์ ๊ฐ์ ์ ๋ฌ๋ฐ์,
+ - float(NaN) ๋๋ "[]"์ธ ๊ฒฝ์ฐ None ๋ฐํ
+ - ์ ์์ ์ธ JSON ๋ฌธ์์ด์ด๋ฉด ํ์ฑ ํ ๋ฆฌ์คํธ ๋ฐํ, ๊ทธ๋ ์ง ์์ผ๋ฉด None ๋ฐํ
+ """
+ if isinstance(raw_value, float) or (isinstance(raw_value, str) and raw_value.strip() == "[]"):
+ return None
+ try:
+ parsed = json.loads(raw_value)
+ if isinstance(parsed, list):
+ return parsed
+ else:
+ return None
+ except json.JSONDecodeError:
+ return None
+
+def load_and_prepare_data(config):
+ """
+ config์ ๋ช
์๋ ๊ฒฝ๋ก์ ํ์ผ๋ช
์ ํ์ฉํ์ฌ ์
๋ ฅ CSV ํ์ผ์ ๋ก๋ํ๊ณ ,
+ unsloth_deepseek_32b ์นผ๋ผ์ ๊ฐ์ ํ์ฑํ ํ ํ์ฅํ๋ค.
+ ๋ถํ์ํ ์ด์ ์ญ์ ํ๋ค.
+ """
+ train_data_path = os.path.join(config["paths"]["inference_dir"], config["inference_data"])
+ df_infer = load_data(train_data_path)
+ df_infer["unsloth_deepseek_32b"] = df_infer["unsloth_deepseek_32b"].apply(filter_invalid_value)
+ df_infer = df_infer.dropna(subset=["unsloth_deepseek_32b"])
+ aste_df = expand_inference_data(df_infer, "unsloth_deepseek_32b")
+ cols_to_drop = ['aste_hcx', 'aste_gpt', 'aste_golden_label']
+ aste_df.drop(columns=cols_to_drop, errors='ignore', inplace=True)
+ print(f"๋ฆฌ๋ทฐ Opinion ๋ฐ์ดํฐ ๊ฐ์: {aste_df.shape[0]}")
+ return aste_df
+
+
+#########################################################
+# ์ต์ข
ํ๊ฐ ์คํ ํจ์
+#########################################################
+
+print("\n[์ถ์ฒ ํค์๋ ๊ธฐ๋ฐ ์ํ ์ฌ์ ๋ ฌ ์๊ฐํ, ์ ๋์ ํ๊ฐ]\n")
+
+# config ํ์ผ์ config/config.yaml์์ ๋ก๋ (ํ์ผ ์์น: ํ๋ก์ ํธ ๋ฃจํธ/config/config.yaml)
+config_path = os.path.join(
+ os.path.dirname(os.path.abspath(__file__)),
+ "..", "..", "config", "config.yaml"
+)
+
+with open(config_path, "r", encoding="utf-8") as f:
+ config = yaml.safe_load(f)
+
+
+# 1. ๋ฐ์ดํฐ ๋ก๋ ๋ฐ ์ ์ฒ๋ฆฌ
+aste_df = load_and_prepare_data(config)
+
+# 2. ๊ธ์ ์๊ฒฌ๋ง ์ ํ (์ถ์ฒ ํค์๋๋ ๊ธ์ ๋ฆฌ๋ทฐ ๊ธฐ๋ฐ)
+infer_pos = aste_df[aste_df["sentiment"] == "๊ธ์ "]
+infer_pos.loc[:, 'category'] = infer_pos['category'].replace({'์์ด๊ฐ์': '๋ผ๋ฉด/๊ฐํธ์'})
+
+category_map = {
+ "๊ณผ์/๋น๊ณผ": "snacks",
+ "๋ผ๋ฉด/๊ฐํธ์": "meals"
+}
+
+infer_pos["category"] = infer_pos["category"].map(category_map).fillna(infer_pos["category"])
+
+# ์นดํ
๊ณ ๋ฆฌ๋ณ๋ก DataFrame ๋ถํ
+category_dfs = {category: df for category, df in infer_pos.groupby("category")}
+
+for category, df_cat in category_dfs.items():
+ # 3. ๋ฆฌ๋ทฐ ํ
์คํธ ์๋ฒ ๋ฉ ์์ฑ (opinion ์ด ์ฌ์ฉ)
+ embedding_file = os.path.join(config["paths"]["embedding_dir"], f"deepseek_inference_{category}.npy")
+ embedding_matrix = sentenceBERT_embeddings(embedding_file, df=df_cat, column="opinion")
+
+ # 4. UMAP ์ฐจ์ ์ถ์
+ reduced_embeddings = umap_reduce_embeddings(embedding_matrix, n_components=256)
+
+ # 5. ํด๋ฌ์คํฐ๋ง (Agglomerative Clustering ์ฌ์ฉ, ์๊ณ๊ฐ ์กฐ์ )
+ cluster_labels = agglomerative_clustering(reduced_embeddings, distance_threshold=21.5)
+ df_cat['cluster_label'] = cluster_labels
+
+ # 6. ํด๋ฌ์คํฐ ์๊ฐํ
+ visualize_clustering(reduced_embeddings, cluster_labels, config, category)
+
+ # 7. ํ๊ฐ Metric
+ evaluate_clustering(reduced_embeddings, cluster_labels, config, category)
diff --git a/models/review/src/sft_pipeline/__init__.py b/models/review/src/sft_pipeline/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/models/review/src/sft_pipeline/qwen_deepseek_14b_finetuning.py b/models/review/src/sft_pipeline/qwen_deepseek_14b_finetuning.py
new file mode 100644
index 0000000..c518c9f
--- /dev/null
+++ b/models/review/src/sft_pipeline/qwen_deepseek_14b_finetuning.py
@@ -0,0 +1,137 @@
+import torch
+import pandas as pd
+from datasets import Dataset
+from peft import get_peft_model, LoraConfig, TaskType
+from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, TrainerCallback, TrainingArguments, Trainer
+
+
+class Review_DeepSeekTrainer:
+ def __init__(self, model_path, train_path, eval_path, output_dir):
+ self.model_path = model_path
+ self.train_path = train_path,
+ self.eval_path = eval_path
+ self.output_dir = output_dir
+ self.PROMPT = """๋น์ ์ ์ํ ๋ฆฌ๋ทฐ์ ๊ฐ์ฑ ๋ถ์ ๋ฐ ํ๊ฐ ์ ๋ฌธ๊ฐ์
๋๋ค. ์ฃผ์ด์ง ์ํ ๋ฆฌ๋ทฐ๋ฅผ ๋ถ์ํ์ฌ ํด๋น ๋ฆฌ๋ทฐ์์ ์์ฑ๊ณผ ํ๊ฐ, ๊ฐ์ฑ์ ์ถ์ถํ์ธ์.
+
+### ์์
๋ชฉํ:
+1. ์
๋ ฅ์ผ๋ก ์ฃผ์ด์ง๋ ์ํ ๋ฆฌ๋ทฐ๋ฅผ ๋ถ์ํ์ฌ JSON ํ์์ผ๋ก ์ถ๋ ฅํ์ธ์.
+2. JSON ํ์์ ์๋์ ๊ฐ์ด ๋ฆฌ์คํธ ํํ๋ก ๊ตฌ์ฑ๋ฉ๋๋ค.
+ ```json
+ [
+ {{"์์ฑ": "<์์ฑ๋ช
1>", "ํ๊ฐ": "<ํ๊ฐ ๋ด์ฉ1>", "๊ฐ์ ": "<๊ธ์ /๋ถ์ /์ค๋ฆฝ>"}},
+ {{"์์ฑ": "<์์ฑ๋ช
2>", "ํ๊ฐ": "<ํ๊ฐ ๋ด์ฉ2>", "๊ฐ์ ": "<๊ธ์ /๋ถ์ /์ค๋ฆฝ>"}},
+ ...
+ ]
+ ```
+
+์์ฑ์ ๋ค์ ์ค ํ๋์ผ ๊ฐ๋ฅ์ฑ์ด ๋์ต๋๋ค: ["์ํ", "๋ฐฐ์ก", "๊ฐ๊ฒฉ", "๋ง", "์ ์ ๋", "์", "ํฌ์ฅ"].
+๋ง์ฝ ์๋ก์ด ์์ฑ์ด ํ์ํ๋ฉด ์์ฑํ์ฌ ์ฌ์ฉ ๊ฐ๋ฅํฉ๋๋ค.
+๋ชจ๋ ์ํ๋ช
์ "์ํ"์ผ๋ก ํต์ผํฉ๋๋ค.
+
+### ์ธ๋ถ ๊ท์น:
+- ๊ฐ์ ๋ถ์
+ - ๋ฆฌ๋ทฐ์์ ๊ฐ์ ์ด ๊ธ์ ์ ์ธ ๊ฒฝ์ฐ "๊ฐ์ ": "๊ธ์ "์ผ๋ก ์ค์ ํฉ๋๋ค.
+ - ๋ถ์ ์ ์ธ ํํ์ด ํฌํจ๋ ๊ฒฝ์ฐ "๊ฐ์ ": "๋ถ์ "์ผ๋ก ์ค์ ํฉ๋๋ค.
+ - ํ๊ฐ๊ฐ ๋ชจํธํ๊ฑฐ๋ ๊ฐ์ ์ ์ธ ๊ฒฝ์ฐ "๊ฐ์ ": "์ค๋ฆฝ"์ผ๋ก ์ค์ ํฉ๋๋ค.
+
+- ํ๊ฐ ๋ฌธ๊ตฌ ์ ์
+ - ๋ฆฌ๋ทฐ์์ ๋ํ๋ ์ฃผ์ ํ๊ฐ๋ฅผ ๊ฐ๊ฒฐํ ๋ฌธ์ฅ์ผ๋ก ๋ณํํฉ๋๋ค.
+ - ํต์ฌ ํค์๋๋ฅผ ์ ์งํ๋ฉด์ ๋ถํ์ํ ํํ์ ์ ๊ฑฐํฉ๋๋ค.
+ - ํ๊ฐ ๋ฌธ๊ตฌ๋ '~๋ค.'๋ก ๋๋๋ ํ์ฌํ, ํ์๋ฌธ์ผ๋ก ๋ต๋ณํฉ๋๋ค. ์๋ฅผ ๋ค์ด, '์ข์ต๋๋ค.' ๊ฐ ์๋ '์ข๋ค.' ๋ก ์์ฑํฉ๋๋ค.
+
+- ์์ธ ์ฌํญ
+ - ์ํ ์ฌ์ฉ ํ๊ธฐ๊ฐ ์๋ ์ํ์ ๋ํ ์์์ด๋ ๊ธฐ๋ํ๋ ๋ถ๋ถ์ ๋ถ๋ฆฌํ์ฌ ์ ์ธํฉ๋๋ค.
+ - ๋ณตํฉ์ ์ธ ํ๊ฐ๊ฐ ์กด์ฌํ๋ ๊ฒฝ์ฐ ํด๋น ๋ด์ฉ์ ๋ถ๋ฆฌํ์ฌ ๊ฐ๊ฐ JSON ํญ๋ชฉ์ผ๋ก ์์ฑํฉ๋๋ค. ์๋ฅผ ๋ค์ด, '๋ฐฐ์ก์ด ์์ ํ๊ณ ๋นจ๋์ด์'์ ๊ฒฝ์ฐ '์์ ํ๋ค.' ์ '๋น ๋ฅด๋ค.' ๋ ๊ฐ์ง๋ก ๊ตฌ๋ถํฉ๋๋ค.
+
+### ์
๋ ฅ:
+{review}
+"""
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path)
+ self.model = self.load_model()
+
+ def load_model(self):
+ bnb_config = BitsAndBytesConfig(
+ load_in_4bit=True,
+ bnb_4bit_compute_dtype=torch.float32
+ )
+
+ model = AutoModelForCausalLM.from_pretrained(
+ self.model_path,
+ quantization_config=bnb_config,
+ device_map="auto"
+ )
+
+ lora_config = LoraConfig(
+ target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj'],
+ r=8,
+ lora_alpha=16,
+ lora_dropout=0.1,
+ task_type=TaskType.CAUSAL_LM
+ )
+
+ return get_peft_model(model, lora_config)
+
+ def load_datasets(self):
+ train_data_df = pd.read_csv(self.train_path)
+ eval_data_df = pd.read_csv(self.eval_path)
+
+ train_dataset = [{"prompt": self.PROMPT.format(review=txt), "completion": think}
+ for txt, think in zip(train_data_df["processed"], train_data_df["GPT-4o-Response"])]
+ eval_dataset = [{"prompt": self.PROMPT.format(review=txt), "completion": answer}
+ for txt, answer in zip(eval_data_df["processed"], eval_data_df["GPT-4o-Response"])]
+
+ return Dataset.from_list(train_dataset), Dataset.from_list(eval_dataset)
+
+ def tokenize_func(self, dataset):
+ combined_texts = [f"{prompt}\n{completion}" for prompt, completion in zip(dataset["prompt"], dataset["completion"])]
+ tokenized = self.tokenizer(combined_texts, truncation=True, max_length=800, padding="max_length")
+ tokenized["labels"] = tokenized["input_ids"].copy()
+ return tokenized
+
+ def train(self):
+ tokenized_train_dataset = self.train_datset.map(self.tokenize_func, batched=True)
+ tokenized_eval_dataset = self.eval_datset.map(self.tokenize_func, batched=True)
+
+ training_args = TrainingArguments(
+ output_dir=self.output_dir,
+ num_train_epochs=5,
+ per_device_train_batch_size=1,
+ gradient_accumulation_steps=16,
+ fp16=True,
+ logging_steps=10,
+ save_steps=100,
+ per_device_eval_batch_size=1,
+ evaluation_strategy="steps",
+ eval_strategy="steps",
+ eval_steps=10,
+ learning_rate=1e-5,
+ logging_dir="./logs",
+ run_name=self.output_dir,
+ )
+
+ class MoveToCPUTensorCallback(TrainerCallback):
+ def on_step_end(self, args, state, control, **kwargs):
+ torch.cuda.empty_cache()
+ torch.cuda.ipc_collect()
+
+ trainer = Trainer(
+ model=self.model,
+ args=training_args,
+ train_dataset=tokenized_train_dataset,
+ eval_dataset=tokenized_eval_dataset,
+ callbacks=[MoveToCPUTensorCallback()]
+ )
+
+ trainer.train()
+
+ self.model.save_pretrained(self.output_dir)
+ self.tokenizer.save_pretrained(self.output_dir)
+ print(f"Model saved to {self.output_dir}")
+
+if __name__ == "__main__":
+ trainer = Review_DeepSeekTrainer(
+ model_path="deepseek-ai/DeepSeek-R1-Distill-Qwen-14B",
+ train_path="./data/aste/train/train_data.csv",
+ eval_path="./data/aste/eval/aste_annotation_100_golden_label.csv",
+ output_dir="./deepseek_14b_finetune"
+ )
diff --git a/models/review/src/sft_pipeline/qwen_deepseek_32b_finetuning.py b/models/review/src/sft_pipeline/qwen_deepseek_32b_finetuning.py
new file mode 100644
index 0000000..a452de8
--- /dev/null
+++ b/models/review/src/sft_pipeline/qwen_deepseek_32b_finetuning.py
@@ -0,0 +1,217 @@
+from unsloth import FastLanguageModel
+
+
+def set_model_size(parameters: str) -> str:
+ """๋ชจ๋ธ ํฌ๊ธฐ์ ๋ฐ๋ฅธ ๋ชจ๋ธ ์ด๋ฆ์ ๋ฐํํฉ๋๋ค."""
+ model_map = {
+ "14B": "unsloth/DeepSeek-R1-Distill-Qwen-14B-bnb-4bit",
+ "32B": "unsloth/DeepSeek-R1-Distill-Qwen-32B-bnb-4bit"
+ }
+ if parameters in model_map:
+ print(f"Setting model size to {parameters}")
+ return model_map[parameters]
+ else:
+ raise ValueError("Invalid model size. Choose '14B' or '32B'.")
+
+def load_model(model_size: str = "32B"):
+ """๋ชจ๋ธ๊ณผ ํ ํฌ๋์ด์ ๋ฅผ ๋ก๋ํฉ๋๋ค."""
+ model_name = set_model_size(model_size)
+ model, tokenizer = FastLanguageModel.from_pretrained(
+ model_name=model_name,
+ max_seq_length=2048,
+ dtype=None,
+ load_in_4bit=True,
+ temperature=0.6
+ )
+ FastLanguageModel.for_inference(model)
+ return model, tokenizer
+
+PROMPT_TEMPLATE = """๋น์ ์ ์ํ ๋ฆฌ๋ทฐ์ ๊ฐ์ฑ ๋ถ์ ๋ฐ ํ๊ฐ ์ ๋ฌธ๊ฐ์
๋๋ค. ์ฃผ์ด์ง ์ํ ๋ฆฌ๋ทฐ๋ฅผ ๋ถ์ํ์ฌ ํด๋น ๋ฆฌ๋ทฐ์์ ์์ฑ๊ณผ ํ๊ฐ, ๊ฐ์ฑ์ ์ถ์ถํ์ธ์.
+
+### ์์
๋ชฉํ:
+1. ์
๋ ฅ์ผ๋ก ์ฃผ์ด์ง๋ ์ํ ๋ฆฌ๋ทฐ๋ฅผ ๋ถ์ํ์ฌ JSON ํ์์ผ๋ก ์ถ๋ ฅํ์ธ์.
+2. JSON ํ์์ ์๋์ ๊ฐ์ด ๋ฆฌ์คํธ ํํ๋ก ๊ตฌ์ฑ๋ฉ๋๋ค.
+ ```json
+ [
+ {{"์์ฑ": "<์์ฑ๋ช
1>", "ํ๊ฐ": "<ํ๊ฐ ๋ด์ฉ1>", "๊ฐ์ ": "<๊ธ์ /๋ถ์ /์ค๋ฆฝ>"}},
+ {{"์์ฑ": "<์์ฑ๋ช
2>", "ํ๊ฐ": "<ํ๊ฐ ๋ด์ฉ2>", "๊ฐ์ ": "<๊ธ์ /๋ถ์ /์ค๋ฆฝ>"}},
+ ...
+ ]
+ ```
+
+์์ฑ์ ๋ค์ ์ค ํ๋์ผ ๊ฐ๋ฅ์ฑ์ด ๋์ต๋๋ค: ["์ํ", "๋ฐฐ์ก", "๊ฐ๊ฒฉ", "๋ง", "์ ์ ๋", "์", "ํฌ์ฅ"].
+๋ง์ฝ ์๋ก์ด ์์ฑ์ด ํ์ํ๋ฉด ์์ฑํ์ฌ ์ฌ์ฉ ๊ฐ๋ฅํฉ๋๋ค.
+๋ชจ๋ ์ํ๋ช
์ "์ํ"์ผ๋ก ํต์ผํฉ๋๋ค.
+
+### ์ธ๋ถ ๊ท์น:
+- ๊ฐ์ ๋ถ์
+ - ๋ฆฌ๋ทฐ์์ ๊ฐ์ ์ด ๊ธ์ ์ ์ธ ๊ฒฝ์ฐ "๊ฐ์ ": "๊ธ์ "์ผ๋ก ์ค์ ํฉ๋๋ค.
+ - ๋ถ์ ์ ์ธ ํํ์ด ํฌํจ๋ ๊ฒฝ์ฐ "๊ฐ์ ": "๋ถ์ "์ผ๋ก ์ค์ ํฉ๋๋ค.
+ - ํ๊ฐ๊ฐ ๋ชจํธํ๊ฑฐ๋ ๊ฐ์ ์ ์ธ ๊ฒฝ์ฐ "๊ฐ์ ": "์ค๋ฆฝ"์ผ๋ก ์ค์ ํฉ๋๋ค.
+
+- ํ๊ฐ ๋ฌธ๊ตฌ ์ ์
+ - ๋ฆฌ๋ทฐ์์ ๋ํ๋ ์ฃผ์ ํ๊ฐ๋ฅผ ๊ฐ๊ฒฐํ ๋ฌธ์ฅ์ผ๋ก ๋ณํํฉ๋๋ค.
+ - ํต์ฌ ํค์๋๋ฅผ ์ ์งํ๋ฉด์ ๋ถํ์ํ ํํ์ ์ ๊ฑฐํฉ๋๋ค.
+ - ํ๊ฐ ๋ฌธ๊ตฌ๋ '~๋ค.'๋ก ๋๋๋ ํ์ฌํ, ํ์๋ฌธ์ผ๋ก ๋ต๋ณํฉ๋๋ค. ์๋ฅผ ๋ค์ด, '์ข์ต๋๋ค.' ๊ฐ ์๋ '์ข๋ค.' ๋ก ์์ฑํฉ๋๋ค.
+
+- ์์ธ ์ฌํญ
+ - ์ํ ์ฌ์ฉ ํ๊ธฐ๊ฐ ์๋ ์ํ์ ๋ํ ์์์ด๋ ๊ธฐ๋ํ๋ ๋ถ๋ถ์ ๋ถ๋ฆฌํ์ฌ ์ ์ธํฉ๋๋ค.
+ - ๋ณตํฉ์ ์ธ ํ๊ฐ๊ฐ ์กด์ฌํ๋ ๊ฒฝ์ฐ ํด๋น ๋ด์ฉ์ ๋ถ๋ฆฌํ์ฌ ๊ฐ๊ฐ JSON ํญ๋ชฉ์ผ๋ก ์์ฑํฉ๋๋ค. ์๋ฅผ ๋ค์ด, '๋ฐฐ์ก์ด ์์ ํ๊ณ ๋นจ๋์ด์'์ ๊ฒฝ์ฐ '์์ ํ๋ค.' ์ '๋น ๋ฅด๋ค.' ๋ ๊ฐ์ง๋ก ๊ตฌ๋ถํฉ๋๋ค.
+
+### ์
๋ ฅ:
+{review}
+
+"""
+
+model, tokenizer = load_model("32B")
+
+model = FastLanguageModel.get_peft_model(
+ model,
+ r = 16, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
+ target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
+ "gate_proj", "up_proj", "down_proj",],
+ lora_alpha = 16,
+ lora_dropout = 0, # Supports any, but = 0 is optimized
+ bias = "none", # Supports any, but = "none" is optimized
+ # [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
+ use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
+ random_state = 3407,
+ use_rslora = False, # We support rank stabilized LoRA
+ loftq_config = None, # And LoftQ
+)
+
+import json
+import re
+import pandas as pd
+from torch.utils.data import Dataset
+
+EOS_TOKEN = tokenizer.eos_token
+
+class ASTEDataset(Dataset):
+ def __init__(self, csv_file, encoding='utf-8'):
+ """
+ csv_file: CSV ํ์ผ ๊ฒฝ๋ก
+ encoding: CSV ํ์ผ ์ธ์ฝ๋ฉ (๊ธฐ๋ณธ๊ฐ: 'utf-8')
+ """
+ # CSV ํ์ผ์ pandas DataFrame์ผ๋ก ๋ก๋
+ self.data = pd.read_csv(csv_file, encoding=encoding)
+
+ self.data['GPT-4o-Answer'] = self.data['GPT-4o-Answer'].apply(
+ lambda x: re.sub(r'""', '"', x)
+ )
+
+ def __len__(self):
+ return len(self.data)
+
+ def __getitem__(self, idx):
+ # DataFrame์์ ํด๋น ์ธ๋ฑ์ค์ row๋ฅผ ๊ฐ์ ธ์ด
+ row = self.data.iloc[idx]
+
+ # ๋ฆฌ๋ทฐ ํ
์คํธ๋ฅผ prompt ํ
ํ๋ฆฟ์ ๋ง๊ฒ ํฌ๋งทํ
+ prompt_text = row['๋ฆฌ๋ทฐ']
+ prompt = PROMPT_TEMPLATE.format(review=prompt_text)
+
+ label_json_str = row['GPT-4o-Answer']
+ try:
+ label_obj = json.loads(label_json_str)
+ except json.JSONDecodeError:
+ # JSON ํ์ฑ ์คํจ ์ ๋น ๋ฆฌ์คํธ ๋ฐํ
+ label_obj = []
+
+ cot = row["GPT-4o-Reasoning"] + "\n"
+
+ # JSON ๋ธ๋ก ํํ๋ก ๊ฐ๊ณตํ์ฌ ์ถ๋ ฅ
+ label_text = f'```json\n{json.dumps(label_obj, ensure_ascii=False, indent=4)}\n```'
+
+ text_all = prompt + cot + label_text + EOS_TOKEN
+ print(text_all)
+
+ return {
+ 'instruction': prompt,
+ 'output': label_text,
+ 'text': text_all,
+ }
+
+from ast import literal_eval
+
+def process_aste_example(example):
+ """
+ datasets.map()์์ ๊ฐ ๋ฐฐ์น์ ์ ์ฉ๋ ํจ์
+ ์
๋ ฅ example์ dict ํํ๋ก, ํค๋ '๋ฆฌ๋ทฐ', 'aste_golden_label' ๋ฑ์ด๊ณ ๊ฐ์ ๋ฆฌ์คํธ์
๋๋ค.
+ """
+ instructions = []
+ outputs = []
+ texts = []
+
+ # ๊ฐ ๋ฐฐ์น์ ์ํ๋ค์ ์ํ
+ for review, cot, aste_label in zip(example['๋ฆฌ๋ทฐ'], example['GPT-4o-Reasoning'], example['GPT-4o-Answer']):
+ prompt_text = review
+ prompt = PROMPT_TEMPLATE.format(review=prompt_text)
+
+ # ์ด์ค ๋ฐ์ดํ ๋ฌธ์ ํด๊ฒฐ ๋ฐ JSON ํ์ฑ ์๋
+ try:
+ label_obj = json.loads(aste_label) # JSON ๋ณํ ์๋
+ except json.JSONDecodeError:
+ try:
+ label_obj = literal_eval(aste_label)
+ except Exception:
+ label_obj = [] # ๋ณํ ์คํจ ์ ๋น ๋ฆฌ์คํธ
+
+ # JSON ๋ฌธ์์ด์ ๊ฐ๊ณต
+ label_text = f'```json\n{json.dumps(label_obj, ensure_ascii=False, indent=4)}\n```'
+ text_all = prompt + cot + '\n\n' + label_text + EOS_TOKEN
+
+ instructions.append(prompt)
+ outputs.append(label_text)
+ texts.append(text_all)
+
+
+ # ๊ฐ ํค๋ณ ๋ฆฌ์คํธ๋ฅผ ๋ฐํํ๋ dict๋ก ๊ฒฐ๊ณผ ๋ฐํ
+ return {
+ 'instruction': instructions,
+ 'output': outputs,
+ 'text': texts,
+ }
+
+from datasets import load_dataset
+
+dataset = load_dataset("csv", data_files="train_gpt_splitted.csv", split="train")
+dataset = dataset.map(process_aste_example, batched=True, )
+
+print("---")
+print("dataset:", dataset[1])
+print("---")
+print("dataset[1]['text']:", dataset[1]["text"])
+
+from trl import SFTTrainer
+from transformers import TrainingArguments
+from unsloth import is_bfloat16_supported
+
+trainer = SFTTrainer(
+ model = model,
+ tokenizer = tokenizer,
+ train_dataset = dataset,
+ dataset_text_field = "text",
+ max_seq_length = 2048,
+ dataset_num_proc = 2,
+ packing = False, # Can make training 5x faster for short sequences.
+ args = TrainingArguments(
+ per_device_train_batch_size = 2,
+ gradient_accumulation_steps = 4,
+ warmup_steps = 5,
+ num_train_epochs = 3, # Set this for 1 full training run.
+ # max_steps = 60,
+ learning_rate = 2e-4,
+ fp16 = not is_bfloat16_supported(),
+ bf16 = is_bfloat16_supported(),
+ logging_steps = 1,
+ optim = "adamw_8bit",
+ weight_decay = 0.01,
+ lr_scheduler_type = "linear",
+ seed = 42,
+ output_dir = "output_zeroshot",
+ report_to = "none", # Use this for WandB etc
+ ),
+)
+
+trainer_stats = trainer.train()
diff --git a/models/review/src/sft_pipeline/review_crawling.py b/models/review/src/sft_pipeline/review_crawling.py
new file mode 100644
index 0000000..b28bf2d
--- /dev/null
+++ b/models/review/src/sft_pipeline/review_crawling.py
@@ -0,0 +1,282 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+"""
+[์ํ ์ ๋ณด ๋ฐ ๋ฆฌ๋ทฐ ํฌ๋กค๋ง ํ์ดํ๋ผ์ธ]
+- ์จ๋ผ์ธ ์ผํ๋ชฐ์์ ์ํ ์ ๋ณด, ์์ธ ์ ๋ณด, ๋ฆฌ๋ทฐ ๋ฐ์ดํฐ๋ฅผ ์์งํ๋ ์ฝ๋์
๋๋ค.
+- ์ฃผ์ ๊ธฐ๋ฅ:
+ 1. get_product_urls: ์ง์ ๋ ์นดํ
๊ณ ๋ฆฌ ํ์ด์ง์์ ์ํ๋ช
๊ณผ URL ์ ๋ณด๋ฅผ ์์ง
+ 2. get_product_details: ๊ฐ ์ํ์ ์์ธ ํ์ด์ง์์ ์ ๋ณด๋ฅผ ์์ง
+ 3. details_to_csv: ์์งํ ์ ๋ณด๋ฅผ CSV ํ์ผ๋ก ์ ์ฅ
+ 4. get_product_reviews: ๋ฆฌ๋ทฐ ๋ฐ์ดํฐ๋ฅผ ์์งํ์ฌ CSV๋ก ์ ์ฅ
+"""
+
+import re
+import json
+import time
+import os
+import pandas as pd
+from tqdm import tqdm
+from pprint import pprint
+from bs4 import BeautifulSoup
+from selenium import webdriver
+from selenium.webdriver.common.by import By
+from selenium.webdriver.chrome.options import Options
+from selenium.webdriver.chrome.service import Service
+from selenium.webdriver.support.ui import WebDriverWait
+from selenium.webdriver.support import expected_conditions as EC
+from webdriver_manager.chrome import ChromeDriverManager
+
+def get_product_urls(mall_category_url, market="emart", category="์์ด๊ฐ์", limit=50):
+ BASE_URL = "https://shopping.naver.com"
+ chrome_options = Options()
+ chrome_options.add_argument("--headless")
+ driver = webdriver.Chrome(service=Service(ChromeDriverManager().install()), options=chrome_options)
+ all_items = []
+ driver.get(mall_category_url)
+ last_height = driver.execute_script("return document.body.scrollHeight")
+ while True:
+ try:
+ wait = WebDriverWait(driver, 10)
+ wait.until(EC.presence_of_element_located((By.CLASS_NAME, "_3m7zfsGIZR")))
+ except Exception as e:
+ print("ํ์ด์ง ๋ก๋ฉ ์ค ์ค๋ฅ ๋ฐ์:", e)
+ soup = BeautifulSoup(driver.page_source, "html.parser")
+ items_list = soup.find_all("li", class_="_3m7zfsGIZR")
+ for item in items_list:
+ source = item.find("a", class_=re.compile(".*_3OaphyWXEP.*"))
+ if source is None:
+ continue
+ name_dict = json.loads(source["data-shp-contents-dtl"])
+ all_items.append({
+ "name": name_dict[0]["value"],
+ "url": BASE_URL + source["href"],
+ "market": market,
+ "category": category
+ })
+ all_items = [dict(i) for i in {frozenset(item.items()) for item in all_items}]
+ if len(all_items) >= limit:
+ break
+ driver.execute_script("window.scrollTo(0, document.body.scrollHeight);")
+ time.sleep(5)
+ new_height = driver.execute_script("return document.body.scrollHeight")
+ if new_height == last_height:
+ break
+ last_height = new_height
+ driver.quit()
+ print("์ํ ์ด ๊ฐ์:", len(all_items))
+ return all_items
+
+def get_product_details(all_items, limit=50):
+ from selenium.webdriver.chrome.options import Options # ์ค๋ณต import ๋ฐฉ์ง
+ chrome_options = Options()
+ chrome_options.add_argument("--headless")
+ driver = webdriver.Chrome(service=Service(ChromeDriverManager().install()), options=chrome_options)
+ for item in tqdm(all_items[:limit], desc="์ํ ์์ธ ์ ๋ณด ์์ง", total=len(all_items[:limit])):
+ driver.get(item["url"])
+ try:
+ wait = WebDriverWait(driver, 10)
+ wait.until(EC.presence_of_element_located((By.CLASS_NAME, "_1Z00EgoxQ9")))
+ except Exception as e:
+ print("ํ์ด์ง ๋ก๋ฉ ์ค ์ค๋ฅ ๋ฐ์:", e)
+ soup = BeautifulSoup(driver.page_source, "html.parser")
+ try:
+ thumbnail_img = soup.find("div", class_="_2tT_gkmAOr").find("img")["src"]
+ item["thumbnail"] = thumbnail_img
+ except Exception as e:
+ item["thumbnail"] = None
+ items_list = soup.find_all("span", class_="_1LY7DqCnwR")
+ prices = []
+ for span in items_list:
+ try:
+ prices.append(int(span.get_text().replace(",", "")))
+ except Exception:
+ continue
+ if prices:
+ if len(prices) > 1:
+ item["price"] = {
+ "before_price": max(prices),
+ "after_price": min(prices)
+ }
+ else:
+ item["price"] = {"before_price": prices[0]}
+ else:
+ item["price"] = {}
+ try:
+ strong = soup.find("strong", class_="_2pgHN-ntx6")
+ star = strong.get_text()[2:]
+ item["star"] = float(star)
+ except Exception as e:
+ item["star"] = None
+ sources = soup.find_all("div", class_="_1Z00EgoxQ9")
+ imgs = []
+ divs = []
+ texts = []
+ for s in sources:
+ imgs = [str(i["src"]) for i in s.find_all("img")]
+ divs = [div.get_text() for div in s.find_all("div", class_="tmpl_tit_para")]
+ texts = [p.get_text() for p in s.find_all(["h2", "p", "strong", "b"])]
+ item["imgs"] = {"num": len(imgs), "urls": imgs}
+ item["texts"] = {"num": len(texts), "divs": divs, "contents": texts}
+ try:
+ item["reviews"] = int(soup.find("span", class_="_3HJHJjSrNK").get_text().replace(",", ""))
+ except Exception as e:
+ item["reviews"] = 0
+ print("๋ฆฌ๋ทฐ ๊ฐ์ ์์ง ์คํจ:", item["url"])
+ pprint(item)
+ print("\n" + "=" * 100 + "\n")
+ driver.quit()
+ return all_items
+
+def details_to_csv(all_items):
+ data_df = pd.DataFrame(columns=["ID", "img-ID", "category", "name", "url", "before_price", "after_price", "star", "thumbnail", "imgs", "texts", "num_reviews"])
+ for idx, item in enumerate(all_items[:2]): # ์์ : 2๊ฐ ์ํ์ ๋ํด ์ ์ฅ
+ market = item["market"]
+ product_df = pd.DataFrame({
+ "ID": f"{market}-{str(idx+1)}",
+ "img-ID": f"{market}-{str(idx+1)}-0",
+ "category": item["category"],
+ "name": item["name"],
+ "url": item["url"],
+ "before_price": item["price"].get("before_price", None),
+ "after_price": item["price"].get("after_price", None),
+ "thumbnail": item.get("thumbnail", None),
+ "star": item.get("star", None),
+ "texts": [str(item.get("texts", {}))],
+ "num_reviews": item.get("reviews", 0)
+ })
+ image_df = pd.DataFrame(columns=["ID", "img-ID", "imgs"])
+ for i, img in enumerate(item.get("imgs", {}).get("urls", [])):
+ image_df.loc[i] = [f"{market}-{str(idx+1)}", f"{market}-{str(idx+1)}-{str(i+1)}", img]
+ product_df = pd.concat([product_df, image_df], axis=0)
+ data_df = pd.concat([data_df, product_df], axis=0)
+ data_df.to_csv("product_details.csv", index=False)
+ print("์์ธ ์ ๋ณด CSV ์ ์ฅ ์๋ฃ: product_details.csv")
+
+def get_product_reviews(all_items):
+ from selenium.webdriver.chrome.options import Options
+ chrome_options = Options()
+ chrome_options.add_argument("--headless")
+ driver = webdriver.Chrome(service=Service(ChromeDriverManager().install()), options=chrome_options)
+ review_df = pd.DataFrame(columns=["ID", "review-ID", "category", "name", "url", "meta", "star", "review"])
+ for idx, item in enumerate(all_items):
+ market = item["market"]
+ driver.get(item["url"])
+ last_height = driver.execute_script("return document.body.scrollHeight")
+ try:
+ wait = WebDriverWait(driver, 10)
+ wait.until(EC.presence_of_element_located((By.CLASS_NAME, "_11xjFby3Le")))
+ except Exception as e:
+ print("๋ฆฌ๋ทฐ ํ์ด์ง ๋ก๋ฉ ์ค ์ค๋ฅ ๋ฐ์:", e)
+ try:
+ review_button = driver.find_element(By.XPATH, "//a[contains(text(), '๋ฆฌ๋ทฐ')]")
+ review_button.click()
+ except Exception as e:
+ print("๋ฆฌ๋ทฐ ํญ ํด๋ฆญ ์คํจ:", e)
+ continue
+ time.sleep(1)
+ meta_dict = {}
+ soup = BeautifulSoup(driver.page_source, "html.parser")
+ try:
+ meta_star = float(soup.find("strong", class_="_2pgHN-ntx6").get_text()[2:])
+ except Exception:
+ meta_star = None
+ meta_keys = [key.get_text() for key in soup.find_all("em", class_="_1ehAE1FZXP")]
+ for key in meta_keys:
+ try:
+ button = driver.find_element(By.XPATH, "//button[contains(@class, '_3pfVLZDLde') and @data-shp-area-id='evalnext']")
+ button.click()
+ time.sleep(1)
+ soup = BeautifulSoup(driver.page_source, "html.parser")
+ detail_keys = [key.get_text() for key in soup.find_all("em", class_="_2QT-bjUbDv")]
+ detail_values = [int(value.get_text()[:-1]) for value in soup.find_all("span", class_="_1CGcLXygdq")]
+ detail_dict = {}
+ for d_k, d_v in zip(detail_keys, detail_values):
+ detail_dict[d_k] = d_v
+ meta_dict[key] = detail_dict
+ except Exception as e:
+ continue
+ review_meta_df = pd.DataFrame({
+ "ID": f"{market}-{str(idx+1)}",
+ "review-ID": f"{market}-{str(idx+1)}-0",
+ "category": item["category"],
+ "name": item["name"],
+ "url": item["url"],
+ "meta": [str(meta_dict)],
+ "star": meta_star,
+ })
+ try:
+ latest_button = driver.find_element(By.XPATH, "//a[text()='์ต์ ์']")
+ latest_button.click()
+ time.sleep(1)
+ except Exception as e:
+ print("์ต์ ์ ๋ฒํผ ํด๋ฆญ ์คํจ:", e)
+ soup = BeautifulSoup(driver.page_source, "html.parser")
+ try:
+ review_num = int(soup.find("span", class_="_9Fgp3X8HT7").get_text().replace(",", ""))
+ except Exception:
+ review_num = 0
+ stars = []
+ review_texts = []
+ for i in range(2):
+ try:
+ scrollTo = driver.find_element(By.CLASS_NAME, "_1McWUwk15j")
+ driver.execute_script("arguments[0].scrollIntoView();", scrollTo)
+ except Exception as e:
+ pass
+ soup = BeautifulSoup(driver.page_source, "html.parser")
+ review_divs = soup.find_all("div", class_="_1McWUwk15j")
+ for review in review_divs:
+ try:
+ star = review.find("em", class_="_15NU42F3kT").get_text()
+ stars.append(star)
+ except Exception:
+ stars.append(None)
+ try:
+ text_div = review.find("div", class_="_1kMfD5ErZ6").find_all("span")
+ review_texts.append(text_div[-1].get_text())
+ except Exception:
+ review_texts.append("")
+ try:
+ next_page = driver.find_element(By.XPATH, f"//a[contains(@class, 'UWN4IvaQza') and @data-shp-contents-id='{str(i+2)}']")
+ next_page.click()
+ time.sleep(2)
+ except Exception as e:
+ break
+ review_text_df = pd.DataFrame(columns=["ID", "review-ID", "star", "review"])
+ for i, (star, review) in enumerate(zip(stars, review_texts)):
+ tmp = pd.DataFrame({
+ "ID": f"{market}-{str(idx+1)}",
+ "review-ID": f"{market}-{str(idx+1)}-{str(i+1)}",
+ "star": [star],
+ "review": [review]
+ })
+ review_text_df = pd.concat([review_text_df, tmp], ignore_index=True)
+ reviews = pd.concat([review_meta_df, review_text_df], ignore_index=True)
+ review_df = pd.concat([review_df, reviews], ignore_index=True)
+ driver.quit()
+ return review_df
+
+def run_review_crawling(config):
+ print("\n[์ํ ์ ๋ณด ๋ฐ ๋ฆฌ๋ทฐ ํฌ๋กค๋ง ํ์ดํ๋ผ์ธ ์์]\n")
+ # config์์ ํฌ๋กค๋งํ ์ํ ์ ๋ฑ ์ต์
์ ์ฝ์ ์ ์์
+ CRAWL_LIMIT = 5
+ OUTPUT_FILE = os.path.join(config["paths"]["data_dir"], "product_reviews.csv")
+ mall_category_url = "https://shopping.naver.com/best100v2/main.nhn?catId=50000004"
+ print("์ํ URL ์ ๋ณด ์์ง ์ค...")
+ product_urls = get_product_urls(mall_category_url, market="emart", category="์์ด๊ฐ์", limit=CRAWL_LIMIT)
+ print("์ํ ์์ธ ์ ๋ณด ์์ง ์ค...")
+ product_details = get_product_details(product_urls, limit=CRAWL_LIMIT)
+ print("์ํ ์์ธ ์ ๋ณด๋ฅผ CSV๋ก ์ ์ฅ ์ค...")
+ details_to_csv(product_details)
+ print("์ํ ๋ฆฌ๋ทฐ ์์ง ์ค...")
+ review_df = get_product_reviews(product_details[:CRAWL_LIMIT])
+ review_df.to_csv(OUTPUT_FILE, index=False)
+ print("๋ฆฌ๋ทฐ ์ ๋ณด CSV ์ ์ฅ ์๋ฃ:", OUTPUT_FILE)
+ print("๋ชจ๋ ํฌ๋กค๋ง ์์
์๋ฃ๋์์ต๋๋ค.")
+
+if __name__ == "__main__":
+ run_review_crawling({
+ "paths": {
+ "data_dir": "./data"
+ }
+ })
diff --git a/models/review/src/sft_pipeline/review_preprocessing.py b/models/review/src/sft_pipeline/review_preprocessing.py
new file mode 100644
index 0000000..95dd5d1
--- /dev/null
+++ b/models/review/src/sft_pipeline/review_preprocessing.py
@@ -0,0 +1,151 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+"""
+[๋ฆฌ๋ทฐ ์ ์ฒ๋ฆฌ ํ์ดํ๋ผ์ธ]
+- ์๋ณธ ๋ฆฌ๋ทฐ CSV ํ์ผ์์ ํ
์คํธ๋ฅผ ์ ์ฒ๋ฆฌ (ํน์๋ฌธ์ ์ ๊ฑฐ, ๊ฐํ ๊ต์ฒด, ์์ด/์ซ์ ํํฐ, ๊ณต๋ฐฑ ์ ๊ทํ, ๋ฐ๋ณต ์ ๊ฑฐ, ์งง์ ํ
์คํธ ๋ฐฐ์ )
+- T5 ๋ง์ถค๋ฒ ๊ต์ ๋ชจ๋ธ๋ก ์คํ ๊ต์ ์ํ
+"""
+
+import re
+import os, sys
+import torch
+import pandas as pd
+from glob import glob
+from tqdm import tqdm
+from konlpy.tag import Hannanum
+from transformers import T5ForConditionalGeneration, T5Tokenizer
+from utils.utils import load_and_preprocess_reviews
+
+hannanum = Hannanum()
+
+def remove_special_chars(text):
+ return re.sub(r'[^a-zA-Z0-9๊ฐ-ํฃ\s]', '', text) if isinstance(text, str) else ""
+
+def replace_newlines(text):
+ return re.sub(r'[\r\n]+', ' ', text).strip() if isinstance(text, str) else ""
+
+def filter_text_by_english_ratio(text, ratio=0.3):
+ if not isinstance(text, str) or not text.strip():
+ return ""
+ total = len(text)
+ eng_count = len(re.findall(r"[a-zA-Z]", text))
+ return text if (eng_count / total if total > 0 else 0) <= ratio else ""
+
+def filter_text_by_number_ratio(text, ratio=0.3):
+ if not isinstance(text, str) or not text.strip():
+ return ""
+ total = len(text)
+ num_count = len(re.findall(r"[0-9]", text))
+ return text if (num_count / total if total > 0 else 0) <= ratio else ""
+
+def normalize_whitespace(text):
+ return re.sub(r'\s+', ' ', text).strip() if isinstance(text, str) else ""
+
+def remove_repetition(text):
+ def is_valid_word(word):
+ pos_tags = hannanum.pos(word)
+ for token, pos in pos_tags:
+ if token == word:
+ return True
+ return False
+ def compress_token(token):
+ if is_valid_word(token):
+ return token
+ n = len(token)
+ for L in range(1, n // 2 + 1):
+ segment = token[:L]
+ if segment * (n // L) == token:
+ return segment
+ return token
+ def compress_token_list(tokens):
+ n = len(tokens)
+ for k in range(1, n // 2 + 1):
+ block = tokens[:k]
+ if block * (n // k) == tokens:
+ return block
+ return tokens
+ if not text or not isinstance(text, str):
+ return ""
+ tokens = text.split()
+ tokens = [compress_token(token) for token in tokens]
+ tokens = compress_token_list(tokens)
+ return " ".join(tokens).strip()
+
+def remove_short_text(text, n=5):
+ return text if len(text) > n else ''
+
+MODEL_NAME = "j5ng/et5-typos-corrector"
+model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME)
+tokenizer = T5Tokenizer.from_pretrained(MODEL_NAME)
+device = "cuda:0" if torch.cuda.is_available() else "cpu"
+model = model.to(device)
+
+def batch_correct_typos(df, target_col, batch_size=100):
+ if target_col not in df.columns:
+ raise ValueError("๋์ ์ปฌ๋ผ์ด ๋ฐ์ดํฐํ๋ ์์ ์์ต๋๋ค.")
+ df = df.copy()
+ df["processed"] = None
+ for i in tqdm(range(0, len(df), batch_size), desc="๋ฐฐ์น ์ฒ๋ฆฌ"):
+ batch_texts = df[target_col].iloc[i : i + batch_size].tolist()
+ input_texts = ["๋ง์ถค๋ฒ์ ๊ณ ์ณ์ฃผ์ธ์: " + text for text in batch_texts]
+ encodings = tokenizer(input_texts, return_tensors="pt", padding=True, truncation=True, max_length=128)
+ input_ids = encodings.input_ids.to(device)
+ attention_mask = encodings.attention_mask.to(device)
+ outputs = model.generate(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ max_length=128,
+ num_beams=5,
+ early_stopping=True
+ )
+ output_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)
+ num_rows = min(len(df) - i, len(output_texts))
+ col_idx = df.columns.get_loc("processed")
+ df.iloc[i : i + num_rows, col_idx] = output_texts[:num_rows]
+ return df
+
+def run_review_preprocessing(config):
+ print("\n[๋ฆฌ๋ทฐ ์ ์ฒ๋ฆฌ ์์]\n")
+ input_dir = os.path.join(config["paths"]["crawled_reviews_dir"])
+ output_dir = os.path.join(config["paths"]["preprocessed_dir"])
+ os.makedirs(output_dir, exist_ok=True)
+ csv_files = glob(os.path.join(input_dir, "*.csv"))
+ if not csv_files:
+ print("์ฒ๋ฆฌํ CSV ํ์ผ์ด ์์ต๋๋ค.")
+ return
+ for src_file in csv_files:
+ base_name = os.path.basename(src_file).replace("crawled_", "processed_", 1)
+ dest_file = os.path.join(output_dir, base_name)
+ meta_base_name = os.path.basename(src_file).replace("crawled_", "meta_", 1)
+ meta_dest_file = os.path.join(output_dir, meta_base_name)
+ try:
+ meta_df, df = load_and_preprocess_reviews(src_file)
+ meta_df.to_csv(meta_dest_file, index=False)
+ except Exception as e:
+ print(f"ํ์ผ ๋ก๋ ์คํจ ({src_file}): {e}")
+ continue
+ tqdm.pandas()
+ df["step_special"] = df["review"].progress_apply(remove_special_chars)
+ df["step_newline"] = df["step_special"].apply(replace_newlines)
+ df["step_eng_filter"] = df["step_newline"].apply(filter_text_by_english_ratio)
+ df["step_num_filter"] = df["step_eng_filter"].apply(filter_text_by_number_ratio)
+ df["step_whitespace"] = df["step_num_filter"].apply(normalize_whitespace)
+ df["step_repetition"] = df["step_whitespace"].progress_apply(remove_repetition)
+ df["step_length"] = df["step_repetition"].apply(remove_short_text)
+ df = df[(df["step_length"] != "") & df["step_length"].notna()]
+ df = batch_correct_typos(df, "step_length", batch_size=100)
+ drop_cols = ["step_special", "step_newline", "step_eng_filter",
+ "step_num_filter", "step_whitespace", "step_repetition", "step_length"]
+ df.drop(columns=drop_cols, inplace=True)
+ df.to_csv(dest_file, index=False)
+ print(f"\n[์ ์ฒ๋ฆฌ ๋ฐ์ดํฐ ์ ์ฅ] {dest_file}\n")
+ print("\n[๋ฆฌ๋ทฐ ์ ์ฒ๋ฆฌ ์๋ฃ]\n")
+
+if __name__ == "__main__":
+ run_review_preprocessing({
+ "paths": {
+ "data_dir": "./data",
+ "crawled_reviews": "./data/crawled_reviews",
+ "preprocessed_dir": "./data/preprocessed"
+ }
+ })
diff --git a/models/review/src/sft_pipeline/sft.py b/models/review/src/sft_pipeline/sft.py
new file mode 100644
index 0000000..da54c66
--- /dev/null
+++ b/models/review/src/sft_pipeline/sft.py
@@ -0,0 +1,9 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+
+def run_sft(config):
+ print("๋ฐ์ดํฐ๊ฐ ์ค๋น๋์์ต๋๋ค.\n")
+ print("qwen_deepseek_14,32b_finetuning.py ํ์ผ๋ก SFT๋ฅผ ์งํํ์ธ์.\n")
+
+if __name__ == "__main__":
+ run_sft({"paths": {}})
diff --git a/models/review/src/sft_pipeline/train_data_annotating.py b/models/review/src/sft_pipeline/train_data_annotating.py
new file mode 100644
index 0000000..cc0b512
--- /dev/null
+++ b/models/review/src/sft_pipeline/train_data_annotating.py
@@ -0,0 +1,124 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+"""
+[ํ์ต๋ฐ์ดํฐ ์์ฑ ํ์ดํ๋ผ์ธ]
+- ์ํ ๋ฐ์ดํฐ ํ์ผ์์ ์์ง ์ฒ๋ฆฌ๋์ง ์์ ๋ฆฌ๋ทฐ์ ๋ํด GPT API ํธ์ถ
+- Few-shot ์์์ ํจ๊ป ๋ฉ์์ง ๊ตฌ์ฑํ์ฌ ์๋ต ์์ฑ
+- ์๋ต์์
ํ๊ทธ์ Reasoning ๋ฐ ```json``` ๋ธ๋ก์ Answer ์ถ์ถ
+- ๊ฒฐ๊ณผ๋ฅผ CSV ํ์ผ๋ก ์ ์ฅ
+"""
+
+import os
+import re
+import json
+import pandas as pd
+from tqdm import tqdm
+from openai import OpenAI
+from dotenv import load_dotenv
+# prompt ๋ชจ๋์์ ํ๋กฌํํธ์ few-shot ์์๋ฅผ ๋ถ๋ฌ์ต๋๋ค.
+from prompt.prompt_loader import load_prompt, load_fewshot
+
+def run_train_data_annotating(config):
+ # ํ๊ฒฝ ๋ณ์ ๋ก๋
+ load_dotenv(os.path.expanduser("~/.env"))
+ OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
+ if not OPENAI_API_KEY:
+ raise ValueError("OPENAI_API_KEY๊ฐ ์ค์ ๋์ด ์์ง ์์ต๋๋ค. .env ํ์ผ์ ํ์ธํ์ธ์.")
+ os.environ["OPENAI_API_KEY"] = OPENAI_API_KEY
+
+ # ํ์ผ ๊ฒฝ๋ก๋ config๋ฅผ ํตํด ์ฝ์
+ aste_dir = config["paths"]["aste_dir"]
+ train_dir = config["paths"]["train_dir"]
+ input_file = os.path.join(aste_dir, "aste_sampled.csv")
+ output_temp_file = os.path.join(train_dir, "train_data_TEMP.csv")
+ output_final_file = os.path.join(train_dir, "train_data.csv")
+ BATCH_SIZE = 10
+ MODEL = config["train_data_annotating"]["annotation_model"]
+
+ PROMPT = load_prompt(prompt_filename="annotation_prompt.txt",
+ prompt_dir="./prompt/review_annotation/")
+ FEW_SHOT = load_fewshot(fewshot_filename="annotation_fewshot.json",
+ prompt_dir="./prompt/review_annotation/")
+
+ client = OpenAI()
+
+ print("\n[ํ์ต ๋ฐ์ดํฐ ์์ฑ ์์]\n")
+
+ df = pd.read_csv(input_file)
+ if os.path.exists(output_temp_file):
+ existing = pd.read_csv(output_temp_file)
+ else:
+ existing = pd.DataFrame()
+ processed_ids = set(existing["review-ID"]) if not existing.empty else set()
+ remaining = df[~df["review-ID"].isin(processed_ids)].copy()
+ print(f"๋จ์ ๋ฐ์ดํฐ ์: {len(remaining)}\n")
+
+ results = []
+ for i in tqdm(range(0, len(remaining), BATCH_SIZE), desc="GPT ์ฒ๋ฆฌ"):
+ batch = remaining.iloc[i : i + BATCH_SIZE]
+ for _, row in batch.iterrows():
+ messages = [{"role": "system", "content": PROMPT}]
+ for example in FEW_SHOT:
+ messages.append({"role": "user", "content": example["query"]})
+ messages.append({"role": "assistant", "content": example["answer"]})
+ messages.append({"role": "user", "content": row["processed"]})
+
+ completion = client.chat.completions.create(
+ model=MODEL,
+ messages=messages,
+ )
+ response = completion.choices[0].message.content
+ entry = row.to_dict()
+ entry["GPT_Response"] = response
+ results.append(entry)
+
+ df_batch = pd.DataFrame(results)
+ if os.path.exists(output_temp_file):
+ df_batch.to_csv(output_temp_file, mode="a", header=False, index=False)
+ else:
+ df_batch.to_csv(output_temp_file, mode="w", header=True, index=False)
+ print(f"์ฒ๋ฆฌ ์๋ฃ: {i + len(batch)} / {len(remaining)} ๊ฑด")
+ results = []
+ print("\n๋ชจ๋ ๋ฐ์ดํฐ GPT ์ฒ๋ฆฌ ์๋ฃ.\n")
+
+ df_temp = pd.read_csv(output_temp_file)
+ results = []
+ for _, row in df_temp.iterrows():
+ resp = row.get("GPT_Response", "")
+ reasoning = ""
+ answer = ""
+ try:
+ match_think = re.search(r"(.*?)", resp, re.DOTALL)
+ if match_think:
+ reasoning = match_think.group(1).strip()
+ except Exception:
+ reasoning = ""
+ try:
+ match_json = re.search(r"```json\n(.*?)\n```", resp, re.DOTALL)
+ if match_json:
+ answer = match_json.group(1).strip()
+ except Exception:
+ answer = ""
+ entry = row.copy()
+ entry["GPT_Reasoning"] = reasoning
+ entry["GPT_Answer"] = answer
+ results.append(entry)
+ df_final = pd.DataFrame(results)
+ df_final.to_csv(output_final_file, index=False)
+
+ if os.path.exists(output_temp_file) and (df_temp.shape[0] == df.shape[0]):
+ os.remove(output_temp_file)
+
+
+ print("[์ต์ข
์ด๋
ธํ
์ด์
๋ฐ์ดํฐ ์ ์ฅ]", output_final_file)
+
+ print("\n[ํ์ต ๋ฐ์ดํฐ ์์ฑ ์๋ฃ]\n")
+
+if __name__ == "__main__":
+ run_train_data_annotating({
+ "paths": {
+ "data_dir": "./data",
+ "prompt_dir": "./prompt"
+ },
+ "pipeline": {"sft": {"review_annotation": True}}
+ })
diff --git a/models/review/src/sft_pipeline/train_data_sampling.py b/models/review/src/sft_pipeline/train_data_sampling.py
new file mode 100644
index 0000000..01bd54a
--- /dev/null
+++ b/models/review/src/sft_pipeline/train_data_sampling.py
@@ -0,0 +1,80 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+"""
+[ํ์ต ๋ฐ์ดํฐ ์ํ๋ง ํ์ดํ๋ผ์ธ]
+1. ์ ์ฒ๋ฆฌ๋ ๋ฆฌ๋ทฐ ๋ฐ์ดํฐ์์ ๊ณจ๋ ๋ฐ ์๋ ๋ ์ด๋ธ ๋ฐ์ดํฐ๋ฅผ ์ ์ธ
+2. Sentence-BERT ์๋ฒ ๋ฉ์ ์ด์ฉํ์ฌ ๋ฆฌ๋ทฐ ๋ฒกํฐ ์์ฑ
+3. K-Means ํด๋ฌ์คํฐ๋ง์ผ๋ก 900๊ฐ ํด๋ฌ์คํฐ ์์ฑ ํ ๊ฐ ํด๋ฌ์คํฐ ์ค์ฌ ์ํ ์ ํ
+4. ๋ํ ์ํ CSV ํ์ผ๋ก ์ ์ฅ
+"""
+
+import os
+import sys
+import json
+import numpy as np
+import pandas as pd
+from collections import Counter
+from sklearn.cluster import KMeans
+from tqdm import tqdm
+from glob import glob
+
+# utils ๋ชจ๋์์ ์๋ฒ ๋ฉ ํจ์๋ฅผ import
+from utils.utils import sentenceBERT_embeddings
+
+def filter_data(config):
+ preprocessed_dir = os.path.join(config["paths"]["preprocessed_dir"])
+ csv_files = glob(os.path.join(preprocessed_dir, "*.csv"))
+ csv_files = [f for f in glob(os.path.join(preprocessed_dir, "*.csv"))
+ if not (os.path.basename(f).startswith("meta_")
+ or os.path.basename(f).endswith("_all.csv"))]
+
+ df_list = [pd.read_csv(file) for file in csv_files]
+ merged_df = pd.concat(df_list, ignore_index=True)
+ merged_output = os.path.join(preprocessed_dir, "processed_reviews_all.csv")
+ merged_df.to_csv(merged_output, index=False)
+ print(f"[์ ์ฒด ์ ์ฒ๋ฆฌ ๋ฐ์ดํฐ ์ ์ฅ] {merged_output}\n")
+
+ raw_df = pd.read_csv(merged_output)
+ golden_file = os.path.join(config["paths"]["eval_dir"], "aste_annotation_100_golden_label.csv")
+ golden_df = pd.read_csv(golden_file)
+ df_filtered = raw_df[~raw_df['review-ID'].isin(golden_df['review-ID'])]
+ output_filtered = os.path.join(config["paths"]["aste_dir"], "processed_except_GL.csv")
+ df_filtered.to_csv(output_filtered, index=False)
+ print(f"[๊ณจ๋ ๋ผ๋ฒจ ์ ์ธ ์ ์ฒ๋ฆฌ ๋ฐ์ดํฐ ์ ์ฅ] {output_filtered}")
+ return output_filtered
+
+def perform_kmeans_clustering(embeddings: np.ndarray, num_clusters=900, random_state=42):
+ print("K-Means ํด๋ฌ์คํฐ๋ง ์ํ ์ค...\n")
+ kmeans = KMeans(n_clusters=num_clusters, random_state=random_state, n_init=10)
+ labels = kmeans.fit_predict(embeddings)
+ return kmeans, labels, num_clusters
+
+def select_representative_samples(kmeans, cluster_labels: np.ndarray, num_clusters, embeddings: np.ndarray) -> list:
+ selected_idx = []
+ for cid in tqdm(range(num_clusters), desc="๋ํ ์ํ ์ ํ"):
+ indices = np.where(cluster_labels == cid)[0]
+ center = kmeans.cluster_centers_[cid]
+ closest = indices[np.argmin(np.linalg.norm(embeddings[indices] - center, axis=1))]
+ selected_idx.append(closest)
+ return selected_idx
+
+def run_train_data_sampling(config):
+ print("\n[ํ์ต ๋ฐ์ดํฐ ์ํ๋ง ์์]\n")
+ filtered_file = filter_data(config)
+ embedding_path = os.path.join(config["paths"]["embedding_dir"], "train_sampling.npy")
+ raw_data = pd.read_csv(filtered_file)
+ embedding_matrix = sentenceBERT_embeddings(embedding_path, raw_data, column="processed")
+ kmeans, labels, num_clusters = perform_kmeans_clustering(embedding_matrix, num_clusters=config["train_data_annotating"]["num_train_data"])
+ selected_indices = select_representative_samples(kmeans, labels, num_clusters, embedding_matrix)
+ sampled_df = raw_data.iloc[selected_indices].reset_index(drop=True)
+ output_file = os.path.join(config["paths"]["aste_dir"], "aste_sampled.csv")
+ sampled_df.to_csv(output_file, index=False)
+ print("\n[ํ์ต ๋ฐ์ดํฐ ์ํ๋ง ์๋ฃ]", output_file)
+
+if __name__ == "__main__":
+ run_train_data_sampling({
+ "paths": {
+ "data_dir": "./data",
+ "embedding_dir": "./data/embedding_matrics"
+ }
+ })
diff --git a/models/review/utils/__init__.py b/models/review/utils/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/models/review/utils/evaluate.py b/models/review/utils/evaluate.py
new file mode 100644
index 0000000..1a2a0b0
--- /dev/null
+++ b/models/review/utils/evaluate.py
@@ -0,0 +1,409 @@
+import json
+import numpy as np
+import pandas as pd
+import seaborn as sns
+from tqdm import tqdm
+from ast import literal_eval
+from bert_score import score
+import matplotlib.pyplot as plt
+import matplotlib.font_manager as fm
+from scipy.optimize import linear_sum_assignment
+from sklearn.metrics import confusion_matrix, classification_report
+
+
+"""
+์ฌ์ฉ ์์:
+sys.path.append(os.path.abspath('../src'))
+from evaluate import evaluate_aste
+
+evaluate_aste(
+ df,
+ golden_label_col="aste_golden_label",
+ model_prediction_col="aste_hcx",
+ # eval_threshold=0.85
+)
+"""
+
+
+
+# ํ๊ธ ํฐํธ ์ค์ (Ubuntu์์๋ 'NanumGothic' ์ฌ์ฉ)
+plt.rc('font', family='NanumGothic')
+# ๋ง์ด๋์ค ๊ธฐํธ ๊นจ์ง ๋ฐฉ์ง
+plt.rcParams['axes.unicode_minus'] = False
+
+# def extract_triplets(json_str):
+# """
+# JSON ํฌ๋งท ๋ฌธ์์ด์ ํ์ฑํ์ฌ triplet ๋ฆฌ์คํธ ๋ฐํ.
+# ๊ฐ triplet์ {"์์ฑ": ..., "ํ๊ฐ": ..., "๊ฐ์ ": ...} ํ์์.
+# """
+# try:
+# triplets = json.loads(json_str)
+# return triplets
+# except Exception as e:
+# print("JSON ํ์ฑ ์๋ฌ:", e)
+# print(json_str)
+# return []
+
+
+def extract_triplets(json_str):
+ """
+ JSON ํฌ๋งท ๋ฌธ์์ด์ ํ์ฑํ์ฌ triplet ๋ฆฌ์คํธ ๋ฐํ.
+ ๊ฐ triplet์ {"์์ฑ": ..., "ํ๊ฐ": ..., "๊ฐ์ ": ...} ํ์์.
+ ๋ง์ฝ "์์ฑ", "ํ๊ฐ", "๊ฐ์ " ํค๊ฐ ์์ผ๋ฉด ๊ฐ๊ฐ ๋น ๋ฌธ์์ด("")์ ๋ฃ์.
+ """
+ try:
+ triplets = literal_eval(json_str)
+
+ # ๊ฐ triplet์ ๋ํด ํค๊ฐ ์์ผ๋ฉด ๋น ๋ฌธ์์ด("")์ ๋ฃ์
+ for triplet in triplets:
+ triplet["์์ฑ"] = triplet.get("์์ฑ", "")
+ triplet["ํ๊ฐ"] = triplet.get("ํ๊ฐ", "")
+ triplet["๊ฐ์ "] = triplet.get("๊ฐ์ ", "")
+
+ return triplets
+ except Exception as e:
+ print("JSON ํ์ฑ ์๋ฌ:", e)
+ print(json_str)
+ return []
+
+
+def bertscore_similarity(text1, text2):
+ """
+ BERTScore๋ฅผ ์ฌ์ฉํ์ฌ ๋ ๋ฌธ์ฅ์ ์ ์ฌ๋๋ฅผ ์ธก์ ํจ.
+ F1-score๋ฅผ ๋ฐํ (0~1).
+ """
+ P, R, F1 = score([text1], [text2], lang="ko", verbose=False, device="cuda")
+ return F1.item()
+
+
+def match_triplets(gl_triplets, hcx_triplets, eval_threshold=0.85):
+ """
+ GL์ HCX์ triplet ๋ฆฌ์คํธ ๊ฐ์ 'ํ๊ฐ' ํญ๋ชฉ์ BERTScore ์ ์ฌ๋๋ฅผ ๊ธฐ์ค์ผ๋ก
+ 1:1 ๋งค์นญ์ ์ํํ๋ค. Hungarian Algorithm์ ํ์ฉํ๋ฉฐ, ์ ์ฌ๋๊ฐ eval_threshold ์ด์์ธ ๊ฒฝ์ฐ๋ง ํ๋ณด๋ก ์ ์ ํ๋ค.
+
+ ๋ฐํ: [(gl_index, hcx_index, similarity), ...] (similarity >= eval_threshold)
+ """
+ if len(gl_triplets) == 0 or len(hcx_triplets) == 0:
+ return [] # ๋งค์นญ ๋ถ๊ฐ
+
+ num_gl = len(gl_triplets)
+ num_hcx = len(hcx_triplets)
+ cost_matrix = np.zeros((num_gl, num_hcx))
+ sim_matrix = np.zeros((num_gl, num_hcx))
+
+ # ๊ฐ pair์ ๋ํด 'ํ๊ฐ' ํญ๋ชฉ์ ์ ์ฌ๋๋ฅผ ๊ณ์ฐํ์ฌ cost matrix ๊ตฌ์ฑ
+ for i, gl in enumerate(gl_triplets):
+ for j, hcx in enumerate(hcx_triplets):
+ sim = bertscore_similarity(gl["ํ๊ฐ"], hcx["ํ๊ฐ"])
+ sim_matrix[i, j] = sim
+ cost_matrix[i, j] = 1 - sim # cost: ์ ์ฌ๋๊ฐ ๋์ผ๋ฉด ๋ฎ์ cost
+
+ # Hungarian Algorithm ์ ์ฉ
+ row_ind, col_ind = linear_sum_assignment(cost_matrix)
+ candidate_matches = []
+ for i, j in zip(row_ind, col_ind):
+ sim = sim_matrix[i, j]
+ if sim >= eval_threshold:
+ candidate_matches.append((i, j, sim))
+ return candidate_matches
+
+
+def evaluate_instance(gl_triplets, hcx_triplets, eval_threshold=0.85):
+ """
+ ํ ์ธ์คํด์ค(ํ๋์ ์๋ฌธ)์ ๋ํด GL์ HCX triplet ์ธํธ๋ฅผ ํ๊ฐํ๋ค.
+
+ 1. Hungarian Algorithm์ ํตํด 'ํ๊ฐ' ํญ๋ชฉ ๊ธฐ๋ฐ 1:1 ๋งค์นญ์ ์ํํ๋ค.
+ 2. ๋งค์นญ๋ ๊ฐ triplet ํ๋ณด์ ๋ํด:
+ - 'ํ๊ฐ': ์ ์ฌ๋๊ฐ eval_threshold ์ด์์ด๋ฏ๋ก TP๋ก ๊ฐ์ฃผ.
+ - '์์ฑ'๊ณผ '๊ฐ์ ': GL์ HCX ๊ฐ์ด ์ ํํ ์ผ์นํ๋ฉด TP, ๋ถ์ผ์นํ๋ฉด ์ค๋ฅ๋ก ๊ฐ์ฃผํ์ฌ
+ ๋ผ๋ฒจ ์ธก๋ฉด์์ FN, ์์ธก ์ธก๋ฉด์์ FP๋ก ๊ธฐ๋ก.
+ 3. ๋งค์นญ๋์ง ์์ triplet์ ๋ํด์๋:
+ - ๋ผ๋ฒจ์๋ง ์์ผ๋ฉด FN (์ธ ํญ๋ชฉ ๋ชจ๋)
+ - ์์ธก์๋ง ์์ผ๋ฉด FP (์ธ ํญ๋ชฉ ๋ชจ๋)
+
+ ๋ฐํ: {'์์ฑ': {"TP": ..., "FN": ..., "FP": ...},
+ 'ํ๊ฐ': {"TP": ..., "FN": ..., "FP": ...},
+ '๊ฐ์ ': {"TP": ..., "FN": ..., "FP": ...}}
+ """
+ # ์ด๊ธฐํ
+ counts = {
+ "์์ฑ": {"TP": 0, "FN": 0, "FP": 0},
+ "ํ๊ฐ": {"TP": 0, "FN": 0, "FP": 0},
+ "๊ฐ์ ": {"TP": 0, "FN": 0, "FP": 0}
+ }
+
+ # 1:1 ๋งค์นญ (ํ๊ฐ ํญ๋ชฉ ๊ธฐ์ค)
+ candidate_matches = match_triplets(gl_triplets, hcx_triplets, eval_threshold)
+ matched_gl_indices = set()
+ matched_hcx_indices = set()
+
+ # ๋งค์นญ๋ ํ๋ณด๋ค์ ๋ํด ์ธ๋ถ ํ๊ฐ
+ for i, j, sim in candidate_matches:
+ matched_gl_indices.add(i)
+ matched_hcx_indices.add(j)
+
+ # 'ํ๊ฐ' ํญ๋ชฉ์ ์ ์ฌ๋ eval_threshold ์ด์์ด๋ฏ๋ก TP๋ก ๊ธฐ๋ก
+ counts["ํ๊ฐ"]["TP"] += 1
+
+ # '์์ฑ' ํ๊ฐ: ์ผ์นํ๋ฉด TP, ๋ถ์ผ์น๋ฉด ์ค๋ฅ๋ก FN(๋ผ๋ฒจ) ๋ฐ FP(์์ธก) ์ฒ๋ฆฌ
+ if gl_triplets[i]["์์ฑ"] == hcx_triplets[j]["์์ฑ"]:
+ counts["์์ฑ"]["TP"] += 1
+ else:
+ counts["์์ฑ"]["FN"] += 1
+ counts["์์ฑ"]["FP"] += 1
+
+ # '๊ฐ์ ' ํ๊ฐ: ์ผ์นํ๋ฉด TP, ๋ถ์ผ์น๋ฉด ์ค๋ฅ๋ก FN ๋ฐ FP ์ฒ๋ฆฌ
+ if gl_triplets[i]["๊ฐ์ "] == hcx_triplets[j]["๊ฐ์ "]:
+ counts["๊ฐ์ "]["TP"] += 1
+ else:
+ counts["๊ฐ์ "]["FN"] += 1
+ counts["๊ฐ์ "]["FP"] += 1
+
+ # ๋งค์นญ๋์ง ์์ triplet ์ฒ๋ฆฌ
+ # ๋ผ๋ฒจ์๋ง ์๋ triplet: FN (์ธ ํญ๋ชฉ ๋ชจ๋)
+ for idx in range(len(gl_triplets)):
+ if idx not in matched_gl_indices:
+ counts["์์ฑ"]["FN"] += 1
+ counts["ํ๊ฐ"]["FN"] += 1
+ counts["๊ฐ์ "]["FN"] += 1
+ # ์์ธก์๋ง ์๋ triplet: FP (์ธ ํญ๋ชฉ ๋ชจ๋)
+ for idx in range(len(hcx_triplets)):
+ if idx not in matched_hcx_indices:
+ counts["์์ฑ"]["FP"] += 1
+ counts["ํ๊ฐ"]["FP"] += 1
+ counts["๊ฐ์ "]["FP"] += 1
+
+ return counts
+
+
+def aggregate_evaluation(df, golden_label_col, model_prediction_col, eval_threshold=0.85):
+ """
+ ๋ฐ์ดํฐํ๋ ์(df)์ ๊ฐ ์ธ์คํด์ค์ ๋ํด GL์ HCX triplet ์ธํธ๋ฅผ ํ๊ฐํ๊ณ ,
+ ์ ์ฒด TP, FN, FP๋ฅผ ์ง๊ณํ์ฌ '์์ฑ', 'ํ๊ฐ', '๊ฐ์ ' ๊ฐ๊ฐ์ ๋ํด Precision, Recall, F1์ ๊ณ์ฐํ๋ค.
+
+ ๋์์ ์์ฑ(Aspect)๊ณผ ๊ฐ์ (Sentiment)์ gold/pred ๋ผ๋ฒจ์ ์์งํ๋๋ฐ,
+ ๋จ์ํ 1:1 ๋งค์นญ๋ ๊ฒฝ์ฐ๋ฟ ์๋๋ผ, ๋งค์นญ๋์ง ์์ triplet์ ๋ํด
+ - GL์๋ง ์กด์ฌํ๋ฉด predicted๋ "NO_PRED"๋ก,
+ - ์์ธก์๋ง ์กด์ฌํ๋ฉด gold๋ "NO_GOLD"๋ก ๊ธฐ๋กํ์ฌ ์ ์ฒด ํ๊ฐ์ ๋ฐ์ํ๋ค.
+
+ ๋ํ, 'ํ๊ฐ' ํญ๋ชฉ์ BERTScore ์ ์ฌ๋ ๋ฆฌ์คํธ๋ ์ถ์ ํ๋ค.
+
+ ๋ฐํ:
+ - metrics: ์์ฑ, ํ๊ฐ, ๊ฐ์ ์ ๋ํ Precision, Recall, F1-score ๋ฐ TP, FN, FP ๊ฐ์
+ - classification_data: [aspects_gold, aspects_pred, sentiments_gold, sentiments_pred] (์ ์ฒด ์ฌ๋ก)
+ - eval_similarities: ํ๊ฐ(BERTScore) ์ ์ฌ๋ ๋ฆฌ์คํธ (๋งค์นญ๋ ๊ฒฝ์ฐ๋ง)
+ """
+ total_counts = {
+ "์์ฑ": {"TP": 0, "FN": 0, "FP": 0},
+ "ํ๊ฐ": {"TP": 0, "FN": 0, "FP": 0},
+ "๊ฐ์ ": {"TP": 0, "FN": 0, "FP": 0}
+ }
+
+ # ์ ์ฒด classification ๋ฐ์ดํฐ๋ฅผ ์ํ ๋ฆฌ์คํธ (๋งค์นญ๋ ๊ฒฝ์ฐ์ ๋ฏธ๋งค์นญ ์ฌ๋ก ํฌํจ)
+ aspects_gold_all = []
+ aspects_pred_all = []
+ sentiments_gold_all = []
+ sentiments_pred_all = []
+ eval_similarities = []
+
+ for idx, row in tqdm(df.iterrows(), total=len(df)):
+ gl_triplets = extract_triplets(row[golden_label_col])
+ hcx_triplets = extract_triplets(row[model_prediction_col])
+
+ candidate_matches = match_triplets(gl_triplets, hcx_triplets, eval_threshold)
+ matched_gl_indices = set([i for i, j, sim in candidate_matches])
+ matched_hcx_indices = set([j for i, j, sim in candidate_matches])
+
+ # ๋งค์นญ๋ ๊ฒฝ์ฐ: ์ค์ ๋ผ๋ฒจ์ ๊ธฐ๋ก
+ for i, j, sim in candidate_matches:
+ aspects_gold_all.append(gl_triplets[i]["์์ฑ"])
+ aspects_pred_all.append(hcx_triplets[j]["์์ฑ"])
+ sentiments_gold_all.append(gl_triplets[i]["๊ฐ์ "])
+ sentiments_pred_all.append(hcx_triplets[j]["๊ฐ์ "])
+ eval_similarities.append(sim)
+
+ # ๋งค์นญ๋์ง ์์ GL triplet: ์์ธก์ด ์์ผ๋ฏ๋ก predicted๋ฅผ "NO_PRED"๋ก ๊ธฐ๋ก
+ for idx_gl, gl_triplet in enumerate(gl_triplets):
+ if idx_gl not in matched_gl_indices:
+ aspects_gold_all.append(gl_triplet["์์ฑ"])
+ aspects_pred_all.append("NO_PRED")
+ sentiments_gold_all.append(gl_triplet["๊ฐ์ "])
+ sentiments_pred_all.append("NO_PRED")
+
+ # ๋งค์นญ๋์ง ์์ ์์ธก triplet: GL์ ํด๋นํ๋ ํญ๋ชฉ์ด ์์ผ๋ฏ๋ก gold๋ฅผ "NO_GOLD"๋ก ๊ธฐ๋ก
+ for idx_hcx, hcx_triplet in enumerate(hcx_triplets):
+ if idx_hcx not in matched_hcx_indices:
+ aspects_gold_all.append("NO_GOLD")
+ aspects_pred_all.append(hcx_triplet["์์ฑ"])
+ sentiments_gold_all.append("NO_GOLD")
+ sentiments_pred_all.append(hcx_triplet["๊ฐ์ "])
+
+ # ์ธ์คํด์ค๋ณ ํ๊ฐ (์ ์ฒด TP/FN/FP ์ง๊ณ)
+ counts = evaluate_instance(gl_triplets, hcx_triplets, eval_threshold)
+ for field in total_counts:
+ total_counts[field]["TP"] += counts[field]["TP"]
+ total_counts[field]["FN"] += counts[field]["FN"]
+ total_counts[field]["FP"] += counts[field]["FP"]
+
+ # ๊ฐ ํ๋๋ณ Precision, Recall, F1 ๊ณ์ฐ
+ metrics = {}
+ for field, vals in total_counts.items():
+ TP = vals["TP"]
+ FN = vals["FN"]
+ FP = vals["FP"]
+ precision = TP / (TP + FP) if (TP + FP) > 0 else 0
+ recall = TP / (TP + FN) if (TP + FN) > 0 else 0
+ f1 = (2 * precision * recall / (precision + recall)) if (precision + recall) > 0 else 0
+ metrics[field] = {
+ "Precision": precision,
+ "Recall": recall,
+ "F1": f1,
+ "TP": TP,
+ "FN": FN,
+ "FP": FP
+ }
+
+ # ๊ฒฐ๊ณผ ์ถ๋ ฅ
+ print("์ต์ข
ํ๊ฐ ๊ฒฐ๊ณผ:")
+ for field, m in metrics.items():
+ print(f"{field} -> Precision: {m['Precision']:.4f}, Recall: {m['Recall']:.4f}, F1: {m['F1']:.4f} (TP: {m['TP']}, FN: {m['FN']}, FP: {m['FP']})")
+
+ classification_data = [aspects_gold_all, aspects_pred_all, sentiments_gold_all, sentiments_pred_all]
+ return metrics, classification_data, eval_similarities
+
+
+def extract_unique_labels(df, golden_label_col, model_prediction_col, field):
+ """
+ ๋ฐ์ดํฐํ๋ ์(df)์์ ์์ฑ(Aspect) ๋ฐ ๊ฐ์ (Sentiment)์ ์ ๋ํฌํ ๊ฐ๋ค์ ์ถ์ถํ๋ ํจ์
+ """
+ labels = set()
+
+ for _, row in df.iterrows():
+ gl_triplets = extract_triplets(row[golden_label_col])
+ hcx_triplets = extract_triplets(row[model_prediction_col])
+ for triplet in gl_triplets + hcx_triplets:
+ labels.add(triplet[field])
+
+ # "NO_PRED"์ "NO_GOLD"๋ ๊ฒฐ๊ณผ์ ํฌํจ
+ labels.update(["NO_PRED", "NO_GOLD"])
+ return sorted(list(labels))
+
+
+def plot_confusion_matrix(y_true, y_pred, labels, title="Confusion Matrix"):
+ """
+ Confusion Matrix๋ฅผ ์๊ฐ์ ์ผ๋ก ํ์ํ๋ ํจ์
+ """
+ cm = confusion_matrix(y_true, y_pred, labels=labels)
+ df_cm = pd.DataFrame(cm, index=labels, columns=labels)
+
+ plt.figure(figsize=(6, 5))
+ sns.heatmap(df_cm, annot=True, fmt="d", cmap="Blues", xticklabels=labels, yticklabels=labels)
+ plt.xlabel("Predicted Label")
+ plt.ylabel("True Label")
+ plt.title(title)
+ plt.show()
+
+
+def plot_evaluation_similarity(eval_similarities):
+ """
+ 'ํ๊ฐ' ํญ๋ชฉ์ BERTScore ์ ์ฌ๋ ๊ฐ์ ํ์คํ ๊ทธ๋จ์ผ๋ก ์๊ฐํํ๋ ํจ์
+ """
+ if not eval_similarities:
+ print("ํ๊ฐ ์ ์ฌ๋ ๋ฐ์ดํฐ๊ฐ ์์ต๋๋ค.")
+ return
+
+ avg_similarity = np.mean(eval_similarities)
+ median_similarity = np.median(eval_similarities)
+ std_similarity = np.std(eval_similarities)
+
+ plt.figure(figsize=(6, 5))
+ sns.histplot(eval_similarities, bins=20, kde=True, color='blue')
+ plt.axvline(avg_similarity, color='red', linestyle='dashed', linewidth=2, label=f'ํ๊ท : {avg_similarity:.4f}')
+ plt.axvline(median_similarity, color='green', linestyle='dashed', linewidth=2, label=f'์ค์๊ฐ: {median_similarity:.4f}')
+ plt.xlabel("BERTScore Similarity")
+ plt.ylabel("Frequency")
+ plt.title("Evaluation BERTScore Similarity Distribution")
+ plt.legend()
+ plt.show()
+
+
+def compute_confusion_and_report(df, golden_label_col, model_prediction_col, classification_data):
+ """
+ ์ถ์ ๋ gold์ predicted ๋ผ๋ฒจ์ ์ด์ฉํ์ฌ, ์์ฑ๊ณผ ๊ฐ์ ์ ๋ํ Confusion Matrix์
+ Classification Report๋ฅผ ์ถ๋ ฅํ๋ค.
+ """
+ aspects_gold, aspects_pred, sentiments_gold, sentiments_pred = classification_data
+
+ print("=== ์์ฑ (Aspect) Confusion Matrix ===")
+ aspect_labels = extract_unique_labels(df, golden_label_col, model_prediction_col, "์์ฑ")
+ plot_confusion_matrix(aspects_gold, aspects_pred, labels=aspect_labels, title="Aspect Confusion Matrix")
+
+ print("\n=== ์์ฑ (Aspect) Classification Report ===")
+ print(classification_report(aspects_gold, aspects_pred))
+
+ print("=== ๊ฐ์ (Sentiment) Confusion Matrix ===")
+ sentiment_labels = extract_unique_labels(df, golden_label_col, model_prediction_col, "๊ฐ์ ")
+ plot_confusion_matrix(sentiments_gold, sentiments_pred, labels=sentiment_labels, title="Sentiment Confusion Matrix")
+
+ print("\n=== ๊ฐ์ (Sentiment) Classification Report ===")
+ print(classification_report(sentiments_gold, sentiments_pred))
+
+
+def compute_eval_statistics(eval_similarities):
+ """
+ 'ํ๊ฐ' ํญ๋ชฉ์ BERTScore ์ ์ฌ๋ ๊ฐ์ ๋ํด ํ๊ท , ์ค์๊ฐ, ๋ถํฌ ๋ฑ ํต๊ณ๋ฅผ ์ถ๋ ฅํ๋ค.
+ """
+ if not eval_similarities:
+ print("ํ๊ฐ ์ ์ฌ๋ ๋ฐ์ดํฐ๊ฐ ์์ต๋๋ค.")
+ return None
+ avg_similarity = np.mean(eval_similarities)
+ median_similarity = np.median(eval_similarities)
+ std_similarity = np.std(eval_similarities)
+
+ plot_evaluation_similarity(eval_similarities)
+
+ print("=== ํ๊ฐ (Evaluation) ์ ์ฌ๋ ํต๊ณ ===")
+ print(f"ํ๊ท ์ ์ฌ๋: {avg_similarity:.4f}")
+ print(f"์ค์๊ฐ ์ ์ฌ๋: {median_similarity:.4f}")
+ print(f"ํ์คํธ์ฐจ: {std_similarity:.4f}")
+
+ return avg_similarity, median_similarity, std_similarity
+
+
+def evaluate_aste(df, golden_label_col, model_prediction_col, eval_threshold=0.85):
+ """
+ ์ ์ฒด ํ๊ฐ ๊ณผ์ ์ ํ ๋ฒ์ ์คํํ๋ Wrapper ํจ์.
+
+ 1. aggregate_evaluation() ์คํํ์ฌ ์ฑ๋ฅ ๋ฉํธ๋ฆญ, Confusion Matrix์ฉ ๋ฐ์ดํฐ, ์ ์ฌ๋ ๋ฆฌ์คํธ ๊ณ์ฐ
+ 2. compute_confusion_and_report() ์คํํ์ฌ Confusion Matrix ๋ฐ Classification Report ์ถ๋ ฅ
+ 3. compute_eval_statistics() ์คํํ์ฌ BERTScore ์ ์ฌ๋ ํต๊ณ ์ถ๋ ฅ
+
+ Args:
+ df (pd.DataFrame): ํ๊ฐํ ๋ฐ์ดํฐํ๋ ์
+ golden_label_col (str): ๊ณจ๋ ๋ผ๋ฒจ ์ปฌ๋ผ๋ช
+ model_prediction_col (str): ๋ชจ๋ธ ์์ธก ์ปฌ๋ผ๋ช
+ eval_threshold (float): BERTScore ์ ์ฌ๋ ๊ธฐ์ค ์๊ณ๊ฐ
+
+ Returns:
+ dict: ์ ์ฒด ํ๊ฐ ๋ฉํธ๋ฆญ (metrics)
+ list: Confusion Matrix ๊ณ์ฐ์ ์ํ classification_data ([aspects_gold, aspects_pred, sentiments_gold, sentiments_pred])
+ list: ํ๊ฐ(BERTScore) ์ ์ฌ๋ ๋ฆฌ์คํธ (eval_similarities)
+ """
+ print("\n=== Step 1: Aggregate Evaluation ===")
+ metrics, classification_data, eval_similarities = aggregate_evaluation(
+ df,
+ golden_label_col=golden_label_col,
+ model_prediction_col=model_prediction_col,
+ eval_threshold=eval_threshold
+ )
+
+ print("\n=== Step 2: Compute Confusion Matrix and Classification Report ===")
+ compute_confusion_and_report(
+ df=df,
+ golden_label_col=golden_label_col,
+ model_prediction_col=model_prediction_col,
+ classification_data=classification_data
+ )
+
+ print("\n=== Step 3: Compute Evaluation Statistics (BERTScore Similarity) ===")
+ compute_eval_statistics(eval_similarities)
diff --git a/models/review/utils/utils.py b/models/review/utils/utils.py
new file mode 100644
index 0000000..8e5a006
--- /dev/null
+++ b/models/review/utils/utils.py
@@ -0,0 +1,127 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+"""
+[์ ํธ๋ฆฌํฐ ํจ์ ๋ชจ๋]
+- ๋ฆฌ๋ทฐ ๋ฐ์ดํฐ ๋ก๋, ์ ์ฒ๋ฆฌ, ์๋ฒ ๋ฉ ์์ฑ, UMAP, ํด๋ฌ์คํฐ๋ง, ์๊ฐํ, API ํธ์ถ ๋ฑ
+"""
+
+import os
+import numpy as np
+import pandas as pd
+import json
+import requests
+import time
+import umap
+import hdbscan
+import matplotlib
+matplotlib.use('Agg')
+import matplotlib.pyplot as plt
+from sklearn.cluster import AgglomerativeClustering
+from sklearn.manifold import TSNE
+from sklearn.metrics import silhouette_score, davies_bouldin_score
+from sentence_transformers import SentenceTransformer
+from dotenv import load_dotenv
+
+def load_data(file_path):
+ df = pd.read_csv(file_path)
+ print("๋ฐ์ดํฐ ๋ก๋:", df.shape)
+ return df
+
+def load_and_preprocess_reviews(file_path):
+ df = pd.read_csv(file_path)
+ cols_to_fill = ["category", "name", "url", "meta"]
+ df[cols_to_fill] = df.groupby("ID")[cols_to_fill].transform(lambda x: x.ffill())
+ meta_df = df[df["review-ID"].astype(str).str.endswith("-0")]
+ review_df = df[~df["review-ID"].astype(str).str.endswith("-0")]
+ return meta_df, review_df
+
+def expand_inference_data(df, json_column="unsloth_deepseek_32b"):
+ expanded = []
+ for idx, row in df.iterrows():
+ raw_value = row[json_column]
+
+ # ๋ง์ฝ raw_value๊ฐ ๋ฌธ์์ด์ด๋ฉด json.loads()๋ฅผ ์ฌ์ฉ, ์ด๋ฏธ ๋ฆฌ์คํธ๋ฉด ๊ทธ๋๋ก ์ฌ์ฉ
+ if isinstance(raw_value, str):
+ try:
+ parsed = json.loads(raw_value)
+ except json.JSONDecodeError:
+ print(f"JSON ํ์ฑ ์๋ฌ, review-ID: {row.get('review-ID', 'N/A')}")
+ continue
+ elif isinstance(raw_value, list):
+ parsed = raw_value
+ else:
+ # ๊ทธ ์ธ ํ์
์ธ ๊ฒฝ์ฐ ๋ฌด์
+ continue
+
+ if isinstance(parsed, list):
+ for item in parsed:
+ new_row = row.copy()
+ new_row["aspect"] = item.get("์์ฑ", None)
+ new_row["opinion"] = item.get("ํ๊ฐ", None)
+ new_row["sentiment"] = item.get("๊ฐ์ ", None)
+ expanded.append(new_row)
+ expanded_df = pd.DataFrame(expanded)
+ expanded_df.reset_index(drop=True, inplace=True)
+ expanded_df.ffill(inplace=True)
+ return expanded_df
+
+def sentenceBERT_embeddings(embedding_path, df, column="processed", model_name="dragonkue/BGE-m3-ko"):
+
+ print("\n์๋ฒ ๋ฉ ํ์ผ์ ์๋ก ์์ฑํฉ๋๋ค...\n")
+ model = SentenceTransformer(model_name)
+ embeddings = df[column].apply(lambda txt: model.encode(txt, show_progress_bar=False)).tolist()
+ emb_matrix = np.array(embeddings)
+ np.save(embedding_path, emb_matrix)
+ print(f"\n์๋ฒ ๋ฉ ํ์ผ ์ ์ฅ ์๋ฃ: {embedding_path}\n")
+
+ print(f"์๋ฒ ๋ฉ Shape:{emb_matrix.shape}\n")
+ return emb_matrix
+
+def umap_reduce_embeddings(embedding_matrix, n_components=256, random_state=42):
+ num_samples, _ = embedding_matrix.shape
+ if n_components >= num_samples:
+ print("UMAP ์ ์ฉ ๋ถ๊ฐ: n_components๊ฐ ๋ฐ์ดํฐ ์๋ณด๋ค ํฝ๋๋ค. ์๋ณธ ๋ฐํ.")
+ return embedding_matrix
+ reducer = umap.UMAP(n_components=n_components, random_state=random_state)
+ reduced = reducer.fit_transform(embedding_matrix)
+ print(f"์ฐจ์ ์ถ์ ์๋ฃ: {embedding_matrix.shape} -> {reduced.shape}")
+ return reduced
+
+def agglomerative_clustering(emb, distance_threshold=22.0, linkage="ward"):
+ clustering = AgglomerativeClustering(distance_threshold=distance_threshold,
+ n_clusters=None,
+ compute_full_tree=True,
+ linkage=linkage)
+ labels = clustering.fit_predict(emb)
+ print(f"Agglomerative Clustering ์๋ฃ: ํด๋ฌ์คํฐ ์ {np.unique(labels).shape[0]}")
+ return labels
+
+def visualize_clustering(emb, cluster_labels, config, category):
+ tsne = TSNE(n_components=2, perplexity=30, learning_rate=200, random_state=42)
+ emb_2d = tsne.fit_transform(emb)
+ plt.figure(figsize=(8, 6))
+ sc = plt.scatter(emb_2d[:,0], emb_2d[:,1], c=cluster_labels, cmap="tab10", alpha=0.6)
+ plt.colorbar(sc, label="ํด๋ฌ์คํฐ ๋ฒํธ")
+ plt.title("Agglomerative Clustering Algorithm Visualization")
+ plt.xlabel("Component 1")
+ plt.ylabel("Component 2")
+ # plt.show()
+ plt.savefig(os.path.join(config["paths"]["embedding_dir"], f"{category}_cluster_result.png")) # ๊ฒฐ๊ณผ๋ฅผ ํ์ผ๋ก ์ ์ฅ
+
+
+def evaluate_clustering(emb, cluster_labels, config, category):
+ if len(set(cluster_labels)) > 1:
+ silhouette = silhouette_score(emb, cluster_labels)
+ dbi = davies_bouldin_score(emb, cluster_labels)
+ print(f"์ค๋ฃจ์ฃ ์ ์: {silhouette:.4f}, Davies-Bouldin Index: {dbi:.4f}")
+
+ results = {"category": category, "Silhouette": float(silhouette), "DBI": float(dbi)}
+
+ json_path = os.path.join(config["paths"]["embedding_dir"], f"{category}_clustering_evaluation.json")
+ with open(json_path, "w") as f:
+ json.dump(results, f, indent=4)
+
+ return results
+ else:
+ print("๋จ์ผ ํด๋ฌ์คํฐ๋ก ํ๊ฐ ๋ถ๊ฐ")
+ return {"Silhouette": None, "DBI": None}
diff --git a/models/size_description/README.md b/models/size_description/README.md
new file mode 100644
index 0000000..0720ac0
--- /dev/null
+++ b/models/size_description/README.md
@@ -0,0 +1,40 @@
+# ์ ํ ํฌ๊ธฐ ๊ฐ์ง ๋ฐ ๋น๊ต
+- ๋ณธ ํ๋ก์ ํธ๋ YOLO ๊ธฐ๋ฐ์ Object Detection ๊ธฐ๋ฒ์ ํ์ฉํ์ฌ ์ ํ์ ํฌ๊ธฐ๋ฅผ ๊ฐ์งํ๊ณ , ์ด๋ฅผ ๋ฐํ์ผ๋ก ์ ํ ๊ฐ ํฌ๊ธฐ ๋น๊ต ๊ธฐ๋ฅ์ ์ ๊ณตํฉ๋๋ค.
+- ํ์ต ๋ฐ์ดํฐ์ ๊ธฐ๋ฐํ ๋ชจ๋ธ ๋ผ๋ฒจ๋ง ๋ฐ ํ์ต ๊ณผ์ ์ ๊ฑฐ์ณ, ์ค์ ์ธํผ๋ฐ์ค ๋จ๊ณ์์๋ ์
๋ ฅ ์ด๋ฏธ์ง์์ ์ ํ์ ํฌ๊ธฐ๋ฅผ ์๋์ผ๋ก ์ธก์ ํ๊ณ ๋น๊ตํ ์ ์์ต๋๋ค.
+
+## ์ฃผ์ ํน์ง
+1. **YOLO ๋ชจ๋ธ ํ์ฉ**: YOLO Object Detection์ ํตํด ์ค์๊ฐ์ผ๋ก ์ ํ์ ์์น์ ํฌ๊ธฐ๋ฅผ ๊ฐ์งํฉ๋๋ค.
+2. **๋ชจ๋ธ ๋ผ๋ฒจ๋ง ๋ฐ ํ์ต**: `data/train` ํด๋ ๋ด ๋ค์์ annotation ํ์ผ(์: `mall-101-3.txt`, `mall-106-3.txt` ๋ฑ)์ ํ์ฉํด ๋ชจ๋ธ์ ํ์ตํฉ๋๋ค.
+3. **์ ํ ํฌ๊ธฐ ์ธก์ **: ํ์ต๋ ๋ชจ๋ธ์ ํตํด ์ด๋ฏธ์ง ๋ด ์ ํ์ ํฌ๊ธฐ๋ฅผ ์ ํํ๊ฒ ๊ฐ์ง ๋ฐ ์ธก์ ํฉ๋๋ค.
+4. **ํฌ๊ธฐ ๋น๊ต ๊ธฐ๋ฅ**: ๊ฐ์ง๋ ํฌ๊ธฐ ์ ๋ณด๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ์ฌ๋ฌ ์ ํ ๊ฐ์ ํฌ๊ธฐ๋ฅผ ๋น๊ตํ ์ ์๋ ๊ธฐ๋ฅ์ ์ ๊ณตํฉ๋๋ค.
+
+## ํด๋ ๊ตฌ์กฐ
+```bash
+.
+โโโ data
+โ โโโ train
+โโโ size_info.yaml
+โโโ src
+ โโโ inference.py
+ โโโ train.py
+```
+
+## ์ค์น ๋ฐ ์คํ ๋ฐฉ๋ฒ
+### 1) ํ๊ฒฝ ๊ตฌ์ถ
+- Python 3.11.11 ๋ฒ์ ๊ถ์ฅ
+- ์์กด์ฑ ํจํค์ง ์ค์น: `pip install -r requirements.txt`
+
+### 2) ๋ฐ์ดํฐ ์ค๋น
+- ํ์ต ๋ฐ์ดํฐ
+ - data/train ํด๋์ ํ์ต์ฉ ์ด๋ฏธ์ง ๋ฐ ๊ทธ์ ํด๋นํ๋ annotation ํ์ผ๊ณผ classes.txt๋ฅผ ์์น์ํต๋๋ค.
+ - ๊ฐ ํ
์คํธ ํ์ผ์๋ ์ด๋ฏธ์ง์ ๋ํ ๊ฐ์ฒด ์ขํ ๋ฐ ํด๋์ค ์ ๋ณด๊ฐ ํฌํจ๋์ด ์์ต๋๋ค.
+
+### 3) ์คํ ๋ฐฉ๋ฒ
+- ๋ชจ๋ธ ํ์ต
+ - ํ์ต ์คํฌ๋ฆฝํธ๋ src/train.py์ ์์ต๋๋ค. ํ์ต์ ํ์ํ ํ์ดํผํ๋ผ๋ฏธํฐ๋ ๊ธฐํ ์ต์
์ ์คํฌ๋ฆฝํธ ๋ด์์ ์ค์ ํ๊ฑฐ๋ ์ธ์๊ฐ์ผ๋ก ์ ๋ฌํ ์ ์์ต๋๋ค.
+```python src/train.py```
+
+- ์ธํผ๋ฐ์ค
+ - ํ์ต๋ ๋ชจ๋ธ์ ์ฌ์ฉํ์ฌ ์ด๋ฏธ์ง ์์ ์ ํ ํฌ๊ธฐ๋ฅผ ๊ฐ์ง ๋ฐ ๋น๊ตํ๋ ค๋ฉด src/inference.py ์คํฌ๋ฆฝํธ๋ฅผ ์คํํฉ๋๋ค.
+ - ์ธํผ๋ฐ์ค ์คํ ์, ์
๋ ฅ ์ด๋ฏธ์ง ๊ฒฝ๋ก ๋ฑ ํ์ํ ์ธ์๋ฅผ ์ ๋ฌํ ์ ์์ต๋๋ค.
+```python src/inference.py --image <์ด๋ฏธ์ง_๊ฒฝ๋ก>```
\ No newline at end of file
diff --git a/models/size_description/data/train/classes.txt b/models/size_description/data/train/classes.txt
new file mode 100644
index 0000000..0c690fa
--- /dev/null
+++ b/models/size_description/data/train/classes.txt
@@ -0,0 +1,4 @@
+product
+compare_pet
+compare_can
+compare_cup
diff --git a/models/size_description/data/train/mall-101-3.txt b/models/size_description/data/train/mall-101-3.txt
new file mode 100644
index 0000000..dc27369
--- /dev/null
+++ b/models/size_description/data/train/mall-101-3.txt
@@ -0,0 +1,2 @@
+0 0.362760 0.529948 0.502604 0.450521
+2 0.768490 0.513021 0.183854 0.475000
diff --git a/models/size_description/data/train/mall-106-3.txt b/models/size_description/data/train/mall-106-3.txt
new file mode 100644
index 0000000..025e30d
--- /dev/null
+++ b/models/size_description/data/train/mall-106-3.txt
@@ -0,0 +1,2 @@
+0 0.375333 0.531333 0.444000 0.524000
+1 0.717333 0.514000 0.176000 0.553333
diff --git a/models/size_description/data/train/mall-108-4.txt b/models/size_description/data/train/mall-108-4.txt
new file mode 100644
index 0000000..b51e68d
--- /dev/null
+++ b/models/size_description/data/train/mall-108-4.txt
@@ -0,0 +1 @@
+0 0.436667 0.472667 0.785333 0.788000
diff --git a/models/size_description/data/train/mall-111-3.txt b/models/size_description/data/train/mall-111-3.txt
new file mode 100644
index 0000000..8fcdf35
--- /dev/null
+++ b/models/size_description/data/train/mall-111-3.txt
@@ -0,0 +1,2 @@
+0 0.342000 0.546000 0.481333 0.206667
+3 0.761333 0.519333 0.240000 0.260000
diff --git a/models/size_description/data/train/mall-117-3.txt b/models/size_description/data/train/mall-117-3.txt
new file mode 100644
index 0000000..2e7b284
--- /dev/null
+++ b/models/size_description/data/train/mall-117-3.txt
@@ -0,0 +1,2 @@
+0 0.334375 0.502344 0.387500 0.766146
+1 0.704948 0.543750 0.206771 0.673958
diff --git a/models/size_description/data/train/mall-121-3.txt b/models/size_description/data/train/mall-121-3.txt
new file mode 100644
index 0000000..98b7766
--- /dev/null
+++ b/models/size_description/data/train/mall-121-3.txt
@@ -0,0 +1,2 @@
+0 0.340667 0.501333 0.286667 0.800000
+1 0.655333 0.545333 0.220000 0.714667
diff --git a/models/size_description/data/train/mall-125-2.txt b/models/size_description/data/train/mall-125-2.txt
new file mode 100644
index 0000000..ca4c522
--- /dev/null
+++ b/models/size_description/data/train/mall-125-2.txt
@@ -0,0 +1,2 @@
+0 0.364063 0.528906 0.373958 0.571354
+1 0.683594 0.529687 0.173437 0.557292
diff --git a/models/size_description/data/train/mall-126-5.txt b/models/size_description/data/train/mall-126-5.txt
new file mode 100644
index 0000000..f64f6af
--- /dev/null
+++ b/models/size_description/data/train/mall-126-5.txt
@@ -0,0 +1,2 @@
+0 0.394000 0.500667 0.588000 0.412000
+2 0.810667 0.532667 0.136000 0.345333
diff --git a/models/size_description/data/train/mall-135-4.txt b/models/size_description/data/train/mall-135-4.txt
new file mode 100644
index 0000000..7c26232
--- /dev/null
+++ b/models/size_description/data/train/mall-135-4.txt
@@ -0,0 +1,2 @@
+0 0.362240 0.520052 0.186979 0.675521
+1 0.575781 0.518490 0.206771 0.672396
diff --git a/models/size_description/data/train/mall-137-4.txt b/models/size_description/data/train/mall-137-4.txt
new file mode 100644
index 0000000..0b1c133
--- /dev/null
+++ b/models/size_description/data/train/mall-137-4.txt
@@ -0,0 +1,2 @@
+0 0.396000 0.501333 0.584000 0.466667
+2 0.804667 0.555333 0.142667 0.369333
diff --git a/models/size_description/data/train/mall-144-2.txt b/models/size_description/data/train/mall-144-2.txt
new file mode 100644
index 0000000..8701419
--- /dev/null
+++ b/models/size_description/data/train/mall-144-2.txt
@@ -0,0 +1,2 @@
+0 0.387240 0.546615 0.333854 0.324479
+2 0.676562 0.509635 0.153125 0.392188
diff --git a/models/size_description/data/train/mall-151-3.txt b/models/size_description/data/train/mall-151-3.txt
new file mode 100644
index 0000000..318c741
--- /dev/null
+++ b/models/size_description/data/train/mall-151-3.txt
@@ -0,0 +1,4 @@
+0 0.161333 0.572667 0.090667 0.188000
+3 0.286000 0.597333 0.105333 0.130667
+0 0.568667 0.523333 0.350667 0.278667
+2 0.842667 0.510667 0.114667 0.298667
diff --git a/models/size_description/data/train/mall-152-3.txt b/models/size_description/data/train/mall-152-3.txt
new file mode 100644
index 0000000..27bec42
--- /dev/null
+++ b/models/size_description/data/train/mall-152-3.txt
@@ -0,0 +1,2 @@
+0 0.332667 0.524667 0.478667 0.636000
+1 0.736667 0.501333 0.212000 0.677333
diff --git a/models/size_description/data/train/mall-153-2.txt b/models/size_description/data/train/mall-153-2.txt
new file mode 100644
index 0000000..e6eb999
--- /dev/null
+++ b/models/size_description/data/train/mall-153-2.txt
@@ -0,0 +1,2 @@
+0 0.385677 0.529687 0.193229 0.407292
+2 0.644531 0.512760 0.174479 0.444271
diff --git a/models/size_description/data/train/mall-165-5.txt b/models/size_description/data/train/mall-165-5.txt
new file mode 100644
index 0000000..7f0f34d
--- /dev/null
+++ b/models/size_description/data/train/mall-165-5.txt
@@ -0,0 +1,2 @@
+0 0.343333 0.546667 0.476000 0.642667
+1 0.746667 0.520000 0.218667 0.698667
diff --git a/models/size_description/data/train/mall-167-3.txt b/models/size_description/data/train/mall-167-3.txt
new file mode 100644
index 0000000..01f6726
--- /dev/null
+++ b/models/size_description/data/train/mall-167-3.txt
@@ -0,0 +1,2 @@
+0 0.415885 0.499219 0.134896 0.477604
+2 0.613281 0.542708 0.150521 0.387500
diff --git a/models/size_description/data/train/mall-192-3.txt b/models/size_description/data/train/mall-192-3.txt
new file mode 100644
index 0000000..8edf418
--- /dev/null
+++ b/models/size_description/data/train/mall-192-3.txt
@@ -0,0 +1,2 @@
+0 0.338281 0.519010 0.554688 0.695312
+1 0.781510 0.522656 0.209896 0.678646
diff --git a/models/size_description/data/train/mall-201-5.txt b/models/size_description/data/train/mall-201-5.txt
new file mode 100644
index 0000000..33efb9c
--- /dev/null
+++ b/models/size_description/data/train/mall-201-5.txt
@@ -0,0 +1,2 @@
+0 0.387333 0.565333 0.577333 0.392000
+1 0.790000 0.513333 0.148000 0.496000
diff --git a/models/size_description/data/train/mall-203-4.txt b/models/size_description/data/train/mall-203-4.txt
new file mode 100644
index 0000000..3faa09a
--- /dev/null
+++ b/models/size_description/data/train/mall-203-4.txt
@@ -0,0 +1,2 @@
+0 0.403333 0.514000 0.588000 0.441333
+1 0.800000 0.513333 0.136000 0.448000
diff --git a/models/size_description/data/train/mall-222-3.txt b/models/size_description/data/train/mall-222-3.txt
new file mode 100644
index 0000000..cf26d8f
--- /dev/null
+++ b/models/size_description/data/train/mall-222-3.txt
@@ -0,0 +1,2 @@
+0 0.398698 0.493750 0.583854 0.816667
+1 0.783854 0.659375 0.146875 0.485417
diff --git a/models/size_description/data/train/mall-24-3.txt b/models/size_description/data/train/mall-24-3.txt
new file mode 100644
index 0000000..8075dbe
--- /dev/null
+++ b/models/size_description/data/train/mall-24-3.txt
@@ -0,0 +1,2 @@
+0 0.396354 0.526823 0.579167 0.542188
+1 0.805469 0.528125 0.168229 0.530208
diff --git a/models/size_description/data/train/mall-242-2.txt b/models/size_description/data/train/mall-242-2.txt
new file mode 100644
index 0000000..132660e
--- /dev/null
+++ b/models/size_description/data/train/mall-242-2.txt
@@ -0,0 +1,2 @@
+0 0.365625 0.510677 0.429167 0.727604
+1 0.752344 0.539062 0.206771 0.667708
diff --git a/models/size_description/data/train/mall-243-3.txt b/models/size_description/data/train/mall-243-3.txt
new file mode 100644
index 0000000..78480e3
--- /dev/null
+++ b/models/size_description/data/train/mall-243-3.txt
@@ -0,0 +1,2 @@
+0 0.348177 0.521354 0.509896 0.669792
+1 0.750000 0.519792 0.211458 0.672917
diff --git a/models/size_description/data/train/mall-249-2.txt b/models/size_description/data/train/mall-249-2.txt
new file mode 100644
index 0000000..2bcdaaf
--- /dev/null
+++ b/models/size_description/data/train/mall-249-2.txt
@@ -0,0 +1,2 @@
+0 0.388021 0.548177 0.504167 0.502604
+1 0.769271 0.516146 0.176042 0.563542
diff --git a/models/size_description/data/train/mall-256-3.txt b/models/size_description/data/train/mall-256-3.txt
new file mode 100644
index 0000000..a1bf234
--- /dev/null
+++ b/models/size_description/data/train/mall-256-3.txt
@@ -0,0 +1,2 @@
+0 0.367448 0.528125 0.438021 0.673958
+1 0.755469 0.535677 0.200521 0.652604
diff --git a/models/size_description/data/train/mall-266-2.txt b/models/size_description/data/train/mall-266-2.txt
new file mode 100644
index 0000000..5d10d1b
--- /dev/null
+++ b/models/size_description/data/train/mall-266-2.txt
@@ -0,0 +1,2 @@
+0 0.384115 0.503906 0.428646 0.652604
+1 0.746354 0.517448 0.194792 0.616146
diff --git a/models/size_description/data/train/mall-267-2.txt b/models/size_description/data/train/mall-267-2.txt
new file mode 100644
index 0000000..86f60eb
--- /dev/null
+++ b/models/size_description/data/train/mall-267-2.txt
@@ -0,0 +1,2 @@
+0 0.382552 0.526042 0.496354 0.497917
+1 0.744792 0.514583 0.166667 0.520833
diff --git a/models/size_description/data/train/mall-27-2.txt b/models/size_description/data/train/mall-27-2.txt
new file mode 100644
index 0000000..aaf5619
--- /dev/null
+++ b/models/size_description/data/train/mall-27-2.txt
@@ -0,0 +1,2 @@
+0 0.367448 0.501563 0.232813 0.712500
+1 0.652083 0.519271 0.211458 0.680208
diff --git a/models/size_description/data/train/mall-272-2.txt b/models/size_description/data/train/mall-272-2.txt
new file mode 100644
index 0000000..5efbdfa
--- /dev/null
+++ b/models/size_description/data/train/mall-272-2.txt
@@ -0,0 +1,2 @@
+0 0.397917 0.499219 0.306250 0.196354
+3 0.669792 0.513802 0.139583 0.164062
diff --git a/models/size_description/data/train/mall-273-3.txt b/models/size_description/data/train/mall-273-3.txt
new file mode 100644
index 0000000..7d5fb7c
--- /dev/null
+++ b/models/size_description/data/train/mall-273-3.txt
@@ -0,0 +1,2 @@
+0 0.358333 0.551302 0.163542 0.539062
+1 0.606510 0.490885 0.206771 0.672396
diff --git a/models/size_description/data/train/mall-274-2.txt b/models/size_description/data/train/mall-274-2.txt
new file mode 100644
index 0000000..0b5fa69
--- /dev/null
+++ b/models/size_description/data/train/mall-274-2.txt
@@ -0,0 +1,2 @@
+0 0.334667 0.569333 0.184000 0.570667
+1 0.645333 0.514667 0.210667 0.685333
diff --git a/models/size_description/data/train/mall-275-3.txt b/models/size_description/data/train/mall-275-3.txt
new file mode 100644
index 0000000..d4be780
--- /dev/null
+++ b/models/size_description/data/train/mall-275-3.txt
@@ -0,0 +1,2 @@
+0 0.307333 0.498667 0.422667 0.373333
+3 0.729333 0.530000 0.296000 0.316000
diff --git a/models/size_description/data/train/mall-28-3.txt b/models/size_description/data/train/mall-28-3.txt
new file mode 100644
index 0000000..543ba14
--- /dev/null
+++ b/models/size_description/data/train/mall-28-3.txt
@@ -0,0 +1,2 @@
+0 0.353333 0.516667 0.488000 0.489333
+2 0.772667 0.514000 0.190667 0.494667
diff --git a/models/size_description/data/train/mall-280-5.txt b/models/size_description/data/train/mall-280-5.txt
new file mode 100644
index 0000000..463c7cb
--- /dev/null
+++ b/models/size_description/data/train/mall-280-5.txt
@@ -0,0 +1,2 @@
+0 0.343333 0.502000 0.478667 0.646667
+2 0.759333 0.552667 0.214667 0.548000
diff --git a/models/size_description/data/train/mall-281-3.txt b/models/size_description/data/train/mall-281-3.txt
new file mode 100644
index 0000000..08d2d13
--- /dev/null
+++ b/models/size_description/data/train/mall-281-3.txt
@@ -0,0 +1,2 @@
+0 0.395052 0.495312 0.225521 0.678125
+2 0.666927 0.595573 0.183854 0.477604
diff --git a/models/size_description/data/train/mall-282-3.txt b/models/size_description/data/train/mall-282-3.txt
new file mode 100644
index 0000000..cbe262e
--- /dev/null
+++ b/models/size_description/data/train/mall-282-3.txt
@@ -0,0 +1,2 @@
+0 0.376302 0.500000 0.483854 0.332292
+3 0.735677 0.552083 0.206771 0.237500
diff --git a/models/size_description/data/train/mall-285-3.txt b/models/size_description/data/train/mall-285-3.txt
new file mode 100644
index 0000000..1e66fe8
--- /dev/null
+++ b/models/size_description/data/train/mall-285-3.txt
@@ -0,0 +1,2 @@
+0 0.384115 0.575781 0.459896 0.269271
+2 0.743490 0.509635 0.154688 0.392188
diff --git a/models/size_description/data/train/mall-286-3.txt b/models/size_description/data/train/mall-286-3.txt
new file mode 100644
index 0000000..0ba4879
--- /dev/null
+++ b/models/size_description/data/train/mall-286-3.txt
@@ -0,0 +1,2 @@
+0 0.321333 0.552000 0.442667 0.226667
+3 0.739333 0.522000 0.276000 0.297333
diff --git a/models/size_description/data/train/mall-287-3.txt b/models/size_description/data/train/mall-287-3.txt
new file mode 100644
index 0000000..49d360c
--- /dev/null
+++ b/models/size_description/data/train/mall-287-3.txt
@@ -0,0 +1,2 @@
+0 0.310667 0.500000 0.429333 0.368000
+3 0.735333 0.529333 0.294667 0.312000
diff --git a/models/size_description/data/train/mall-288-3.txt b/models/size_description/data/train/mall-288-3.txt
new file mode 100644
index 0000000..fb4f2be
--- /dev/null
+++ b/models/size_description/data/train/mall-288-3.txt
@@ -0,0 +1,2 @@
+0 0.339333 0.496667 0.486667 0.390667
+3 0.758000 0.559333 0.249333 0.268000
diff --git a/models/size_description/data/train/mall-290-4.txt b/models/size_description/data/train/mall-290-4.txt
new file mode 100644
index 0000000..c525bdb
--- /dev/null
+++ b/models/size_description/data/train/mall-290-4.txt
@@ -0,0 +1,2 @@
+0 0.308667 0.514000 0.428000 0.326667
+3 0.733333 0.523333 0.296000 0.310667
diff --git a/models/size_description/data/train/mall-293-2.txt b/models/size_description/data/train/mall-293-2.txt
new file mode 100644
index 0000000..1e5b6c2
--- /dev/null
+++ b/models/size_description/data/train/mall-293-2.txt
@@ -0,0 +1,2 @@
+0 0.385677 0.575000 0.410938 0.264583
+2 0.718750 0.512760 0.148958 0.392188
diff --git a/models/size_description/data/train/mall-295-3.txt b/models/size_description/data/train/mall-295-3.txt
new file mode 100644
index 0000000..1272168
--- /dev/null
+++ b/models/size_description/data/train/mall-295-3.txt
@@ -0,0 +1,2 @@
+0 0.404167 0.497396 0.214583 0.205208
+3 0.619271 0.514583 0.139583 0.165625
diff --git a/models/size_description/data/train/mall-302-2.txt b/models/size_description/data/train/mall-302-2.txt
new file mode 100644
index 0000000..623385c
--- /dev/null
+++ b/models/size_description/data/train/mall-302-2.txt
@@ -0,0 +1,2 @@
+0 0.390104 0.496875 0.300000 0.197917
+3 0.670573 0.513021 0.141146 0.168750
diff --git a/models/size_description/data/train/mall-303-3.txt b/models/size_description/data/train/mall-303-3.txt
new file mode 100644
index 0000000..19aedac
--- /dev/null
+++ b/models/size_description/data/train/mall-303-3.txt
@@ -0,0 +1,2 @@
+0 0.365104 0.500000 0.133333 0.338542
+3 0.576302 0.545833 0.203646 0.240625
diff --git a/models/size_description/data/train/mall-311-2.txt b/models/size_description/data/train/mall-311-2.txt
new file mode 100644
index 0000000..2733ef9
--- /dev/null
+++ b/models/size_description/data/train/mall-311-2.txt
@@ -0,0 +1,2 @@
+0 0.392448 0.496875 0.151562 0.537500
+2 0.584115 0.572656 0.160938 0.392188
diff --git a/models/size_description/data/train/mall-317-3.txt b/models/size_description/data/train/mall-317-3.txt
new file mode 100644
index 0000000..0c7092c
--- /dev/null
+++ b/models/size_description/data/train/mall-317-3.txt
@@ -0,0 +1,2 @@
+0 0.381771 0.502865 0.319792 0.213021
+3 0.693490 0.515104 0.177604 0.188542
diff --git a/models/size_description/data/train/mall-327-2.txt b/models/size_description/data/train/mall-327-2.txt
new file mode 100644
index 0000000..6fb7e1e
--- /dev/null
+++ b/models/size_description/data/train/mall-327-2.txt
@@ -0,0 +1,2 @@
+0 0.356771 0.569010 0.148958 0.678646
+1 0.625000 0.522135 0.234375 0.772396
diff --git a/models/size_description/data/train/mall-332-5.txt b/models/size_description/data/train/mall-332-5.txt
new file mode 100644
index 0000000..aad1831
--- /dev/null
+++ b/models/size_description/data/train/mall-332-5.txt
@@ -0,0 +1,2 @@
+0 0.334667 0.496667 0.448000 0.670667
+2 0.751333 0.555333 0.228000 0.580000
diff --git a/models/size_description/data/train/mall-354-2.txt b/models/size_description/data/train/mall-354-2.txt
new file mode 100644
index 0000000..3914f33
--- /dev/null
+++ b/models/size_description/data/train/mall-354-2.txt
@@ -0,0 +1,2 @@
+0 0.366406 0.523698 0.519271 0.284896
+3 0.787760 0.567969 0.173437 0.186979
diff --git a/models/size_description/data/train/mall-357-4.txt b/models/size_description/data/train/mall-357-4.txt
new file mode 100644
index 0000000..a8a482f
--- /dev/null
+++ b/models/size_description/data/train/mall-357-4.txt
@@ -0,0 +1,2 @@
+0 0.363333 0.552667 0.526667 0.385333
+2 0.779333 0.512667 0.182667 0.470667
diff --git a/models/size_description/data/train/mall-364-3.txt b/models/size_description/data/train/mall-364-3.txt
new file mode 100644
index 0000000..b711ead
--- /dev/null
+++ b/models/size_description/data/train/mall-364-3.txt
@@ -0,0 +1,2 @@
+0 0.400260 0.498177 0.353646 0.464062
+2 0.702604 0.537240 0.153125 0.389062
diff --git a/models/size_description/data/train/mall-366-4.txt b/models/size_description/data/train/mall-366-4.txt
new file mode 100644
index 0000000..8635a32
--- /dev/null
+++ b/models/size_description/data/train/mall-366-4.txt
@@ -0,0 +1,2 @@
+0 0.386458 0.536719 0.504167 0.396354
+2 0.771615 0.496094 0.183854 0.471354
diff --git a/models/size_description/data/train/mall-371-2.txt b/models/size_description/data/train/mall-371-2.txt
new file mode 100644
index 0000000..eeb2ba3
--- /dev/null
+++ b/models/size_description/data/train/mall-371-2.txt
@@ -0,0 +1,2 @@
+0 0.369531 0.575000 0.445312 0.562500
+1 0.752344 0.516146 0.209896 0.680208
diff --git a/models/size_description/data/train/mall-375-3.txt b/models/size_description/data/train/mall-375-3.txt
new file mode 100644
index 0000000..1014c91
--- /dev/null
+++ b/models/size_description/data/train/mall-375-3.txt
@@ -0,0 +1,2 @@
+0 0.353385 0.495312 0.303646 0.660417
+2 0.639062 0.585677 0.182292 0.470313
diff --git a/models/size_description/data/train/mall-379-3.txt b/models/size_description/data/train/mall-379-3.txt
new file mode 100644
index 0000000..a460f3e
--- /dev/null
+++ b/models/size_description/data/train/mall-379-3.txt
@@ -0,0 +1,2 @@
+0 0.391927 0.553385 0.604688 0.310937
+2 0.806250 0.509115 0.153125 0.393229
diff --git a/models/size_description/data/train/mall-41-3.txt b/models/size_description/data/train/mall-41-3.txt
new file mode 100644
index 0000000..97b9396
--- /dev/null
+++ b/models/size_description/data/train/mall-41-3.txt
@@ -0,0 +1,2 @@
+0 0.372656 0.505208 0.503646 0.692708
+1 0.754167 0.562760 0.185417 0.590104
diff --git a/models/size_description/data/train/mall-44-2.txt b/models/size_description/data/train/mall-44-2.txt
new file mode 100644
index 0000000..3a9f837
--- /dev/null
+++ b/models/size_description/data/train/mall-44-2.txt
@@ -0,0 +1,2 @@
+0 0.395573 0.500521 0.405729 0.514583
+2 0.730990 0.565625 0.154688 0.387500
diff --git a/models/size_description/data/train/mall-52-3.txt b/models/size_description/data/train/mall-52-3.txt
new file mode 100644
index 0000000..30a2293
--- /dev/null
+++ b/models/size_description/data/train/mall-52-3.txt
@@ -0,0 +1,2 @@
+0 0.369531 0.514583 0.240104 0.481250
+2 0.650781 0.530729 0.174479 0.439583
diff --git a/models/size_description/data/train/mall-54-3.txt b/models/size_description/data/train/mall-54-3.txt
new file mode 100644
index 0000000..de09b70
--- /dev/null
+++ b/models/size_description/data/train/mall-54-3.txt
@@ -0,0 +1,2 @@
+0 0.394792 0.579427 0.407292 0.255729
+2 0.731771 0.512760 0.153125 0.385937
diff --git a/models/size_description/data/train/mall-56-3.txt b/models/size_description/data/train/mall-56-3.txt
new file mode 100644
index 0000000..35f1999
--- /dev/null
+++ b/models/size_description/data/train/mall-56-3.txt
@@ -0,0 +1,2 @@
+0 0.392000 0.565333 0.581333 0.426667
+1 0.815333 0.493333 0.180000 0.573333
diff --git a/models/size_description/data/train/mall-59-3.txt b/models/size_description/data/train/mall-59-3.txt
new file mode 100644
index 0000000..d1cab61
--- /dev/null
+++ b/models/size_description/data/train/mall-59-3.txt
@@ -0,0 +1,2 @@
+0 0.378906 0.490625 0.214062 0.644792
+2 0.651823 0.575781 0.183854 0.474479
diff --git a/models/size_description/data/train/mall-77-3.txt b/models/size_description/data/train/mall-77-3.txt
new file mode 100644
index 0000000..c9310a8
--- /dev/null
+++ b/models/size_description/data/train/mall-77-3.txt
@@ -0,0 +1,2 @@
+0 0.418000 0.503333 0.641333 0.441333
+2 0.854000 0.548000 0.137333 0.352000
diff --git a/models/size_description/data/train/mall-83-3.txt b/models/size_description/data/train/mall-83-3.txt
new file mode 100644
index 0000000..3459476
--- /dev/null
+++ b/models/size_description/data/train/mall-83-3.txt
@@ -0,0 +1,2 @@
+0 0.366406 0.489844 0.301563 0.643229
+2 0.655990 0.577344 0.184896 0.477604
diff --git a/models/size_description/data/train/mall-87-3.txt b/models/size_description/data/train/mall-87-3.txt
new file mode 100644
index 0000000..6f8faa0
--- /dev/null
+++ b/models/size_description/data/train/mall-87-3.txt
@@ -0,0 +1,2 @@
+0 0.371875 0.523698 0.422917 0.613021
+1 0.725000 0.516667 0.182292 0.602083
diff --git a/models/size_description/data/train/mall-88-5.txt b/models/size_description/data/train/mall-88-5.txt
new file mode 100644
index 0000000..bc2db02
--- /dev/null
+++ b/models/size_description/data/train/mall-88-5.txt
@@ -0,0 +1,2 @@
+0 0.326000 0.497333 0.438667 0.624000
+2 0.746667 0.519333 0.234667 0.604000
diff --git a/models/size_description/data/train/mall-98-3.txt b/models/size_description/data/train/mall-98-3.txt
new file mode 100644
index 0000000..9568440
--- /dev/null
+++ b/models/size_description/data/train/mall-98-3.txt
@@ -0,0 +1,2 @@
+0 0.290667 0.500667 0.392000 0.398667
+3 0.716667 0.532000 0.313333 0.338667
diff --git a/models/size_description/size_info.yaml b/models/size_description/size_info.yaml
new file mode 100644
index 0000000..ac1fdff
--- /dev/null
+++ b/models/size_description/size_info.yaml
@@ -0,0 +1,10 @@
+# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
+train: train # train images (relative to 'path') 128 images
+val: train # val images (relative to 'path') 128 images
+# test: # test images (optional)
+
+names:
+ 0: product
+ 1: compare_pet
+ 2: compare_beer
+ 3: compare_cup
diff --git a/models/size_description/src/inference.py b/models/size_description/src/inference.py
new file mode 100644
index 0000000..52ec601
--- /dev/null
+++ b/models/size_description/src/inference.py
@@ -0,0 +1,173 @@
+import requests
+from io import BytesIO
+import pandas as pd
+from PIL import Image
+
+from ultralytics import YOLO
+
+
+compare_actual_sizes = {
+ 1: {"name": "2L ํํธ๋ณ", "width_cm": 9.0, "height_cm": 31.0},
+ 2: {"name": "500ml ์บ", "width_cm": 6.5, "height_cm": 16.5},
+ 3: {"name": "์ข
์ด์ปต", "width_cm": 7.0, "height_cm": 7.5}
+}
+
+
+def download_image(url):
+ response = requests.get(url)
+ response.raise_for_status() # ์ค๋ฅ ๋ฐ์ ์ ์์ธ ์ฒ๋ฆฌ
+ return Image.open(BytesIO(response.content))
+
+
+def get_detection_from_image(image_dir, model_path, show=False):
+ model = YOLO(model_path)
+ image = None
+
+ if image_dir.startswith("https"):
+ image = download_image(image_dir)
+ else:
+ image = Image.open(image_dir)
+ results = model(image, conf=0.3)
+
+ if show:
+ # ๊ฒฐ๊ณผ ์๊ฐํ ๋ฐ ์ ์ฅ
+ for idx, result in enumerate(results):
+ result.show()
+
+ # ๊ฐ์ง๋ ๊ฐ์ฒด ์ ๋ณด ์ถ๋ ฅ
+ for result in results:
+ print("Bounding Boxes:", result.boxes.xyxy)
+ print("Class IDs:", result.boxes.cls)
+ print("Confidence Scores:", result.boxes.conf)
+
+ return results
+
+
+def process_result_with_actual_size_desc(result):
+ boxes = result.boxes.xyxy
+ classes = result.boxes.cls
+
+ # product(0) ํ๋ณด์ compare(1,2,3) ํ๋ณด๋ฅผ ๋ถ๋ฆฌํด์ ์ ์ฅ
+ product_candidates = [] # (box)
+ compare_candidates = [] # (box, class_id)
+
+ for box, cls_id in zip(boxes, classes):
+ cls_id_int = int(cls_id.item()) # ํด๋์ค ID๋ฅผ ์ ์๋ก ๋ณํ
+ if cls_id_int == 0: # product
+ product_candidates.append(box)
+ elif cls_id_int in [1, 2, 3]: # compare
+ compare_candidates.append((box, cls_id_int))
+
+ # ๋ง์ฝ product๋ compare ํ๋ณด๊ฐ ํ๋๋ ์๋ค๋ฉด ์์ธ ์ฒ๋ฆฌ
+ if not product_candidates:
+ print("product(0) ํด๋์ค ๋ฐ์ด๋ฉ ๋ฐ์ค๊ฐ ์์ต๋๋ค.")
+ return
+ if not compare_candidates:
+ print("compare(1,2,3) ํด๋์ค ๋ฐ์ด๋ฉ ๋ฐ์ค๊ฐ ์์ต๋๋ค.")
+ return
+
+ # product ์ค์์ x2(์ค๋ฅธ์ชฝ) ์ขํ๊ฐ ๊ฐ์ฅ ํฐ ๋ฐ์ค ํ๋ ์ ํ
+ selected_product_box = max(product_candidates, key=lambda b: b[2].item())
+
+ # compare(1,2,3) ์ค์์ x2(์ค๋ฅธ์ชฝ) ์ขํ๊ฐ ๊ฐ์ฅ ํฐ ๋ฐ์ค๊ณผ ๊ทธ ํด๋์ค ID๋ฅผ ํ๋ ์ ํ
+ selected_compare_box, selected_compare_class = max(compare_candidates, key=lambda x: x[0][2].item())
+
+ global compare_actual_sizes
+
+ compare_info = compare_actual_sizes.get(selected_compare_class)
+ if not compare_info:
+ print("์ ์ ์๋ ๋น๊ต ๋์ ํด๋์ค์
๋๋ค.")
+ return
+
+ compare_name = compare_info["name"]
+
+ compare_actual_width = compare_info["width_cm"]
+ compare_actual_height = compare_info["height_cm"]
+
+ compare_pixel_width = selected_compare_box[2].item() - selected_compare_box[0].item()
+ compare_pixel_height = selected_compare_box[3].item() - selected_compare_box[1].item()
+
+ scale_width = compare_actual_width / compare_pixel_width
+ scale_height = compare_actual_height / compare_pixel_height
+
+ product_pixel_width = selected_product_box[2].item() - selected_product_box[0].item()
+ product_pixel_height = selected_product_box[3].item() - selected_product_box[1].item()
+
+ product_actual_width = product_pixel_width * scale_width
+ product_actual_height = product_pixel_height * scale_height
+
+ width_ratio = product_actual_width / compare_actual_width
+ height_ratio = product_actual_height / compare_actual_height
+
+ def describe_ratio_first(ratio, compare_name, dimension):
+ ratio = round(ratio, 3)
+ base_text = f"{compare_name} {dimension}"
+
+ # ์ ์์ ๊ฐ์์ง ์ฒดํฌ
+ if float(int(ratio)) == float(ratio):
+ if ratio == 1:
+ return f"{base_text}์ ๊ฐ๊ณ "
+ return f"{base_text}์ {int(ratio)}๋ฐฐ์ด๊ณ "
+
+ if ratio > 1.5:
+ # ๋ฐ์ฌ๋ฆผ ํํ, ์ฌ๋ฆผ์ด๋ฉด ํฌ๋ค, ๋ด๋ฆผ์ด๋ฉด ์๋ค ํ์
+ rounded_ratio = round(ratio)
+ if rounded_ratio > ratio: # 2.0 > 1.9
+ return f"{base_text}์ {rounded_ratio}๋ฐฐ๋ณด๋ค ์กฐ๊ธ ์๊ณ "
+ if rounded_ratio < ratio: # 2.0 < 2.2
+ return f"{base_text}์ {rounded_ratio}๋ฐฐ๋ณด๋ค ์กฐ๊ธ ํฌ๊ณ "
+ if ratio > 1.0:
+ return f"{base_text}๋ณด๋ค ์กฐ๊ธ ํฌ๊ณ "
+ if 0.5 <= ratio < 1.0:
+ return f"{base_text}๋ณด๋ค ์กฐ๊ธ ์๊ณ "
+ return f"{base_text}๋ณด๋ค ๋ฐ ์ด์ ์๊ณ "
+
+ def describe_ratio(ratio, compare_name, dimension):
+ ratio = round(ratio, 3)
+ base_text = f"{compare_name} {dimension}"
+
+ # ์ ์์ ๊ฐ์์ง ์ฒดํฌ
+ if float(int(ratio)) == float(ratio):
+ if ratio == 1:
+ return f"{base_text}์ ๊ฐ์ต๋๋ค."
+ return f"{base_text}์ {int(ratio)}๋ฐฐ์
๋๋ค."
+
+ if ratio > 1.5:
+ # ๋ฐ์ฌ๋ฆผ ํํ, ์ฌ๋ฆผ์ด๋ฉด ํฌ๋ค, ๋ด๋ฆผ์ด๋ฉด ์๋ค ํ์
+ rounded_ratio = round(ratio)
+ if rounded_ratio > ratio: # 2.0 > 1.9
+ return f"{base_text}์ {rounded_ratio}๋ฐฐ๋ณด๋ค ์กฐ๊ธ ์์ต๋๋ค."
+ if rounded_ratio < ratio: # 2.0 < 2.2
+ return f"{base_text}์ {rounded_ratio}๋ฐฐ๋ณด๋ค ์กฐ๊ธ ํฝ๋๋ค."
+ if ratio > 1.0:
+ return f"{base_text}๋ณด๋ค ์กฐ๊ธ ํฝ๋๋ค."
+ if 0.5 <= ratio < 1.0:
+ return f"{base_text}๋ณด๋ค ์กฐ๊ธ ์์ต๋๋ค."
+ return f"{base_text}๋ณด๋ค ๋ฐ ์ด์ ์์ต๋๋ค."
+
+ width_description = describe_ratio_first(width_ratio, compare_name, "๋๋น")
+ height_description = describe_ratio(height_ratio, compare_name, "๋์ด")
+
+ description = (
+ f"๋ฐฐ์ก๋ฐ๋ ์ ํ์ ์ค์ ๋๋น๋ {int(product_actual_width)}cm ์ ๋๋ก {width_description}, ์ค์ ๋์ด๋ {int(product_actual_height)}cm ์ ๋๋ก {height_description}"
+ )
+
+ return description
+
+
+def process_row(row, model):
+ detection_result = get_detection_from_image(row["ํฌ๊ธฐ ์ด๋ฏธ์ง URL"], model)
+
+ if detection_result:
+ return process_result_with_actual_size_desc(detection_result)
+ return "์ด๋ฏธ์ง ๋ถ์ ์คํจ"
+
+
+
+if __name__ == "__main__":
+ model = YOLO("/data/ephemeral/home/workspace/personal/size_info/ultralytics/runs/detect/train14/weights/best.pt")
+
+ df = pd.read_csv("250201_image_comparison.csv")
+ df["size_description"] = df.apply(lambda row: process_row(row, model), axis=1)
+
+ df.to_csv("size_description_output.csv", index=False)
diff --git a/models/size_description/src/train.py b/models/size_description/src/train.py
new file mode 100644
index 0000000..08c809b
--- /dev/null
+++ b/models/size_description/src/train.py
@@ -0,0 +1,21 @@
+import requests
+from ultralytics import YOLO
+from PIL import Image
+from io import BytesIO
+
+
+def download_image(url):
+ response = requests.get(url)
+ response.raise_for_status()
+ return Image.open(BytesIO(response.content))
+
+def train_yolo():
+ # YOLO ๋ชจ๋ธ ํ์ต
+ model = YOLO("yolo11m.pt")
+ model.train(data="size_info.yaml", epochs=12)
+
+ results = model.val()
+ success = model.export(format="onnx")
+
+if __name__ == "__main__":
+ train_yolo()
\ No newline at end of file
diff --git a/models/thumbnail_description/.DS_Store b/models/thumbnail_description/.DS_Store
new file mode 100644
index 0000000..a605e32
Binary files /dev/null and b/models/thumbnail_description/.DS_Store differ
diff --git a/models/thumbnail_description/.gitkeep b/models/thumbnail_description/.gitkeep
new file mode 100644
index 0000000..e69de29
diff --git a/models/thumbnail_description/README.md b/models/thumbnail_description/README.md
new file mode 100644
index 0000000..4f5d443
--- /dev/null
+++ b/models/thumbnail_description/README.md
@@ -0,0 +1,157 @@
+# ๋ํ ์ด๋ฏธ์ง ์ค๋ช
+
+> ๋ํ ์ด๋ฏธ์ง ์ค๋ช
๊ธฐ๋ฅ์ ์๊ฐ์ฅ์ ์ธ ์ฌ์ฉ์๊ฐ ์ํ์ ์ธํ๊ณผ ํฌ์ฅ ์ํ๋ฅผ ๋ณด๋ค ์ฝ๊ฒ ํ์
ํ ์ ์๋๋ก ๋์ต๋๋ค.
+> ์ด ๊ธฐ๋ฅ์ ํตํด ์ํ ๋ํ ์ด๋ฏธ์ง์ ๋ํ ํฌ์ฅ ์ํ(์์, ์ฌ์ง, ํฌ๋ช
์ฑ), ๊ตฌ์ฑ, ๋์์ธ ๋ฑ์ ์๊ฐ์ ์ ๋ณด๋ฅผ ์ ๊ณตํฉ๋๋ค.
+> ์๊ฐ์ฅ์ ์ธ ์ฌ์ฉ์๊ฐ ์ํ์ ์ธํ๊ณผ ํฌ์ฅ ์ํ๋ฅผ ์ดํดํ์ฌ ๋ณด๋ค ํธ๋ฆฌํ ์จ๋ผ์ธ ์ผํ ๊ฒฝํ์ ์ง์ํฉ๋๋ค.
+> ์์ฑ์ : ์ค์ ์
+
+---
+
+## ๊ฐ์
+
+์๊ฐ์ฅ์ ์ธ ์ฌ์ฉ์์๊ฒ **์ํ ์ด๋ฏธ์ง**๋ฅผ ๊ฐ๊ด์ ์ด๊ณ ๊ฐ๊ฒฐํ๊ฒ ์ค๋ช
ํ๊ธฐ ์ํ **์ด๋ฏธ์ง ์บก์
๋** ์์คํ
์ ๊ตฌ์ฑํฉ๋๋ค.
+
+1. **VLM ๋ชจ๋ธ Inference**
+ - Janus-Pro, Qwen2.5_VL ๋ฑ์ ํตํด ๋ํ ์ด๋ฏธ์ง์ ๋ํ ์ค๋ช
์ ์์ฑํฉ๋๋ค.
+2. **ํ์ฒ๋ฆฌ (Post_processing)**
+ - HyperCLOVA HCX-003 ๋ชจ๋ธ์ ํ์ฉํ์ฌ ๋ฒ์ญ ๋ฐ Few-shot ๋ฐฉ์ ๋ฑ์ ํตํด ์ค๋ช
์ ํ์ง์ ๋์
๋๋ค.
+3. **ํ์ธํ๋ (Finetuning)**
+ - GPT-4o ์์ฑ๊ณผ ์๋๊ฒ์๋ก ๊ตฌ์ฑํ ๋ฐ์ดํฐ์
์ผ๋ก Janus-Pro๋ฅผ ํ์ธํ๋ํ์ฌ ์ ๊ตํ ์ค๋ช
์ฑ๋ฅ์ ๋์ ํฉ๋๋ค.
+4. **ํ๊ฐ (Evaluation)**
+ - OpenAI GPT-4o ๋ชจ๋ธ์ ํ์ฉํด ์ ์๋ฅผ ๋งค๊ฒจ ์ค๋ช
์ ํ์ง์ ์ธก์ ํ ์ ์์ต๋๋ค.
+
+> `config.yaml`์์ **API Key, ํ์ผ ๊ฒฝ๋ก**, ์คํ ์ฌ๋ถ ๋ฑ์ ๊ด๋ฆฌํ์ฌ ํ์ดํ๋ผ์ธ์ ์ ์ฐํ๊ฒ ์ ์ดํฉ๋๋ค.
+---
+## ํ์ดํ๋ผ์ธ
+
+
+## ์ฑ๋ฅ ๊ณ ๋ํ ๊ณผ์
+
+
+
+---
+
+## ํด๋ ๊ตฌ์กฐ
+
+```bash
+thumbnail_description/
+โโโ config
+โ โโโ config.yaml # ์ค์ ํ์ผ(API Key, ๊ฒฝ๋ก, ํ์ดํ๋ผ์ธ ์คํ ์ฌ๋ถ)
+โโโ data
+โ โโโ ... (๊ฐ์ข
CSV ๋ฐ์ดํฐ)
+โ โโโ ...
+โโโ hcx_prompt
+โ โโโ system_janus_pro_hcx_fewshot.txt
+โ โโโ system_janus_pro_hcx_translation.txt
+โ โโโ system_qwen2_5_pp_hcx.txt
+โ โโโ user_janus_pro_hcx_fewshot.txt
+โ โโโ user_janus_pro_hcx_translation.txt
+โ โโโ user_qwen2_5_pp_hcx.txt
+โโโ prompt
+โ โโโ deepseek_prompt.txt
+โ โโโ janus_prompt.txt
+โ โโโ maal_prompt.txt
+โ โโโ qwen2_5_prompt.txt
+โ โโโ qwen2_prompt.txt
+โ โโโ unsloth_prompt.txt
+โ src
+โ โโโ description_pipeline # ์ค๋ช
์์ฑ ํ์ดํ๋ผ์ธ
+โ โ โโโ inference_model # ๋ชจ๋ธ ์ถ๋ก ์ฝ๋
+โ โ โ โโโ deepseekvl.py # DeepSeek_VL์ ํ์ฉํ ์ธ๋ค์ผ ์ค๋ช
์์ฑ
+โ โ โ โโโ finetuned_janus_pro.py # ์ง์ ํ์ธํ๋ํ Janus Pro์ ํ์ฉํ ์ธ๋ค์ผ ์ค๋ช
์์ฑ
+โ โ โ โโโ janus_pro.py # Janus Pro์ ํ์ฉํ ์ธ๋ค์ผ ์ค๋ช
์์ฑ
+โ โ โ โโโ maal.py # MAAL์ ํ์ฉํ ์ธ๋ค์ผ ์ค๋ช
์์ฑ
+โ โ โ โโโ qwen2_5_vl.py # Qwen2.5_VL์ ํ์ฉํ ์ธ๋ค์ผ ์ค๋ช
์์ฑ
+โ โ โ โโโ qwen2_vl.py # Qwen2_VL์ ํ์ฉํ ์ธ๋ค์ผ ์ค๋ช
์์ฑ
+โ โ โ โโโ unsloth_qwen2_vl.py # Unsloth_Qwen2_VL์ ํ์ฉํ ์ธ๋ค์ผ ์ค๋ช
์์ฑ
+โ โ โโโ post_processing # ํ์ฒ๋ฆฌ ๊ด๋ จ ์ฝ๋
+โ โ โ โโโ janus_pro_hcx_translation.py # HCX ๋ฒ์ญ์ ํ์ฉํ Janus Pro ํ์ฒ๋ฆฌ
+โ โ โ โโโ janus_pro_papago_translation.py # Papago ๋ฒ์ญ์ ํ์ฉํ Janus Pro ํ์ฒ๋ฆฌ
+โ โ โ โโโ janus_pro_pp_hcx.py # Janus Pro ๋ชจ๋ธ์ PP-HCX ๊ธฐ๋ฐ ํ์ฒ๋ฆฌ
+โ โ โ โโโ qwen2_5_pp_hcx.py # Qwen2.5 ๋ชจ๋ธ์ PP-HCX ๊ธฐ๋ฐ ํ์ฒ๋ฆฌ
+โ โ โโโ evaluation # ํ๊ฐ ๊ด๋ จ ์ฝ๋
+โ โ โ โโโ gpt_eval_323.py # GPT ๊ธฐ๋ฐ ํ๊ฐ์
์ธ๋ค์ผ ์ค๋ช
ํ๊ฐ
+โ โ โ โโโ gpt_eval.py # GPT ๊ธฐ๋ฐ ์ ์ฒด ๋ฐ์ดํฐ ์
์ธ๋ค์ผ ์ค๋ช
ํ๊ฐ
+โโโ sft_pipeline # SFT(์ง๋ ํ์ต ๋ฏธ์ธ ์กฐ์ ) ๊ด๋ จ ์ฝ๋
+โ โโโ detailed_feature_description.py # 1327๊ฐ ๋ํ ์ด๋ฏธ์ง GPT๊ธฐ๋ฐ ์ค๋ฒ๋ผ๋ฒจ ์ถ์ถ ์ฝ๋
+โ โโโ janus_pro_7b_finetuning.py # ๊ณจ๋๋ผ๋ฒจ(์ค๋ฒ๋ผ๋ฒจ + ๊ฒ์)ํ์ฉ Janus Pro ํ์ธํ๋
+
+โโโ utils
+โ โโโ __init__.py
+โ โโโ common_utils.py # ๊ณตํต ์ ํธ๋ฆฌํฐ ํจ์ ์ ์
+โโโ main.py # ๋ฉ์ธ ์คํ ํ์ผ
+```
+---
+
+## ์
๋ ฅ(Input)๊ณผ ์ถ๋ ฅ(Output)
+
+### ์
๋ ฅ
+
+1. **์ํ ๋ํ ์ด๋ฏธ์ง ๋ฐ ๋ฉํ๋ฐ์ดํฐ**
+ - **์ด๋ฏธ์ง ํ์ผ**:
+ ์จ๋ผ์ธ ์ผํ๋ชฐ์์ ์ ๊ณตํ๋ ์ํ ๋ํ ์ด๋ฏธ์ง๊ฐ ์์คํ
์ ์ฃผ์ ์
๋ ฅ ๋ฐ์ดํฐ์
๋๋ค.
+ - **CSV ๋ฐ์ดํฐ**:
+ `data` ํด๋ ๋ด์ CSV ํ์ผ๋ค์ ๊ฐ ์ํ์ ๋ํ ์ถ๊ฐ ๋ฉํ๋ฐ์ดํฐ(์: ์ํ ์ฝ๋, ์นดํ
๊ณ ๋ฆฌ, ๊ธฐ์กด ์ค๋ช
๋ฑ)๋ฅผ ํฌํจํ๋ฉฐ, ์ด๋ฏธ์ง์ ์ฐ๊ณ๋์ด ํ์ฒ๋ฆฌ ๋ฐ ํ๊ฐ ๊ณผ์ ์์ ํ์ฉ๋ฉ๋๋ค.
+
+2. **์ค์ ์ ๋ณด ๋ฐ API ์ธ์ฆ**
+ - **config.yaml**:
+ API Key, ํ์ผ ๊ฒฝ๋ก, ํ์ดํ๋ผ์ธ ์คํ ์ฌ๋ถ ๋ฑ ์ ์ฒด ์์คํ
์ ์ค์ ์ ๋ณด๋ฅผ ํฌํจํฉ๋๋ค.
+ - **๋ฒ์ญ ๋ฐ ํ์ฒ๋ฆฌ ๊ด๋ จ ์ค์ **:
+ HyperCLOVA HCX-003, OpenAI API Key ๋ฑ ํ์ฒ๋ฆฌ์ ํ๊ฐ์ ํ์ํ ์ธ์ฆ ์ ๋ณด์ ํ๋ผ๋ฏธํฐ๋ฅผ ํฌํจํฉ๋๋ค.
+
+3. **ํ๋กฌํํธ ํ
์คํธ**
+ - **hcx_prompt ํด๋**:
+ Janus-Pro, Qwen2.5_VL ๋ฑ ๋ค์ํ ๋ชจ๋ธ์ ๋ฒ์ญ ๋ฐ Few-shot ํ์ต์ ํ์ํ ํ๋กฌํํธ ํ
์คํธ๋ฅผ ์ ์ฅํฉ๋๋ค.
+ - **prompt ํด๋**:
+ DeepSeek, Janus, MAAL, Qwen2_VL ๋ฑ ๋ค์ํ VLM ๋ชจ๋ธ์ ๋ํ ์ธํผ๋ฐ์ค ์์ฒญ ํ๋กฌํํธ๋ฅผ ํฌํจํฉ๋๋ค.
+
+---
+
+### ์ถ๋ ฅ (Output)
+
+1. **์ธ๋ค์ผ ์ด๋ฏธ์ง์ ๋ํ ํ
์คํธ ์ค๋ช
**
+ - **๊ธฐ๋ณธ ์์ฑ ๊ฒฐ๊ณผ**:
+ `src/description` ํด๋ ๋ด์ ๊ฐ ๋ชจ๋(์: `janus_pro.py`, `qwen2_5_vl.py` ๋ฑ)์ ์
๋ ฅ๋ ์ํ ์ด๋ฏธ์ง๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ํฌ์ฅ ์ํ(์์, ์ฌ์ง, ํฌ๋ช
์ฑ), ๊ตฌ์ฑ, ๋์์ธ ๋ฑ์ ์๊ฐ์ ์ ๋ณด๋ฅผ ํฌํจํ ํ
์คํธ ์ค๋ช
์ ์์ฑํฉ๋๋ค.
+ - **ํ์ฒ๋ฆฌ๋ ๊ฒฐ๊ณผ**:
+ - `src/post_processing` ํด๋ ๋ด์ ์คํฌ๋ฆฝํธ๋ค์ ์ด๊ธฐ ์์ฑ๋ ํ
์คํธ ์ค๋ช
์ ๋ฒ์ญ(Papago ๋๋ HCX ๊ธฐ๋ฐ)ํ๊ฑฐ๋ Few-shot ๊ธฐ๋ฒ์ ํ์ฉํด ํ์ง์ ๋ณด์ ํฉ๋๋ค.
+ - ์๋ฅผ ๋ค์ด, `janus_pro_hcx_translation.py`์ `qwen2_5_pp_hcx.py` ๋ชจ๋์ ๊ฐ๊ฐ ํด๋น ๋ชจ๋ธ์ ์ถ๋ ฅ์ ๋ํด ํ์ฒ๋ฆฌ๋ฅผ ์ํํฉ๋๋ค.
+
+2. **ํ๊ฐ ๋ฐ ์ ์ ์ ๋ณด**
+ - **GPT ํ๊ฐ ๊ฒฐ๊ณผ**:
+ `src/evaluation` ํด๋์ `gpt_eval.py` ๋ฐ `gpt_eval_323.py` ์คํฌ๋ฆฝํธ๋ GPT-4o ๋ชจ๋ธ ๋ฑ์ ํ์ฉํ์ฌ ์์ฑ๋ ํ
์คํธ ์ค๋ช
์ ํ์ง์ ํ๊ฐํฉ๋๋ค.
+ - **์ ์ฒด ํ์ดํ๋ผ์ธ ํ๊ฐ**:
+ ํ๊ฐ ๊ฒฐ๊ณผ๋ ์ต์ข
์ถ๋ ฅ๋ฌผ์ ์ ๋ขฐ๋์ ํ์ง ๊ฐ์ ์ ํ์ฉ๋๋ฉฐ, ํ์ธํ๋(์: Janus-Pro Finetuning) ๋ฑ์ ๋ฐ์๋ฉ๋๋ค.
+
+3. **์ต์ข
์ฌ์ฉ์ ์ ๊ณต ๊ฒฐ๊ณผ**
+ - **์๊ฐ์ฅ์ ์ธ ๋์ ์ค๋ช
ํ
์คํธ**:
+ ์ต์ข
์ถ๋ ฅ์ ์๊ฐ์ฅ์ ์ธ ์ฌ์ฉ์๊ฐ ์ํ์ ์ธํ๊ณผ ํฌ์ฅ ์ํ๋ฅผ ๋ณด๋ค ์ฝ๊ฒ ์ดํดํ ์ ์๋๋ก ๊ฐ๊ฒฐํ๊ณ ๊ฐ๊ด์ ์ธ ํ
์คํธ ์ค๋ช
ํํ๋ก ์ ๊ณต๋ฉ๋๋ค.
+ - **์ ์ฅ ๋ฐ ํ์ฉ**:
+ ์ต์ข
๊ฒฐ๊ณผ๋ ์ค์ ๋ ํ์ผ ๊ฒฝ๋ก์ ์ ์ฅ๋๋ฉฐ, ์ดํ ์ฌ์ฉ์ ์ธํฐํ์ด์ค(์: ์จ๋ผ์ธ ์ผํ๋ชฐ)์ ํตํฉ๋์ด ์ค์ ์๋น์ค์ ํ์ฉ๋ฉ๋๋ค.
+
+---
+## ์ค์น ๋ฐ ์คํ ๋ฐฉ๋ฒ
+### 1) ํ๊ฒฝ ๊ตฌ์ถ
+- Python 3.10.15 ๋ฒ์ ๊ถ์ฅ
+- ์์กด์ฑ ํจํค์ง ์ค์น
+```bash
+conda env create -f environment.yml
+```
+
+### 2) ์ค์
+- `config/config.yaml` ํ์ผ์์ ๋ค์ ์ ๋ณด๋ฅผ ์ ์ ํ ์ค์ ํฉ๋๋ค.
+ - **API Key / Request ID**: HyperCLOVA X ์ธ์ฆ ์ ๋ณด
+ - **OpenAI API Key**: GPT ๋ชจ๋ธ ์ฌ์ฉ ์ ํ์
+ - **ํ์ผ ๊ฒฝ๋ก**: ๋ฐ์ดํฐ ํ์ผ ์์น, ํ์ธํ๋ ๊ฒฐ๊ณผ ์ ์ฅ ๊ฒฝ๋ก ๋ฑ
+ - **ํ์ดํ๋ผ์ธ ์คํ ์ฌ๋ถ**: `pipeline` ์น์
์ `true`/`false` ๊ฐ์ผ๋ก ํฌ๋กค๋ง/์ธํผ๋ฐ์ค/ํ์ธํ๋ ๋ฑ ๋จ๊ณ๋ณ ์คํ ์ ์ด
+ - **ํ์ธํ๋ ์ค์ **: TBD
+
+### 3) ์คํ
+- ๊ธฐ๋ณธ ์คํ (๊ธฐ๋ณธ `config/config.yaml` ์ฌ์ฉ ์)
+```bash
+python main.py
+```
+- ๋ณ๋ ์ค์ ํ์ผ ์ฌ์ฉ
+```bash
+python main.py --config config/config_name.yaml
+```
+---
diff --git a/models/thumbnail_description/config/config.yaml b/models/thumbnail_description/config/config.yaml
new file mode 100644
index 0000000..7c6c5d3
--- /dev/null
+++ b/models/thumbnail_description/config/config.yaml
@@ -0,0 +1,46 @@
+hcx_api:
+ host: "https://clovastudio.stream.ntruss.com"
+ api_key: "YOUR_API_KEY"
+ request_id: "YOUR_REQUEST_ID"
+
+papago_api:
+ client_id: "YOUR_CLIENT_ID"
+ client_secret: "YOUR_CLIENT_SECRET"
+
+openai:
+ api_key: "OPENAI_API_KEY"
+
+paths:
+ data_dir: "./data"
+ prompt_dir: "./prompt"
+
+ # ์ฌ์ฉ๋ CSV ํ์ผ ๊ฒฝ๋ก๋ค
+ cleaned_text_contents: "cleaned_text_contents.csv"
+ Foodly_323_product_information: "Foodly_323_product_information.csv"
+ thumbnail_1347_gpt_human_labeling_train: "thumbnail_1347_gpt_human_labeling_train.csv"
+
+
+pipeline:
+ sft_pipeline:
+ detailed_feature_description: false
+ janus_pro_7b_finetuning: false
+
+ description_pipeline:
+ inference_model:
+ deepseekvl: false
+ finetuned_janus_pro : false
+ janus_pro: true
+ maal: false
+ qwen2_vl: false
+ qwen2_5_vl: false
+ unsloth_qwen2_vl: false
+
+ post_processing:
+ janus_pro_papago: false
+ janus_pro_hcx_translation: true
+ janus_pro_pp_hcx: true
+ qwen2_5_pp_hcx: false
+
+ evaluation:
+ gpt_eval: false
+ gpt_eval_323: true
\ No newline at end of file
diff --git a/models/thumbnail_description/data/.gitkeep b/models/thumbnail_description/data/.gitkeep
new file mode 100644
index 0000000..e69de29
diff --git a/models/thumbnail_description/environment.yml b/models/thumbnail_description/environment.yml
new file mode 100644
index 0000000..799d5e9
--- /dev/null
+++ b/models/thumbnail_description/environment.yml
@@ -0,0 +1,185 @@
+name: thumbnail
+channels:
+ - xformers
+ - pytorch
+ - nvidia
+ - conda-forge
+ - defaults
+dependencies:
+ - _libgcc_mutex=0.1=main
+ - _openmp_mutex=5.1=1_gnu
+ - asttokens=3.0.0=pyhd8ed1ab_1
+ - blas=1.0=mkl
+ - bzip2=1.0.8=h5eee18b_6
+ - ca-certificates=2024.12.14=hbcca054_0
+ - comm=0.2.2=pyhd8ed1ab_1
+ - cuda-cudart=12.1.105=0
+ - cuda-cupti=12.1.105=0
+ - cuda-libraries=12.1.0=0
+ - cuda-nvrtc=12.1.105=0
+ - cuda-nvtx=12.1.105=0
+ - cuda-opencl=12.4.127=0
+ - cuda-runtime=12.1.0=0
+ - cudatoolkit=11.7.0=hd8887f6_10
+ - debugpy=1.8.11=py311h6a678d5_0
+ - decorator=5.1.1=pyhd8ed1ab_1
+ - exceptiongroup=1.2.2=pyhd8ed1ab_1
+ - executing=2.1.0=pyhd8ed1ab_1
+ - filelock=3.13.1=py311h06a4308_0
+ - gmp=6.2.1=h295c915_3
+ - gmpy2=2.1.2=py311hc9b5ff0_0
+ - importlib-metadata=8.5.0=pyha770c72_1
+ - intel-openmp=2023.1.0=hdb19cb5_46306
+ - ipykernel=6.29.5=pyh3099207_0
+ - ipython=8.31.0=pyh707e725_0
+ - jedi=0.19.2=pyhd8ed1ab_1
+ - jinja2=3.1.4=py311h06a4308_1
+ - jupyter_client=8.6.3=pyhd8ed1ab_1
+ - jupyter_core=5.7.2=pyh31011fe_1
+ - ld_impl_linux-64=2.40=h12ee557_0
+ - libcublas=12.1.0.26=0
+ - libcufft=11.0.2.4=0
+ - libcufile=1.9.1.3=0
+ - libcurand=10.3.5.147=0
+ - libcusolver=11.4.4.55=0
+ - libcusparse=12.0.2.55=0
+ - libffi=3.4.4=h6a678d5_1
+ - libgcc-ng=11.2.0=h1234567_1
+ - libgomp=11.2.0=h1234567_1
+ - libnpp=12.0.2.50=0
+ - libnvjitlink=12.1.105=0
+ - libnvjpeg=12.1.1.14=0
+ - libsodium=1.0.18=h36c2ea0_1
+ - libstdcxx-ng=11.2.0=h1234567_1
+ - libuuid=1.41.5=h5eee18b_0
+ - llvm-openmp=14.0.6=h9e868ea_0
+ - markupsafe=2.1.3=py311h5eee18b_0
+ - matplotlib-inline=0.1.7=pyhd8ed1ab_1
+ - mkl=2023.1.0=h213fc3f_46344
+ - mpc=1.1.0=h10f8cd9_1
+ - mpfr=4.0.2=hb69a4c5_1
+ - mpmath=1.3.0=py311h06a4308_0
+ - ncurses=6.4=h6a678d5_0
+ - nest-asyncio=1.6.0=pyhd8ed1ab_1
+ - networkx=3.3=py311h06a4308_0
+ - openssl=3.0.15=h5eee18b_0
+ - packaging=24.2=pyhd8ed1ab_2
+ - parso=0.8.4=pyhd8ed1ab_1
+ - pexpect=4.9.0=pyhd8ed1ab_1
+ - pickleshare=0.7.5=pyhd8ed1ab_1004
+ - pip=24.2=py311h06a4308_0
+ - platformdirs=4.3.6=pyhd8ed1ab_1
+ - prompt-toolkit=3.0.48=pyha770c72_1
+ - ptyprocess=0.7.0=pyhd8ed1ab_1
+ - pure_eval=0.2.3=pyhd8ed1ab_1
+ - python=3.11.11=he870216_0
+ - python-dateutil=2.9.0.post0=pyhff2d567_1
+ - pytorch=2.5.1=py3.11_cuda12.1_cudnn9.1.0_0
+ - pytorch-cuda=12.1=ha16c6d3_6
+ - pytorch-mutex=1.0=cuda
+ - pyyaml=6.0.2=py311h5eee18b_0
+ - pyzmq=26.2.0=py311h6a678d5_0
+ - readline=8.2=h5eee18b_0
+ - setuptools=75.1.0=py311h06a4308_0
+ - six=1.17.0=pyhd8ed1ab_0
+ - sqlite=3.45.3=h5eee18b_0
+ - stack_data=0.6.3=pyhd8ed1ab_1
+ - tbb=2021.8.0=hdb19cb5_0
+ - tk=8.6.14=h39e8969_0
+ - torchtriton=3.1.0=py311
+ - tornado=6.4.2=py311h5eee18b_0
+ - traitlets=5.14.3=pyhd8ed1ab_1
+ - typing_extensions=4.12.2=py311h06a4308_0
+ - wcwidth=0.2.13=pyhd8ed1ab_1
+ - wheel=0.44.0=py311h06a4308_0
+ - xformers=0.0.28.post3=py311_cu12.1.0_pyt2.5.1
+ - xz=5.4.6=h5eee18b_1
+ - yaml=0.2.5=h7b6447c_0
+ - zeromq=4.3.5=h6a678d5_0
+ - zipp=3.21.0=pyhd8ed1ab_1
+ - zlib=1.2.13=h5eee18b_1
+ - pip:
+ - accelerate==1.2.1
+ - aiohappyeyeballs==2.4.4
+ - aiohttp==3.11.11
+ - aiosignal==1.3.2
+ - annotated-types==0.7.0
+ - anyio==4.8.0
+ - attrdict==2.0.1
+ - attrs==24.3.0
+ - av==14.0.1
+ - bitsandbytes==0.45.0
+ - certifi==2024.12.14
+ - charset-normalizer==3.4.1
+ - contourpy==1.3.1
+ - cut-cross-entropy==24.12.3
+ - cycler==0.12.1
+ - datasets==3.2.0
+ - decord==0.6.0
+ - deepseek-vl==1.0.0
+ - dill==0.3.8
+ - distro==1.9.0
+ - docstring-parser==0.16
+ - einops==0.8.0
+ - fonttools==4.55.6
+ - frozenlist==1.5.0
+ - fsspec==2024.9.0
+ - h11==0.14.0
+ - hf-transfer==0.1.8
+ - httpcore==1.0.7
+ - httpx==0.28.1
+ - huggingface-hub==0.27.0
+ - idna==3.10
+ - janus==1.0.0
+ - jiter==0.8.2
+ - joblib==1.4.2
+ - kiwisolver==1.4.8
+ - levenshtein==0.26.1
+ - markdown-it-py==3.0.0
+ - matplotlib==3.10.0
+ - mdurl==0.1.2
+ - multidict==6.1.0
+ - multiprocess==0.70.16
+ - numpy==2.2.1
+ - openai==1.61.1
+ - pandas==2.2.3
+ - peft==0.14.0
+ - pillow==11.1.0
+ - propcache==0.2.1
+ - protobuf==3.20.3
+ - psutil==6.1.1
+ - pyarrow==18.1.0
+ - pydantic==2.10.6
+ - pydantic-core==2.27.2
+ - pygments==2.19.0
+ - pyparsing==3.2.1
+ - python-levenshtein==0.26.1
+ - pytz==2024.2
+ - qwen-vl-utils==0.0.8
+ - rapidfuzz==3.11.0
+ - regex==2024.11.6
+ - requests==2.32.3
+ - rich==13.9.4
+ - safetensors==0.5.0
+ - scikit-learn==1.6.1
+ - scipy==1.15.1
+ - sentencepiece==0.2.0
+ - shtab==1.7.1
+ - sniffio==1.3.1
+ - sympy==1.13.1
+ - threadpoolctl==3.5.0
+ - timm==1.0.14
+ - tokenizers==0.21.0
+ - torchvision==0.20.1
+ - tqdm==4.67.1
+ - transformers==4.49.0.dev0
+ - trl==0.13.0
+ - typeguard==4.4.1
+ - tyro==0.9.5
+ - tzdata==2024.2
+ - unsloth==2025.1.6
+ - unsloth-zoo==2025.1.5
+ - urllib3==2.3.0
+ - xxhash==3.5.0
+ - yarl==1.18.3
+prefix: /data/ephemeral/home/.condaenv/envs/unsloth
diff --git a/models/thumbnail_description/hcx_prompt/system_janus_pro_hcx_fewshot.txt b/models/thumbnail_description/hcx_prompt/system_janus_pro_hcx_fewshot.txt
new file mode 100644
index 0000000..be4d58b
--- /dev/null
+++ b/models/thumbnail_description/hcx_prompt/system_janus_pro_hcx_fewshot.txt
@@ -0,0 +1,60 @@
+๋น์ ์ ํ
์คํธ๋ฅผ ์ ์ ํ๋ ์ ๋ฌธ๊ฐ์
๋๋ค. ์๋์ ๊ท์น์ ์๊ฒฉํ ์ ์ฉํ์ฌ ์
๋ ฅ ํ
์คํธ๋ฅผ ๋ณํํ์ธ์:
+
+1. ํ
์คํธ ์ค '๋์ด, ๋๋น, ๋ฌด๊ฒ' ๊ด๋ จ ์ ๋ณด๋ ์์ ํ ์ ๊ฑฐํฉ๋๋ค.
+2. ๋์์ธ ์์ฒด์ ๋ํ ํ๊ฐ(์: ๋์์ธ์ด ๋ฉ์ง๋ค, ์๋ฆ๋ต๋ค, ์ธ๋ จ๋๋ค ๋ฑ)๋ ๋ชจ๋ ์ ๊ฑฐํฉ๋๋ค.
+ - ๋จ, ์ ํ์ ์ฃผ์ ์คํ(์ฌ์ง, ์์, ํํ, ํฌ๋ช
์ฑ ๋ฑ)์ ์ ์งํฉ๋๋ค.
+3. '์ถ์ ' ํน์ '๊ฐ๋ฅ์ฑ์ด ๋๋ค' ๊ฐ์ ๋ถํ์คํ ํํ์ด๋ ์ถ์ธก ๋ฌธ์ฅ์ ์ ๊ฑฐํฉ๋๋ค.
+4. ์ต์ข
๊ฒฐ๊ณผ๋ฌผ์ ํ๊ตญ์ด๋ก ์์ฑํฉ๋๋ค.
+
+์๋๋ ๊ท์น์ ์ ์ฉํ ์์๋ค์
๋๋ค.
+
+[์์ 1]
+
+์
๋ ฅ:
+์ ํ์ ์์์ ๋ถํ์์ด๊ณ , ๋์์ธ์ด ๋งค์ฐ ์ธ๋ จ๋์ด ๋ณด์ธ๋ค.
+๋์ด๋ ์ฝ 30cm ์ ๋๋ก ์ถ์ ๋๋ค.
+๋ฌด๊ฒ๋ 200g ๋ด์ธ์ผ ๊ฒ์ผ๋ก ๋ณด์ธ๋ค.
+ํฌ๋ช
ํ ๋ถ๋ถ์ด ์์ด ์์ชฝ ๋ด์ฉ๋ฌผ์ด ๋ณด์.
+
+์ถ๋ ฅ:
+์ ํ์ ์์์ ๋ถํ์์ด๊ณ , ํฌ๋ช
ํ ๋ถ๋ถ์ด ์์ด ์์ชฝ ๋ด์ฉ๋ฌผ์ด ๋ณด์.
+
+(์ค๋ช
:
+- '๋์์ธ์ด ๋งค์ฐ ์ธ๋ จ๋์ด ๋ณด์ธ๋ค' โ ๋์์ธ ํ๊ฐ ๋ฌธ์ฅ ์ ๊ฑฐ
+- '๋์ด๋ ์ฝ 30cm ์ ๋๋ก ์ถ์ ๋๋ค' โ ๋์ด ์ ๋ณด ๋ฐ ์ถ์ ํํ ์ ๊ฑฐ
+- '๋ฌด๊ฒ๋ 200g ๋ด์ธ์ผ ๊ฒ์ผ๋ก ๋ณด์ธ๋ค' โ ๋ฌด๊ฒ ์ ๋ณด ๋ฐ ์ถ์ ํํ ์ ๊ฑฐ
+)
+
+[์์ 2]
+
+์
๋ ฅ:
+ํฌ์ฅ์ง ํํ๋ ์ง์ฌ๊ฐํ์ด๋ฉฐ, ๋ชจ๋ํ ๋์์ธ ๋๋ถ์ ์๊ฐ์ ์ผ๋ก ๊น๋ํ๋ค.
+๋๋น๊ฐ ๋๋ต 20cm๋ก ์์๋๋ค.
+์์์ ์ ์ฒด์ ์ผ๋ก ์ง์ ํ๋์์ด๊ณ , ์ฌ์ง์ ํ๋ผ์คํฑ์ธ ๊ฒ์ผ๋ก ์ถ์ ๋๋ค.
+
+์ถ๋ ฅ:
+ํฌ์ฅ์ง ํํ๋ ์ง์ฌ๊ฐํ์ด๋ฉฐ, ์์์ ์ ์ฒด์ ์ผ๋ก ์ง์ ํ๋์์ด๊ณ , ์ฌ์ง์ ํ๋ผ์คํฑ์ด๋ค.
+
+(์ค๋ช
:
+- '๋ชจ๋ํ ๋์์ธ ๋๋ถ์ ์๊ฐ์ ์ผ๋ก ๊น๋ํ๋ค' โ ๋์์ธ ํ๊ฐ ๋ฌธ์ฅ ์ ๊ฑฐ
+- '๋๋น๊ฐ ๋๋ต 20cm๋ก ์์๋๋ค' โ ๋๋น ์ ๋ณด ๋ฐ ์ถ์ ํํ ์ ๊ฑฐ
+- '์ฌ์ง์ ํ๋ผ์คํฑ์ธ ๊ฒ์ผ๋ก ์ถ์ ๋๋ค' โ '์ถ์ ' ์ ๊ฑฐ ํ '์ฌ์ง์ ํ๋ผ์คํฑ์ด๋ค'๋ก ๋ณ๊ฒฝ
+)
+
+[์์ 3]
+
+์
๋ ฅ:
+์ด ์ ํ์ ํฌ๋ช
์ฉ๊ธฐ์ ๋ค์ด ์์ด ์์ด ํคํ ๋ณด์ด๋ฉฐ,
+์ด๋์ด ์ด์ฝ๋ฆฟ์๊ณผ ์์ํ ๊ธ์ ๋ก๊ณ ๊ฐ ์ธ์์ ์ด๋ค.
+๋ฌด๊ฒ๋ 350g ์ ๋์ผ ๊ฒ ๊ฐ๊ณ ,
+์ ์ฒด์ ์ธ ๋์์ธ ์์ฑ๋๊ฐ ๋ฐ์ด๋ ๋ณด์ธ๋ค.
+
+์ถ๋ ฅ:
+์ด ์ ํ์ ํฌ๋ช
์ฉ๊ธฐ์ ๋ค์ด ์์ด ์์ด ํคํ ๋ณด์ด๋ฉฐ, ์ด๋์ด ์ด์ฝ๋ฆฟ์๊ณผ ์์ํ ๊ธ์ ๋ก๊ณ ๊ฐ ์๋ค.
+
+(์ค๋ช
:
+- '๋ฌด๊ฒ๋ 350g ์ ๋์ผ ๊ฒ ๊ฐ๊ณ ' โ ๋ฌด๊ฒ ์ ๋ณด ๋ฐ ์ถ์ ํํ ์ ๊ฑฐ
+- '์ ์ฒด์ ์ธ ๋์์ธ ์์ฑ๋๊ฐ ๋ฐ์ด๋ ๋ณด์ธ๋ค' โ ๋์์ธ ํ๊ฐ ๋ฌธ์ฅ ์ ๊ฑฐ
+)
+
+์ด์ ๊ท์น๊ณผ ์์๋ฅผ ๋ฐํ์ผ๋ก, ์
๋ ฅ ํ
์คํธ๋ฅผ ์ ์ ํ์ธ์.
\ No newline at end of file
diff --git a/models/thumbnail_description/hcx_prompt/system_janus_pro_hcx_translation.txt b/models/thumbnail_description/hcx_prompt/system_janus_pro_hcx_translation.txt
new file mode 100644
index 0000000..64a7220
--- /dev/null
+++ b/models/thumbnail_description/hcx_prompt/system_janus_pro_hcx_translation.txt
@@ -0,0 +1 @@
+๋น์ ์ ์ํ์ ํต์ฌ ์ ๋ณด๋ฅผ ์๊ฐํ๋ ์ ๋ฌธ ์นดํผ๋ผ์ดํฐ์
๋๋ค. ๋น์ ์ ๋ชฉํ๋ ์ํ์ ๊ฐ์น๋ฅผ ๋ช
ํํ๊ฒ ์ ๋ฌํ์ฌ ์๋น์์ ๊ตฌ๋งค ๊ฒฐ์ ์ ์ ๋ํ๋ ๊ฒ์
๋๋ค.
\ No newline at end of file
diff --git a/models/thumbnail_description/hcx_prompt/system_qwen2_5_pp_hcx.txt b/models/thumbnail_description/hcx_prompt/system_qwen2_5_pp_hcx.txt
new file mode 100644
index 0000000..a97844a
--- /dev/null
+++ b/models/thumbnail_description/hcx_prompt/system_qwen2_5_pp_hcx.txt
@@ -0,0 +1,2 @@
+๋น์ ์ ์ํ์ ํต์ฌ์ ์ผ๋ก ์๊ฐํ๋ ์ ํ ๋ด๋น๊ด์
๋๋ค.
+๋น์ ์ ๋ชฉํ๋ ์ต์ํ์ ๋ฌธ๊ตฌ๋ก ํน์ง์ ๋ช
ํํ๊ฒ ์ ๋ฌํ์ฌ ์๋น์์ ๊ตฌ๋งค ๊ฒฐ์ ์ ์ ๋ํ๋ ๊ฒ์
๋๋ค.
\ No newline at end of file
diff --git a/models/thumbnail_description/hcx_prompt/user_janus_pro_hcx_fewshot.txt b/models/thumbnail_description/hcx_prompt/user_janus_pro_hcx_fewshot.txt
new file mode 100644
index 0000000..d29268e
--- /dev/null
+++ b/models/thumbnail_description/hcx_prompt/user_janus_pro_hcx_fewshot.txt
@@ -0,0 +1 @@
+์ฃผ์ด์ง '{model_output_text}'์์ ๋์ด์ ๋๋น ๋ฐ ๋ฌด๊ฒ ์ ๋ณด๋ ์ ์ธํ์ฌ ํ๊ตญ์ด๋ก ํด์ํ์ธ์. ์ถ์ ์ ๋งํฌ๋ '๊ฐ๋ฅ์ฑ์ด ๋๋ค'์ ๊ฐ์ ํํ๋ค์ด ๋ค์ด๊ฐ ๋ฌธ์ฅ์ ์ ๊ฑฐํ์ธ์ ๋์์ธ์ ๋ํ ์ ๋ฐ์ ์ธ ํ๊ฐ๋ ๋ชจ๋ ์ ๊ฑฐํ์ธ์. ์ค๋ณต๋ ํํ๋ค์ ์ต๋ํ ์ ๊ฑฐํ์ธ์.
\ No newline at end of file
diff --git a/models/thumbnail_description/hcx_prompt/user_janus_pro_hcx_translation.txt b/models/thumbnail_description/hcx_prompt/user_janus_pro_hcx_translation.txt
new file mode 100644
index 0000000..ebdf5c2
--- /dev/null
+++ b/models/thumbnail_description/hcx_prompt/user_janus_pro_hcx_translation.txt
@@ -0,0 +1 @@
+์ฃผ์ด์ง '{model_output_text}'์์ ๋์ด์ ๋๋น ๋ฐ ๋ฌด๊ฒ ์ ๋ณด๋ ์ ์ธํ์ฌ ํ๊ตญ์ด๋ก ํด์ํ์ธ์. ์ถ์ ์ ๋งํฌ๋ ๊ฐ๋ฅ์ฑ์ด ๋๋ค์ ๊ฐ์ ํํ๋ค์ด ๋ค์ด๊ฐ ๋ฌธ์ฅ์ ์ ๊ฑฐํ๊ณ ๋์์ธ์ ๋ํ ์ ๋ฐ์ ์ธ ํ๊ฐ๋ ๋ชจ๋ ์ ๊ฑฐํด์ค
\ No newline at end of file
diff --git a/models/thumbnail_description/hcx_prompt/user_qwen2_5_pp_hcx.txt b/models/thumbnail_description/hcx_prompt/user_qwen2_5_pp_hcx.txt
new file mode 100644
index 0000000..86630a5
--- /dev/null
+++ b/models/thumbnail_description/hcx_prompt/user_qwen2_5_pp_hcx.txt
@@ -0,0 +1,6 @@
+{
+ '{model_output_text}'์ ๋ ๋ฒ์งธ ๋ฌธ์ฅ ์ดํ ๋ฌธ์ฅ๋ค์ ํ ๋ฌธ์ฅ์ผ๋ก๋ง ์ถ๋ ฅํด์ผ ํฉ๋๋ค.
+ ๊ตฌ์ฑํ, ์ธ์ฆ๋งํฌ๋ฅผ ์ ์์ฝํด์ ํ ์ค๋ก ๋ํ๋ด๊ณ ์กฐ๋ฆฌ๋ฐฉ๋ฒ๊ณผ ํ์ฉ๋ฒ์ ์ ๋ ์ธ๊ธํ์ง ๋ง์ธ์.
+ ์ ํ๋ช
์ ๋ง์ง๋ง์ ์ธ๊ธํ๊ณ ํด์์ฒด๊ฐ ์๋๋ผ ๋ฌธ์ฅ ๋ง๋ฌด๋ฆฌ ์์ ์ด๋ฅผ '์
๋๋ค'๋ก ์๋ฒฝํ๊ฒ ๋ง๋ฌด๋ฆฌํด์ผ ํฉ๋๋ค.
+ ๋์ด์ฐ๊ธฐ์ ๋ง์ถค๋ฒ์ ํ๋ฆฌ์ง ๋ง์ธ์.
+}
\ No newline at end of file
diff --git a/models/thumbnail_description/main.py b/models/thumbnail_description/main.py
new file mode 100644
index 0000000..6c8626c
--- /dev/null
+++ b/models/thumbnail_description/main.py
@@ -0,0 +1,124 @@
+import argparse
+import yaml
+import logging
+
+# SFT ํ์ดํ๋ผ์ธ ๋ชจ๋ import
+from src.sft_pipeline import (
+ detailed_feature_description,
+ janus_pro_7b_finetuning
+)
+
+# Description ํ์ดํ๋ผ์ธ ๋ด ๋ชจ๋ธ ์ถ๋ก ๋ชจ๋ import
+from src.description_pipeline.inference_model import (
+ deepseekvl,
+ janus_pro,
+ maal,
+ qwen2_vl,
+ qwen2_5_vl,
+ unsloth_qwen2_vl,
+ finetuned_janus_pro
+)
+# Description ํ์ดํ๋ผ์ธ ๋ด ํ์ฒ๋ฆฌ ๋ชจ๋ import
+from src.description_pipeline.post_processing import (
+ janus_pro_papago_translation,
+ janus_pro_hcx_translation,
+ janus_pro_pp_hcx,
+ qwen2_5_pp_hcx
+)
+# Description ํ์ดํ๋ผ์ธ ๋ด ํ๊ฐ ๋ชจ๋ import
+from src.description_pipeline.evaluation import (
+ gpt_eval,
+ gpt_eval_323
+)
+
+def setup_logger():
+ logging.basicConfig(
+ level=logging.INFO,
+ format="%(asctime)s [%(levelname)s] %(name)s - %(message)s"
+ )
+
+def run_sft_pipeline(config):
+ logging.info("SFT ํ์ดํ๋ผ์ธ ์์")
+ # config.yaml์์ sft ํ์ดํ๋ผ์ธ ๊ด๋ จ ์ค์ ์ pipeline > sft_pipeline ์ ์์นํจ
+ sft_config = config.get("pipeline", {}).get("sft_pipeline", {})
+ if sft_config.get("detailed_feature_description", False):
+ detailed_feature_description.run_detailed_feature_description(config)
+ if sft_config.get("janus_pro_7b_finetuning", False):
+ janus_pro_7b_finetuning.run_janus_pro_7b_finetuning(config)
+ logging.info("SFT ํ์ดํ๋ผ์ธ ์๋ฃ.")
+
+def run_description_pipeline(config):
+ logging.info("Description ํ์ดํ๋ผ์ธ ์์")
+ # config.yaml์์ description ๊ด๋ จ ์ค์ ์ pipeline > description_pipeline ์ ์์นํจ
+ desc_config = config.get("pipeline", {}).get("description_pipeline", {})
+
+ # 1) ๋ชจ๋ธ๋ณ Inference ๋จ๊ณ
+ inference_cfg = desc_config.get("inference_model", {})
+ if inference_cfg.get("deepseekvl", False):
+ deepseekvl.run_inference(config)
+ if inference_cfg.get("janus_pro", False):
+ janus_pro.run_inference(config)
+ if inference_cfg.get("maal", False):
+ maal.run_inference(config)
+ if inference_cfg.get("qwen2_vl", False):
+ qwen2_vl.run_inference(config)
+ if inference_cfg.get("qwen2_5_vl", False):
+ qwen2_5_vl.run_inference(config)
+ if inference_cfg.get("unsloth_qwen2_vl", False):
+ unsloth_qwen2_vl.run_inference(config)
+ if inference_cfg.get("finetuned_janus_pro", False):
+ finetuned_janus_pro.run_inference(config)
+
+ # 2) ํ์ฒ๋ฆฌ ๋จ๊ณ
+ postproc_cfg = desc_config.get("post_processing", {})
+ if postproc_cfg.get("janus_pro_papago_translation", False):
+ janus_pro_papago_translation.run_post_processing(config)
+ if postproc_cfg.get("janus_pro_hcx_translation", False):
+ janus_pro_hcx_translation.run_post_processing(config)
+ if postproc_cfg.get("janus_pro_pp_hcx", False):
+ janus_pro_pp_hcx.run_post_processing(config)
+ if postproc_cfg.get("qwen2_5_pp_hcx", False):
+ qwen2_5_pp_hcx.run_post_processing(config)
+
+ # 3) Evaluation ๋จ๊ณ
+ eval_cfg = desc_config.get("evaluation", {})
+ if eval_cfg.get("gpt_eval", False):
+ gpt_eval.run_evaluation(config)
+ if eval_cfg.get("gpt_eval_323", False):
+ gpt_eval_323.run_evaluation(config)
+
+ logging.info("Description ํ์ดํ๋ผ์ธ ์๋ฃ.")
+
+def main():
+ setup_logger()
+ parser = argparse.ArgumentParser(description="ํ์ดํ๋ผ์ธ ์คํ")
+ parser.add_argument(
+ "--config",
+ "-c",
+ default="config/config.yaml",
+ help="์ค์ ํ์ผ ๊ฒฝ๋ก (๊ธฐ๋ณธ๊ฐ: config/config.yaml)"
+ )
+ parser.add_argument(
+ "--pipeline",
+ "-p",
+ choices=["sft", "description", "all"],
+ default="all",
+ help="์คํํ ํ์ดํ๋ผ์ธ ์ ํ (sft, description, all)"
+ )
+ args = parser.parse_args()
+
+ # ์ค์ ํ์ผ ๋ก๋
+ with open(args.config, "r", encoding="utf-8") as f:
+ config = yaml.safe_load(f)
+
+ # ์ ํํ ํ์ดํ๋ผ์ธ ์คํ
+ if args.pipeline in ["sft", "all"]:
+ run_sft_pipeline(config)
+
+ if args.pipeline in ["description", "all"]:
+ run_description_pipeline(config)
+
+ logging.info("All pipeline tasks completed.")
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/models/thumbnail_description/prompt/deepseek_prompt.txt b/models/thumbnail_description/prompt/deepseek_prompt.txt
new file mode 100644
index 0000000..cb4aa6d
--- /dev/null
+++ b/models/thumbnail_description/prompt/deepseek_prompt.txt
@@ -0,0 +1,4 @@
+Provide a detailed description of the packaging container (shape, material, design, color, and contents)
+that ensures the caption is of high quality and accessible for visually impaired individuals.
+Do not mention screen readers in the description under any circumstances.
+Write an image evaluation based solely on facts. Do not mention any opinions or evaluations about the design.
\ No newline at end of file
diff --git a/models/thumbnail_description/prompt/janus_prompt.txt b/models/thumbnail_description/prompt/janus_prompt.txt
new file mode 100644
index 0000000..97e26a3
--- /dev/null
+++ b/models/thumbnail_description/prompt/janus_prompt.txt
@@ -0,0 +1,15 @@
+You are a professional copywriter specializing in describing the appearance of food products.
+Your goal is to help visually impaired individuals objectively understand the productโs appearance,
+enabling them to make informed purchasing decisions.
+
+Provide a detailed description of the packaging container (shape, material, color) that ensures
+the caption is of high quality. If both the product design and the product are present, describe
+only the product design. Do not describe any text data on the design or labels.
+
+"output_requirements": [
+ "1. Output must be only 2 sentences.",
+ "2. The first sentence must describe the shape and material of the packaging container.",
+ "3. The second sentence must describe the color design and key features factually and concisely.",
+ "4. Do not mention screen readers or opinions in the description.",
+ "5. Reflect only text written in English and do not translate to Chinese."
+]
\ No newline at end of file
diff --git a/models/thumbnail_description/prompt/maal_prompt.txt b/models/thumbnail_description/prompt/maal_prompt.txt
new file mode 100644
index 0000000..c16efc6
--- /dev/null
+++ b/models/thumbnail_description/prompt/maal_prompt.txt
@@ -0,0 +1,14 @@
+{product_name}์ ํ์ ํน์ง(ํฌ์ฅ ์ฉ๊ธฐ ๋ชจ์, ์ฌ์ง, ๋์์ธ, ์, ๋ด์ฉ๋ฌผ ๋ฑ)์ ๋ํด
+์ต๋ 70์ ์ด๋ด๋ก ๊ฐ๋จํ ์ค๋ช
ํด์ฃผ์ธ์.
+์๊ฐ์ฅ์ ์ธ์ด ์ดํดํ ์ ์๋๋ก ํ์ง ๋์ ์บก์
ํ์์ ์ ์งํ์ธ์.
+์ฒซ์ค์๋ '{product_name}์ ๋ํ ๋ํ ์ด๋ฏธ์ง์
๋๋ค'๋ผ๊ณ ์์ํด์ผ ํฉ๋๋ค.
+
+0. ํฌ์ฅ์ ์ธ๋ถ ๋์์ธ๊ณผ ๋ด๋ถ ๋ด์ฉ๋ฌผ์ ๋ช
ํํ ๊ตฌ๋ถํ์ธ์.
+1. ํฌ์ฅ์ ๋ชจ์์ ๋ช
ํํ ์ค๋ช
ํ์ธ์ (์: ์ง์ฌ๊ฐํ ํ๋ผ์คํฑ ์ฉ๊ธฐ, ์ํ ์ ๋ฆฌ๋ณ ๋ฑ).
+2. ํฌ์ฅ ์ฌ์ง์ ๊ตฌ์ฒด์ ์ผ๋ก ์์ฑํ์ธ์ (์: ํฌ๋ช
ํ๋ผ์คํฑ, ์ข
์ด, ๋น๋ ํฉ ๋ฑ).
+3. ์์์ด๋ ํฌ๋ช
๋ ์ ๋ณด๊ฐ ๊ณ ๊ฐ ์ธ์์ ์ ์ฉํ๋๋ก ์์ฑํ์ธ์.
+4. ๋ด์ฉ๋ฌผ์ด ์ด๋ป๊ฒ ์ ์ฅ๋๋์ง ๊ตฌ์ฒด์ ์ผ๋ก ์ค๋ช
ํ์ธ์ (์: ์ง๊ณต ํฌ์ฅ, ๊ฐ๋ณ ํฌ์ฅ).
+5. ์ค์ ์ํ๊ณผ ์ผ์นํ๋๋ก ๋ด์ฉ๋ฌผ ์ํ๋ฅผ ๋ฌ์ฌํ์ธ์ (์: ์ก์ฒด ๋ด์ฉ๋ฌผ, ๋ฐ๋ด ์ฒ๋ฆฌ).
+6. ํฌ์ฅ์ ์ฃผ์ ์์๊ณผ ํฌ๋ช
์ฌ๋ถ๋ฅผ ๋ช
ํํ ์ธ๊ธํ์ธ์.
+
+โป 70์๋ฅผ ์ด๊ณผํ ๊ฒฝ์ฐ ๊ทธ๋๋ก ์ถ๋ ฅํ๊ฑฐ๋, ์ต๋ํ ๊ฐ๊ฒฐํ๊ฒ ์ถ์ฝ ๋ถํ๋๋ฆฝ๋๋ค.
\ No newline at end of file
diff --git a/models/thumbnail_description/prompt/qwen2_5_prompt.txt b/models/thumbnail_description/prompt/qwen2_5_prompt.txt
new file mode 100644
index 0000000..004d7a8
--- /dev/null
+++ b/models/thumbnail_description/prompt/qwen2_5_prompt.txt
@@ -0,0 +1,34 @@
+{
+ "task": "์๊ฐ์ฅ์ ์ธ์ ์ํ ์ด๋ฏธ์ง ์บก์
์์ฑ (Qwen2.5)",
+ "input": {
+ "image": "์ด ์ด๋ฏธ์ง๋ ์ํ ํจํค์ง๋ก ๊ตฌ์ฑ๋ ์ฅ๋ฉด์
๋๋ค.",
+ "product_name": "{product_name}"
+ },
+ "steps": [
+ {
+ "step": 1,
+ "instruction": "ํฌ์ฅ ์ฉ๊ธฐ, ๋ด์ฉ๋ฌผ, ์์, ์ฌ์ง, ํฌ๋ช
๋, ๋ฐ๋ด ์ํ ๋ฑ์ ์ค๋ช
ํ์ธ์.",
+ "actions": [
+ "1. {product_name}์ ์ฐ๊ด๋ ํต์ฌ ๋ฌธ๊ตฌ๋ฅผ ์์ฑํ์ธ์.",
+ "2. ํฌ์ฅ ์ฉ๊ธฐ์ ๊ตฌ์ฒด์ ํ์์ ์ธ๊ธํ์ธ์ (์: ์ํ, ์ง์ฌ๊ฐํ, ์ ๋ฆฌ๋ณ ๋ฑ).",
+ "3. ํฌ์ฅ ์ฌ์ง(ํ๋ผ์คํฑ, ์ข
์ด, ์ ๋ฆฌ ๋ฑ)๊ณผ ํฌ๋ช
์ฌ๋ถ(ํฌ๋ช
, ๋ถํฌ๋ช
)๋ฅผ ๋ช
ํํ ์ ์ผ์ธ์.",
+ "4. ๋ด์ฉ๋ฌผ์ ํํ(๋ฉ์ด๋ฆฌ, ์ก์ฒด, ๊ฐ๋ฃจ ๋ฑ)์ ํฌ์ฅ ์ํ(๊ฐ๋ณ, ์ง๊ณต ๋ฑ)๋ฅผ ๊ธฐ์ ํ์ธ์."
+ ]
+ },
+ {
+ "step": 2,
+ "instruction": "์๊ฐ์ฅ์ ์ธ์ ์ํ ์ ๊ทผ์ฑ ์ค๋ช
์ ์์ฑํ์ธ์.",
+ "actions": [
+ "1. ์ฒซ ๋ฌธ์ฅ์ '{product_name}์ ๋ํ ์ด๋ฏธ์ง์
๋๋ค'๋ก ์์ํ์ธ์.",
+ "2. 80์ ์ด๋ด๋ก ๊ฐ๊ฒฐํ๊ฒ ํต์ฌ ํฌ์ธํธ๋ง ์ ๋ฌํ์ธ์.",
+ "3. ์ค๋ณต ํํ ์์ด ์ง๊ด์ ์ธ ๋ฌธ์ฅ์ ์ฌ์ฉํ์ธ์.",
+ "4. ๋ถํ์ํ ์์์ด๋ ์ต๋ํ ๋ฐฐ์ ํ๊ณ ์ง์ค์ ์ผ๋ก ๋ฌ์ฌํ์ธ์."
+ ]
+ }
+ ],
+ "output_requirements": [
+ "1. ์ถ๋ ฅ์ ๋ฐ๋์ ํ๊ตญ์ด๋ก ์์ฑํ์ธ์.",
+ "2. 80์ ์ด๋ด๋ก ์์ฝํด์ฃผ์ธ์.",
+ "3. ์ค์ ์ํ ์ ๋ณด๋ฅผ ํฌ๊ฒ ๋ฒ์ด๋์ง ๋ง์ธ์."
+ ]
+}
\ No newline at end of file
diff --git a/models/thumbnail_description/prompt/qwen2_prompt.txt b/models/thumbnail_description/prompt/qwen2_prompt.txt
new file mode 100644
index 0000000..7d195e9
--- /dev/null
+++ b/models/thumbnail_description/prompt/qwen2_prompt.txt
@@ -0,0 +1,37 @@
+{
+ "task": "์๊ฐ์ฅ์ ์ธ์ ์ํ ์ด๋ฏธ์ง ์บก์
์์ฑ",
+ "input": {
+ "image": "ํด๋น ์ด๋ฏธ์ง๋ ์๋ฃํ ํฌ์ฅ๊ณผ ๋ด์ฉ์ ์ค๋ช
ํ๋ ์ด๋ฏธ์ง์
๋๋ค.",
+ "product_name": "{product_name}"
+ },
+ "steps": [
+ {
+ "step": 1,
+ "instruction": "์ ํ ์ด๋ฏธ์ง์ ์ฃผ์ ํน์ง์ ์ค์ฌ์ผ๋ก, ๋ด์ฉ๋ฌผ๊ณผ ํฌ์ฅ ์ํ, ๋์์ธ, ์ฌ์ง์ ๊ตฌ์ฒด์ ์ผ๋ก ์ค๋ช
ํ์ธ์.",
+ "actions": [
+ "1. {product_name}๊ณผ ๊ด๋ จ๋ ํต์ฌ ๋ฌธ๊ตฌ๋ฅผ ์์ฑํ์ธ์.",
+ "2. ๋ด์ฉ๋ฌผ์ ํํ์ ํฌ์ฅ ์ํ๋ฅผ ์ ๋ฆฌํ์ธ์.",
+ "3. ์ฃผ์ ์์๊ณผ ๋์์ธ ์์๋ฅผ ๊ธฐ์ ํ์ธ์.",
+ "4. ํฌ์ฅ์ ์ฌ์ง๊ณผ ํํ๋ฅผ ๊ฐ๊ด์ ์ผ๋ก ํํํ์ธ์.",
+ "5. ํฌ์ฅ์ด ํฌ๋ช
ํ์ง, ๋ฐ๋ด ์ฌ๋ถ๊ฐ ์ด๋ ํ์ง ์ ์ด์ฃผ์ธ์."
+ ]
+ },
+ {
+ "step": 2,
+ "instruction": "์๊ฐ์ฅ์ ์ธ์ ์ํ ์ ๊ทผ์ฑ ์ค๋ช
์ ์์ฑํ์ธ์.",
+ "actions": [
+ "1. ์ฒซ ๋ฌธ์ฅ์ '{product_name}์ ๋ํ ์ด๋ฏธ์ง์
๋๋ค'๋ก ์์ํ์ธ์.",
+ "2. 80์ ์ด๋ด๋ก ๊ฐ๊ฒฐํ๊ณ ๋ช
ํํ๊ฒ ํต์ฌ ์ ๋ณด๋ฅผ ์ ๋ฌํ์ธ์.",
+ "3. ์ค๋ณต ํํ ์์ด ์ง๊ด์ ์ธ ๋ฌธ์ฅ์ ์ฌ์ฉํ์ธ์.",
+ "4. ๋ถํ์ํ ๋จ์ด๋ฅผ ์ ๊ฑฐํ๊ณ ์ ํ์ฑ์ ์ ์งํ์ธ์."
+ ]
+ }
+ ],
+ "output_requirements": [
+ "1. ์ถ๋ ฅ์ ๋ฐ๋์ ํ๊ตญ์ด๋ก ์์ฑํ์ธ์.",
+ "2. ์ฒซ ๋ฌธ์ฅ์ '{product_name}์ ๋ํ ์ด๋ฏธ์ง์
๋๋ค'๋ก ์์ํด์ผ ํฉ๋๋ค.",
+ "3. ์ต๋ 80์๋ก ์์ฑํ๋ฉฐ, ํต์ฌ ์ ๋ณด๋ง ๋ด์์ผ ํฉ๋๋ค.",
+ "4. ์ค๋ณต๋ ํํ ์์ด ๊ฐ๊ฒฐํ๊ฒ ์ค๋ช
ํ์ธ์.",
+ "5. ์ค์ ์ด๋ฏธ์ง์ ์ผ์นํ์ง ์๋ ์ ๋ณด๋ ํฌํจํ์ง ๋ง์ธ์."
+ ]
+}
\ No newline at end of file
diff --git a/models/thumbnail_description/prompt/unsloth_prompt.txt b/models/thumbnail_description/prompt/unsloth_prompt.txt
new file mode 100644
index 0000000..ef0f8da
--- /dev/null
+++ b/models/thumbnail_description/prompt/unsloth_prompt.txt
@@ -0,0 +1,35 @@
+{
+ "task": "์ธ๋ค์ผ ์ด๋ฏธ์ง ๋ฌ์ฌ",
+ "input": {
+ "image": "{product_name}์ ์
๋ ฅ ์ด๋ฏธ์ง๋ ๋ํ์ฑ์ ๋ ๋ ์๋ฃํ์ ์ด๋ฏธ์ง์
๋๋ค.",
+ "product_name": "{product_name}"
+ },
+ "steps": [
+ {
+ "step": 1,
+ "instruction": "ํฌ์ฅ ์ฉ๊ธฐ, ๋ด์ฉ๋ฌผ, ๋์์ธ ์์๋ฅผ ๊ตฌ์ฒด์ ์ผ๋ก ์ค๋ช
ํ์ธ์.",
+ "actions": [
+ "0. ์ธ๋ถ ๋์์ธ๊ณผ ๋ด๋ถ ๋ด์ฉ๋ฌผ์ ๊ตฌ๋ถํ์ฌ ์์ฑํ์ธ์.",
+ "1. ํฌ์ฅ์ ๋ชจ์(์ง์ฌ๊ฐํ, ์ํ, ์ข
์ด ์์, ํ๋ผ์คํฑ ์ฉ๊ธฐ ๋ฑ)์ ๋ช
ํํ ์ค๋ช
ํ์ธ์.",
+ "2. ํฌ์ฅ ์ฌ์ง(์ข
์ด, ์ ๋ฆฌ, ๊ธ์, ๋น๋ ๋ฑ)์ ๊ตฌ์ฒด์ ์ผ๋ก ์ธ๊ธํ์ธ์.",
+ "3. ์์์ด๋ ํฌ๋ช
๋ ์ ๋ณด๋ฅผ ๋ถ๋ช
ํ ๊ธฐ์ฌํ์ธ์ (ํฌ๋ช
ํ๋ผ์คํฑ, ๋ถํฌ๋ช
์ข
์ด ๋ฑ).",
+ "4. ๋ด์ฉ๋ฌผ์ด ์ด๋ป๊ฒ ์ ์ฅ๋๋์ง (๊ฐ๋ณํฌ์ฅ, ์ง๊ณตํฌ์ฅ ๋ฑ)๋ฅผ ๊ฐ๋จํ ์ ์ผ์ธ์."
+ ]
+ },
+ {
+ "step": 2,
+ "instruction": "์๊ฐ์ฅ์ ์ธ์ ์ํ ์ ๊ทผ์ฑ ์ค๋ช
์ ์์ฑํ์ธ์.",
+ "actions": [
+ "1. ์ฒซ์ค์๋ '{product_name}์ ๋ํ ๋ํ ์ด๋ฏธ์ง์
๋๋ค'๋ก ์์ํด์ผ ํฉ๋๋ค.",
+ "2. ์ต๋ 70์ ์ด๋ด๋ก ๊ฐ๊ฒฐํ๊ณ ๋ช
ํํ๊ฒ ์์ฑํ์ธ์.",
+ "3. ๋ฐ๋ณต๋๋ ๋ฌธ์ฅ์ ์ค์ด๊ณ ํต์ฌ ์ ๋ณด๋ฅผ ์ ๋ฌํ์ธ์.",
+ "4. '์๊ฐ์ฅ์ ์ธ์ด ์ดํดํ๋๋ก ์์ฑ' ๊ฐ์ ํํ์ ์ง์ ์ธ๊ธํ์ง ๋ง์ธ์."
+ ]
+ }
+ ],
+ "output_requirements": [
+ "1. ์ต๋ 70์ ์ด๋ด๋ก ์์ฑํ์ธ์.",
+ "2. ์ค๋ณต ํํ์ ํผํ๊ณ , ์ง๊ด์ ์ผ๋ก ์์ฑํด์ฃผ์ธ์.",
+ "3. ์ค์ ์ด๋ฏธ์ง์ ๋ถํฉํ๋ ๋ด์ฉ์ ์ฐ์ ํ์ธ์."
+ ]
+}
\ No newline at end of file
diff --git a/models/thumbnail_description/src/description_pipeline/evaluation/gpt_eval.py b/models/thumbnail_description/src/description_pipeline/evaluation/gpt_eval.py
new file mode 100644
index 0000000..bcf9ffa
--- /dev/null
+++ b/models/thumbnail_description/src/description_pipeline/evaluation/gpt_eval.py
@@ -0,0 +1,260 @@
+from utils.common_utils import (
+ set_seed, requests, pd, time
+)
+import yaml
+import os
+import re
+import numpy as np
+import matplotlib.pyplot as plt
+from concurrent.futures import ThreadPoolExecutor, as_completed
+
+import openai
+
+def main():
+ # 1) config.yaml ๋ก๋
+ with open("config/config.yaml", "r", encoding="utf-8") as f:
+ config = yaml.safe_load(f)
+
+ # 2) OpenAI API ํค ์ค์
+ openai.api_key = config["openai"]["api_key"]
+
+ # 3) CSV ํ์ผ ๊ฒฝ๋ก ์ค์
+ data_dir = config["paths"]["data_dir"]
+ output_csv = "outputs/gpt_eval_result_janus+qwen2_5.csv"
+
+ internvl_eval_path = os.path.join(data_dir, config["paths"]["internVL_eval"])
+ maai_eval_path = os.path.join(data_dir, config["paths"]["maai_eval"])
+ qwen_unsloth_eval_path= os.path.join(data_dir, config["paths"]["unsloth_qwen2_eval"])
+ qwen2_eval_path = os.path.join(data_dir, config["paths"]["qwen2_eval"])
+ deepseekvl_eval_path = os.path.join(data_dir, config["paths"]["deepseekvl_eval"])
+ qwen2_5_eval_path = os.path.join(data_dir, config["paths"]["qwen2.5_eval"])
+ janus_eval_path = os.path.join(data_dir, config["paths"]["janus_pro_eval"])
+ qwen2_5_janus_eval_path = os.path.join(data_dir, config["paths"]["qwen2_5+janus_eval"])
+
+ # 4) CSV ๋ก๋
+ internvl_eng_eval = pd.read_csv(internvl_eval_path)
+ maai_eval = pd.read_csv(maai_eval_path)
+ qwen_unsloth_eval = pd.read_csv(qwen_unsloth_eval_path)
+ qwen2_eval = pd.read_csv(qwen2_eval_path)
+ deepseekvl_eval = pd.read_csv(deepseekvl_eval_path)
+ qwen2_5_eval = pd.read_csv(qwen2_5_eval_path)
+ janus_eval = pd.read_csv(janus_eval_path)
+ qwen2_5_janus_eval= pd.read_csv(qwen2_5_janus_eval_path)
+
+
+ # 5) GPT ํ๊ฐ ๋ก์ง ์ค๋น
+ def calculate_total_score_from_gpt_eval(eval_text):
+ """
+ GPT๊ฐ ๋ฐํํ ํ๊ฐ ํ
์คํธ ๋ด์ 'ํ๊ฐ: x/5' ํํ์ ํญ๋ชฉ์ ์ ๊ท์์ผ๋ก ์ฐพ์
+ ํฉ๊ณ๋ฅผ ๊ตฌํ๋ ํจ์ (10๊ฐ ํญ๋ชฉ * 0~5์ ).
+ """
+ try:
+ pattern = r"ํ๊ฐ:\s*([\d\.]+)/5"
+ scores = [float(match) for match in re.findall(pattern, eval_text)]
+ return sum(scores)
+ except Exception as e:
+ print(f"Error calculating total score: {e}")
+ return None
+
+ def evaluate_gpt(row):
+ """
+ GPT API ํธ์ถ ํจ์.
+ row: pd.Series ํํ (์ฃผ์ด์ง CSV ํ ํ)
+ """
+ model_output_text = row.get("Model Output", "")
+ prompt = f"""
+ ์ด๋ฏธ์ง ์บก์
๋ ๊ฒฐ๊ณผ:
+ {model_output_text}
+
+ ์ ์ ๋ณด๋ ์๋ณธ ์ด๋ฏธ์ง์ ์ด์ ๋ํ VLM ๋ชจ๋ธ์ ์ด๋ฏธ์ง ์บก์
๋ ๊ฒฐ๊ณผ์
๋๋ค. ์๋ ํญ๋ชฉ๋ค์ ๊ธฐ์ค์ผ๋ก ์บก์
์ ํ์ง์ ํ๊ฐํ๊ณ , ๊ฐ ํญ๋ชฉ์ ๋ํด ์ ์(0-5์ )๋ฅผ ๋ถ์ฌํ ํ ๊ฐ๊ฒฐํ๊ณ ๋ช
ํํ ํผ๋๋ฐฑ์ ์ ๊ณตํด์ฃผ์ธ์.
+
+ ํ๊ฐ ํญ๋ชฉ:
+ 1.ํฌ์ฅ ์ฉ๊ธฐ์ ๋ชจ์์ ๋ช
ํํ ์ค๋ช
ํ๋๊ฐ? (0~5์ )
+ 2.ํฌ์ฅ ์ฌ์ง์ด ์ ํํ ํํ๋์๋๊ฐ? (0~5์ )
+ 3.๋ด์ฉ๋ฌผ์ ๋ํ ์ ๋ณด๊ฐ ๋ช
ํํ ์ ์๋์๋๊ฐ? (0~5์ )
+ 4.์์์ ๋ํ ์ค๋ช
์ด ํฌํจ๋์๋๊ฐ? (0~5์ )
+ 5.๋ํ
์ผํ ๋ฌ์ฌ๊ฐ ์ด๋ค์ก๋๊ฐ? (0~5์ )
+ 6.์ ํ์ ๋ํ ์ถ๊ฐ ์ ๋ณด๋ฅผ ์ ๊ณตํ๋๊ฐ? (0~5์ )
+ 7.์ ํ์ ์ค์ ํน์ฑ์ ์ ํํ ๋ฌ์ฌํ๋๊ฐ? (์คํ๋ ์๊ณก ์๋์ง) (0~5์ )
+ 8.๋ถํ์ํ๊ฒ ๊ธธ์ง ์๊ณ , ํต์ฌ ์ ๋ณด์ ์ง์คํ๋๊ฐ? (0~5์ )
+ 9.์๊ฐ์ฅ์ ์ธ์ด ์ฝ๊ฒ ์ดํดํ ์ ์๋๋ก ์ง๊ด์ ์ผ๋ก ์์ฑ๋์๋๊ฐ? (0~5์ )
+ 10.ํน์ ์ ๋ณด๋ฅผ ์ค๋ณตํ์ง ์๊ณ , ์๋กญ๊ฑฐ๋ ํ์ํ ์ ๋ณด ์์ฃผ๋ก ์ ๋ฆฌ๋์๋๊ฐ? (0~5์ )
+
+ ์ ์ ๊ธฐ์ค(0~5์ ):
+ 0์ : ์ ํ ๋ฐ์๋์ง ์์
+ 1์ : ๋งค์ฐ ๋ถ์กฑํ๊ฒ ๋ฐ์๋จ
+ 2์ : ์ผ๋ถ ๋ฐ์๋์์ผ๋ ๋ถ์กฑํจ
+ 3์ : ๋ณดํต ์์ค์ผ๋ก ๋ฐ์๋จ
+ 4์ : ๋๋ถ๋ถ ์ ๋ฐ์๋จ
+ 5์ : ์๋ฒฝํ๊ฒ ๋ฐ์๋จ
+
+ ์ถ๋ ฅ ์์:
+ ํญ๋ชฉ1: 4/5
+ ํผ๋๋ฐฑ: ...
+ ...
+ (๋ง์ง๋ง์ 'ํ๊ฐ: x/5' ํํ๋ก ๊ฐ ํญ๋ชฉ ์ ์ ํฉ๊ณ๋ฅผ ํ๊ธฐํด ์ฃผ์ธ์.)
+ """
+
+ try:
+ response = openai.ChatCompletion.create(
+ model="gpt-4o",
+ messages=[
+ {"role": "system", "content": "๋น์ ์ ์ด๋ฏธ์ง ์บก์
๋ ๊ฒฐ๊ณผ๋ฅผ ํ๊ฐํ๋ ์ ๋ฌธ๊ฐ์
๋๋ค."},
+ {"role": "user", "content": prompt}
+ ],
+ temperature=0.2,
+ max_tokens=1024
+ )
+ gpt_eval_text = response.choices[0].message.content
+ return gpt_eval_text
+ except Exception as e:
+ print(f"Error calling GPT: {e}")
+ return None
+
+ def run_tasks(df: pd.DataFrame) -> pd.DataFrame:
+ """
+ ThreadPoolExecutor๋ฅผ ์ฌ์ฉํด ๋ณ๋ ฌ๋ก GPT ํ๊ฐ ์ํ.
+ """
+ with ThreadPoolExecutor() as executor:
+ futures = {}
+ for idx, row in df.iterrows():
+ futures[executor.submit(evaluate_gpt, row)] = idx
+
+ for future in as_completed(futures):
+ idx = futures[future]
+ result = future.result()
+ if result is not None:
+ df.at[idx, "Eval (gpt-4o)"] = result
+ df.at[idx, "Score (gpt-4o)"] = calculate_total_score_from_gpt_eval(result)
+ return df
+
+ # 6) ๊ฐ ๋ฐ์ดํฐํ๋ ์์ ๋ํด์ GPT ํ๊ฐ ์ถ๊ฐ ("Score (gpt-4o)" ์ด)
+ internvl_eng_eval = run_tasks(internvl_eng_eval)
+ maai_eval = run_tasks(maai_eval)
+ deepseekvl_eval = run_tasks(deepseekvl_eval)
+ qwen_unsloth_eval = run_tasks(qwen_unsloth_eval)
+ qwen2_eval = run_tasks(qwen2_eval)
+ qwen2_5_eval = run_tasks(qwen2_5_eval)
+ janus_eval = run_tasks(janus_eval)
+ qwen2_5_janus_eval= run_tasks(qwen2_5_janus_eval)
+
+ # 7) CSV๋ก ๋ค์ ์ ์ฅ (์ํ๋ฉด overwrite or new filename)
+ internvl_eng_eval.to_csv(internvl_eval_path, index=False, encoding='utf-8-sig')
+ maai_eval.to_csv(maai_eval_path, index=False, encoding='utf-8-sig')
+ deepseekvl_eval.to_csv(deepseekvl_eval_path, index=False, encoding='utf-8-sig')
+ qwen_unsloth_eval.to_csv(qwen_unsloth_eval_path, index=False, encoding='utf-8-sig')
+ qwen2_eval.to_csv(qwen2_eval_path, index=False, encoding='utf-8-sig')
+ qwen2_5_eval.to_csv(qwen2_5_eval_path, index=False, encoding='utf-8-sig')
+ janus_eval.to_csv(janus_eval_path, index=False, encoding='utf-8-sig')
+ qwen2_5_janus_eval.to_csv(qwen2_5_janus_eval_path, index=False, encoding='utf-8-sig')
+
+ print("[Info] GPT-4o Evaluation columns added to each CSV successfully.")
+
+ # 8) ๋ชจ๋ธ๋ณ ์ ์ ํฉ ๊ณ์ฐ
+ internvl_eng_total_score_gpt_4o = internvl_eng_eval['Score (gpt-4o)'].sum()
+ maai_total_score_gpt_4o = maai_eval['Score (gpt-4o)'].sum()
+ deepseekvl_total_score_gpt_4o = deepseekvl_eval['Score (gpt-4o)'].sum()
+ qwen2_unsloth_total_score_gpt_4o= qwen_unsloth_eval['Score (gpt-4o)'].sum()
+ qwen2_total_score_gpt_4o = qwen2_eval['Score (gpt-4o)'].sum()
+ qwen2_5_total_score_gpt_4o = qwen2_5_eval['Score (gpt-4o)'].sum()
+ janus_total_score_gpt_4o = janus_eval['Score (gpt-4o)'].sum()
+ qwen2_5_janus_total_score_gpt_4o= qwen2_5_janus_eval['Score (gpt-4o)'].sum()
+
+ print(f"internvl_eng Score (gpt-4o): {internvl_eng_total_score_gpt_4o}")
+ print(f"maai Score (gpt-4o): {maai_total_score_gpt_4o}")
+ print(f"deepseekvl Score (gpt-4o): {deepseekvl_total_score_gpt_4o}")
+ print(f"qwen2_unsloth Score (gpt-4o): {qwen2_unsloth_total_score_gpt_4o}")
+ print(f"qwen2 Score (gpt-4o): {qwen2_total_score_gpt_4o}")
+ print(f"qwen2_5 Score (gpt-4o): {qwen2_5_total_score_gpt_4o}")
+ print(f"janus Score (gpt-4o): {janus_total_score_gpt_4o}")
+ print(f"qwen2_5_janus Score (gpt-4o): {qwen2_5_janus_total_score_gpt_4o}")
+
+ # 9) ํ๊ท ์ถ๋ก ์๊ฐ ๊ณ์ฐ
+ internvl_eng_avg_time = internvl_eng_eval['Inference Time (s)'].mean()
+ maai_avg_time = maai_eval['Inference Time (s)'].mean()
+ deepseekvl_avg_time = deepseekvl_eval['Inference Time (s)'].mean()
+ qwen2_unsloth_avg_time= qwen_unsloth_eval['Inference Time (s)'].mean()
+ qwen2_avg_time = qwen2_eval['Inference Time (s)'].mean()
+ qwen2_5_avg_time = qwen2_5_eval['Inference Time (s)'].mean()
+ janus_avg_time = janus_eval['Inference Time (s)'].mean()
+ qwen2_5_janus_avg_time= qwen2_5_janus_eval['Inference Time (s)'].mean()
+
+ print(f"internvl Avg Inference Time: {internvl_eng_avg_time:.2f}")
+ print(f"maai Avg Inference Time: {maai_avg_time:.2f}")
+ print(f"deepseekvl Avg Inference Time: {deepseekvl_avg_time:.2f}")
+ print(f"qwen2(unsloth) Avg Inference Time: {qwen2_unsloth_avg_time:.2f}")
+ print(f"qwen2 Avg Inference Time: {qwen2_avg_time:.2f}")
+ print(f"qwen2_5 Avg Inference Time: {qwen2_5_avg_time:.2f}")
+ print(f"janus Avg Inference Time: {janus_avg_time:.2f}")
+ print(f"qwen2_5_janus Avg Inference Time: {qwen2_5_janus_avg_time:.2f}")
+
+ # 10) ์๊ฐํ ํํธ
+ # Performance Comparison
+ models = ['MAAI','InternVL','Qwen2_VL(Unsloth)','Qwen2_VL','DeepSeek_VL','Qwen2_5_VL','Janus_Pro','Qwen2_5_Janus' ]
+ total_scores_gpt_4o = [
+ maai_total_score_gpt_4o,
+ internvl_eng_total_score_gpt_4o,
+ qwen2_unsloth_total_score_gpt_4o,
+ qwen2_total_score_gpt_4o,
+ deepseekvl_total_score_gpt_4o,
+ qwen2_5_total_score_gpt_4o,
+ janus_total_score_gpt_4o,
+ qwen2_5_janus_total_score_gpt_4o
+ ]
+ # 10๊ฐ ํญ๋ชฉ * 5์ ๋ง์ * N๊ฐ ๋ฐ์ดํฐ?? ์: 50๊ฐ, => 2500์ = 100%
+ total_scores_gpt_4o_percent = [(val / 2500) * 100 for val in total_scores_gpt_4o]
+
+ colors = ['#FF7F50', '#008080', '#E6E6FA', '#FFD700', '#DDA0DD', '#6A5ACD', '#32CD32', '#DC143C']
+ fig, ax = plt.subplots(figsize=(12, 8))
+ bars = ax.bar(models, total_scores_gpt_4o_percent, color=colors, width=0.6)
+ for bar in bars:
+ h = bar.get_height()
+ ax.text(bar.get_x() + bar.get_width()/2., h,
+ f'{h:.2f} %',
+ ha='center', va='bottom', fontsize=12, fontweight='bold')
+ ax.set_ylabel('Total Score (%)', fontsize=14, fontweight='bold')
+ ax.set_title('Model Performance Comparison (gpt-4o)', fontsize=18, fontweight='bold', pad=20)
+ ax.set_ylim(0, max(total_scores_gpt_4o_percent) * 1.2)
+ ax.grid(axis='y', linestyle='--', alpha=0.7)
+ plt.xticks(rotation=0, ha='center', fontsize=12, fontweight='bold')
+ plt.yticks(fontsize=10)
+ for bar in bars:
+ bar.set_edgecolor('white')
+ bar.set_linewidth(2)
+ plt.tight_layout()
+ plt.savefig('model_performance_comparison_8.png', dpi=300, bbox_inches='tight')
+ plt.show()
+
+ # Inference Time Comparison
+ avg_times = [
+ maai_avg_time,
+ internvl_eng_avg_time,
+ qwen2_unsloth_avg_time,
+ qwen2_avg_time,
+ deepseekvl_avg_time,
+ qwen2_5_avg_time,
+ janus_avg_time,
+ qwen2_5_janus_avg_time
+ ]
+ fig2, ax2 = plt.subplots(figsize=(12, 8))
+ bars2 = ax2.bar(models, avg_times, color=colors, width=0.6)
+ for bar in bars2:
+ h = bar.get_height()
+ ax2.text(bar.get_x() + bar.get_width()/2., h,
+ f'{h:.2f} (s)',
+ ha='center', va='bottom', fontsize=12, fontweight='bold')
+ ax2.set_ylabel('Avg Inference Time (s)', fontsize=14, fontweight='bold')
+ ax2.set_title('Model Avg Inference Time Comparison', fontsize=18, fontweight='bold', pad=20)
+ ax2.set_ylim(0, max(avg_times) * 1.2)
+ ax2.grid(axis='y', linestyle='--', alpha=0.7)
+ plt.xticks(rotation=0, ha='center', fontsize=12, fontweight='bold')
+ plt.yticks(fontsize=10)
+ for bar in bars2:
+ bar.set_edgecolor('white')
+ bar.set_linewidth(2)
+ plt.tight_layout()
+ plt.savefig('model_avg_inference_time_comparison_8.png', dpi=300, bbox_inches='tight')
+ plt.show()
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/models/thumbnail_description/src/description_pipeline/evaluation/gpt_eval_323.py b/models/thumbnail_description/src/description_pipeline/evaluation/gpt_eval_323.py
new file mode 100644
index 0000000..77d6022
--- /dev/null
+++ b/models/thumbnail_description/src/description_pipeline/evaluation/gpt_eval_323.py
@@ -0,0 +1,162 @@
+from utils.common_utils import (
+ set_seed, requests, pd, time
+)
+import yaml
+import os
+import re
+import numpy as np
+import matplotlib.pyplot as plt
+from concurrent.futures import ThreadPoolExecutor, as_completed
+
+def main():
+ # 1) config.yaml ๋ก๋
+ with open("config/config.yaml", "r", encoding="utf-8") as f:
+ config = yaml.safe_load(f)
+
+ # 2) OpenAI API Key ์ค์
+ openai.api_key = config["openai"]["api_key"]
+
+ # 3) CSV ํ์ผ ๊ฒฝ๋ก ์ค์
+ data_dir = config["paths"]["data_dir"]
+ qwen2_5_janus_323_eval_path = os.path.join(data_dir, config["paths"]["qwen2_5+janus_323_eval"])
+ output_csv = "outputs/gpt_eval_result_janus+qwen2_5.csv" # ๊ฒฐ๊ณผ ์ ์ฅ์ฉ CSV ํ์ผ๋ช
+
+ # 4) CSV ํ์ผ ๋ก๋
+ df = pd.read_csv(qwen2_5_janus_323_eval_path)
+ print("Data loaded:", df.shape)
+
+ ### GPT ํ๊ฐ ๋ก์ง ###
+
+ def calculate_total_score_from_gpt_eval(eval_text):
+ """
+ GPT๊ฐ ์์ฑํ ํ๊ฐ ํ
์คํธ์์ 'ํ๊ฐ: x/5' ํํ๋ก ๋ ์ซ์๋ฅผ ์ฐพ์
+ ํฉ์ฐํ ๊ฐ์ ๋ฐํ (10๊ฐ ํญ๋ชฉ * 0~5์ ).
+ """
+ try:
+ pattern = r"ํ๊ฐ:\s*([\d\.]+)/5"
+ matches = re.findall(pattern, eval_text)
+ scores = [float(m) for m in matches]
+ return sum(scores)
+ except Exception as e:
+ print(f"Error calculating total score: {e}")
+ return None
+
+ def evaluate_gpt(model_name, idx, row):
+ model_output = row.get("Model Output", "")
+ prompt = f"""
+ ์ด๋ฏธ์ง ์บก์
๋ ๊ฒฐ๊ณผ:
+ {model_output}
+
+ ์ ์ ๋ณด๋ ์๋ณธ ์ด๋ฏธ์ง์ ์ด์ ๋ํ VLM ๋ชจ๋ธ์ ์ด๋ฏธ์ง ์บก์
๋ ๊ฒฐ๊ณผ์
๋๋ค.
+ ์๋ ํญ๋ชฉ๋ค์ ๊ธฐ์ค์ผ๋ก ์บก์
์ ํ์ง์ ํ๊ฐํ๊ณ , ๊ฐ ํญ๋ชฉ์ ๋ํด ์ ์(0-5์ )๋ฅผ ๋ถ์ฌํ ํ
+ ๊ฐ๊ฒฐํ๊ณ ๋ช
ํํ ํผ๋๋ฐฑ์ ์ ๊ณตํด์ฃผ์ธ์.
+
+ ํ๊ฐ ํญ๋ชฉ:
+ 1.ํฌ์ฅ ์ฉ๊ธฐ์ ๋ชจ์์ ๋ช
ํํ ์ค๋ช
ํ๋๊ฐ? (์: ์ง์ฌ๊ฐํ, ์ํ, ๋น๋ํ, ์ ์ฌ๊ฐํ ๋ฑ)
+ 2.ํฌ์ฅ ์ฌ์ง์ด ์ ํํ ํํ๋์๋๊ฐ? (์: ์ข
์ด, ํ๋ผ์คํฑ, ๋น๋, ํํธ๋ณ, ์บ, ์ ๋ฆฌ ๋ฑ)
+ 3.๋ด์ฉ๋ฌผ์ ๋ํ ์ ๋ณด๊ฐ ๋ช
ํํ ์ ์๋์๋๊ฐ? (์: ๊ณผ์, ์์ , ๊ณผ์ผ ๋ฑ)
+ 4.ํจํค์ง์ ๋ํ ๋ฌ์ฌ๊ฐ ์ด๋ค์ก๋๊ฐ?
+ 5.์์์ ๋ํ ์ค๋ช
์ด ํฌํจ๋์๋๊ฐ? (์: ๋นจ๊ฐ์ ๋๊ป, ํฌ๋ช
ํ ๋ณ, ์ด๋ก์ ํฌ์ฅ์ง ๋ฑ)
+ 6.์ ํ์ ๋ํ ์ถ๊ฐ ์ ๋ณด๋ฅผ ์ ๊ณตํ๋๊ฐ? (ํด๋น ์ ํ์ ๋ํ ์ค๋ช
์ด ๋ด๊ฒจ ์๋์ง)
+ 7.์ ํ์ ์ค์ ํน์ฑ์ ์ ํํ ๋ฌ์ฌํ๋๊ฐ? (์คํ๋ ์๊ณก๋ ์ ๋ณด๊ฐ ์๋์ง)
+ 8.์บก์
์ด ๋ถํ์ํ๊ฒ ๊ธธ์ง ์๊ณ , ํต์ฌ ์ ๋ณด์ ์ง์คํ๋๊ฐ?
+ 9.์๊ฐ์ฅ์ ์ธ์ด ์ฝ๊ฒ ์ดํดํ ์ ์๋๋ก ๋ช
ํํ๊ณ ์ง๊ด์ ์ผ๋ก ์์ฑ๋์๋๊ฐ?
+ 10.ํน์ ์ ๋ณด๋ฅผ ์ค๋ณตํ์ง ์๊ณ , ํ์ํ๊ฑฐ๋ ์๋ก์ด ์ ๋ณด ์์ฃผ๋ก ์ ๋ฆฌ๋์๋๊ฐ?
+
+ ์ ์ ๊ธฐ์ค(0~5์ ):
+ 0์ : ์ ํ ๋ฐ์๋์ง ์์
+ 1์ : ๋งค์ฐ ๋ถ์กฑํ๊ฒ ๋ฐ์๋จ
+ 2์ : ์ผ๋ถ ๋ฐ์๋์์ผ๋ ๋ถ์กฑํจ
+ 3์ : ๋ณดํต ์์ค์ผ๋ก ๋ฐ์๋จ
+ 4์ : ๋๋ถ๋ถ ์ ๋ฐ์๋จ
+ 5์ : ์๋ฒฝํ๊ฒ ๋ฐ์๋จ
+
+ ์ถ๋ ฅ ํ์ ์์:
+ ํญ๋ชฉ: 1.์บก์
์ด ํฌ์ฅ ์ฉ๊ธฐ์ ๋ชจ์์ ๋ช
ํํ ์ค๋ช
ํ๋๊ฐ?
+ ํ๊ฐ: 4/5
+ ํผ๋๋ฐฑ: ๋ชจ์์ด ๋๋ถ๋ถ ๋ช
ํํ๊ฒ ๋ฌ์ฌ๋์์ผ๋, ์ฝ๊ฐ์ ์ธ๋ถ ์ ๋ณด๊ฐ ๋ถ์กฑํจ.
+ (๋ง์ง๋ง์ ๊ฐ ํญ๋ชฉ ์ ์ ํฉ๊ณ๋ฅผ 'ํ๊ฐ: x/5' ํํ๋ก ํ๊ธฐํด์ฃผ์ธ์.)
+ """
+ try:
+ completion = openai.ChatCompletion.create(
+ model="gpt-4o",
+ messages=[
+ {"role": "system", "content": "๋น์ ์ VLM์ ์ด๋ฏธ์ง ์บก์
๋ ๊ฒฐ๊ณผ๋ฅผ ํ๊ฐํ๋ ์ ๋ฌธ๊ฐ์
๋๋ค."},
+ {"role": "user", "content": prompt}
+ ],
+ temperature=0.2,
+ max_tokens=1024
+ )
+ gpt_eval = completion.choices[0].message.content
+ return idx, gpt_eval
+ except Exception as e:
+ print(f"Error for idx {idx}: {e}")
+ return idx, None
+
+ def run_tasks(df_input, model_name):
+ from concurrent.futures import ThreadPoolExecutor, as_completed
+ futures = {}
+ with ThreadPoolExecutor() as executor:
+ for i, row in df_input.iterrows():
+ futures[executor.submit(evaluate_gpt, model_name, i, row)] = i
+
+ for future in as_completed(futures):
+ idx = futures[future]
+ res = future.result()
+ if res is not None:
+ row_idx, gpt_eval_text = res
+ if gpt_eval_text:
+ df_input.at[row_idx, f'Eval ({model_name})'] = gpt_eval_text
+ df_input.at[row_idx, f'Score ({model_name})'] = calculate_total_score_from_gpt_eval(gpt_eval_text)
+ return df_input
+
+ # 5) GPT ํ๊ฐ ์คํ
+ model_name = "gpt-4o"
+ df_eval = run_tasks(df, model_name)
+
+ # 6) ๊ฒฐ๊ณผ CSV ์ ์ฅ
+ df_eval.to_csv(output_csv, index=False, encoding="utf-8-sig")
+ print(f"[Info] GPT Evaluation completed and saved => {output_csv}")
+
+ # 7) ๊ทธ๋ํ ์์ฑ์ฉ ์ ์ ์ค๋น
+ # ์ ์ ํฉ ๊ณ์ฐ
+ sum_score = df_eval[f'Score ({model_name})'].sum()
+ print(f"Total Score (gpt-4o) = {sum_score}")
+
+ # ๋ชจ๋ธ๋ช
๋ฆฌ์คํธ์ ์ ์ ๋ฆฌ์คํธ
+ models = ['Janus_Qwen2_5']
+
+ # ์: 10๊ฐ ํญ๋ชฉ * 5์ ๋ง์ * N๊ฐ ๋ฐ์ดํฐ = 10*N*5 => ์ต๋์น
+ max_possible_score = 10 * df_eval.shape[0] * 5
+ # ์ ์ % ๊ณ์ฐ
+ percentage = (sum_score / max_possible_score) * 100
+ scores_gpt_4o = [percentage]
+
+ # ๊ทธ๋ํ
+ colors = ['#FF7F50']
+
+ fig, ax = plt.subplots(figsize=(8, 6))
+ bars = ax.bar(models, scores_gpt_4o, color=colors, width=0.6)
+ for bar in bars:
+ height = bar.get_height()
+ ax.text(bar.get_x() + bar.get_width()/2., height,
+ f'{height:.2f} %',
+ ha='center', va='bottom', fontsize=12, fontweight='bold')
+
+ ax.set_ylabel('Total Score (%)', fontsize=14, fontweight='bold')
+ ax.set_title('Model Performance Comparison (GPT-4o)', fontsize=18, fontweight='bold', pad=20)
+ ax.set_ylim(0, max(scores_gpt_4o) * 1.2 if len(scores_gpt_4o) > 0 else 100)
+ ax.grid(axis='y', linestyle='--', alpha=0.7)
+ plt.xticks(rotation=0, ha='center', fontsize=12, fontweight='bold')
+ plt.yticks(fontsize=10)
+
+ for bar in bars:
+ bar.set_edgecolor('white')
+ bar.set_linewidth(2)
+
+ plt.tight_layout()
+ plt.savefig('model_performance_comparison_total.png', dpi=300, bbox_inches='tight')
+ print("[Info] Bar chart saved => model_performance_comparison_total.png")
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/models/thumbnail_description/src/description_pipeline/inference_model/deepseekvl.py b/models/thumbnail_description/src/description_pipeline/inference_model/deepseekvl.py
new file mode 100644
index 0000000..ec034e9
--- /dev/null
+++ b/models/thumbnail_description/src/description_pipeline/inference_model/deepseekvl.py
@@ -0,0 +1,134 @@
+import os
+import time
+import urllib.request
+import pandas as pd
+import torch
+import yaml
+
+from transformers import AutoModelForCausalLM
+from deepseek_vl.models import VLChatProcessor, MultiModalityCausalLM
+from deepseek_vl.utils.io import load_pil_images
+
+from utils.common_utils import set_seed, load_and_filter_data
+
+def run_inference_deepseekvl():
+ """
+ DeepSeek-VL-7B-Chat ๋ชจ๋ธ์ ํ์ฉํ ์ถ๋ก ์คํฌ๋ฆฝํธ.
+ config.yaml์ ํตํด CSV ๊ฒฝ๋ก, ํ๋กฌํํธ ๊ฒฝ๋ก, ๊ฒฐ๊ณผ ์ ์ฅ ๊ฒฝ๋ก๋ฅผ ์ฝ์ด์จ ๋ค ์์
์ํ.
+ """
+ # 1) config.yaml ๋ก๋
+ with open("config/config.yaml", "r", encoding="utf-8") as f:
+ config = yaml.safe_load(f)
+
+ data_dir = config["paths"]["data_dir"]
+ prompt_dir = config["paths"]["prompt_dir"]
+ csv_name = config["paths"]["cleaned_text_contents"]
+ out_name = config["paths"]["deepseekvl_eval"]
+
+ # ์ค์ ๊ฒฝ๋ก ๊ตฌ์ฑ
+ csv_path = os.path.join(data_dir, csv_name)
+ prompt_path = os.path.join(prompt_dir, "deepseek_prompt.txt")
+ output_csv = os.path.join(data_dir, out_name)
+
+ # 2) ์๋ ์ค์
+ set_seed(42)
+
+ # 3) CSV ๋ก๋ & ํํฐ๋ง
+ df_filtered = load_and_filter_data(csv_path)
+ image_urls = df_filtered['url_clean'].to_list()
+ product_names = df_filtered['์ํ๋ช
'].to_list()
+
+ # 4) ํ๋กฌํํธ ํ
์คํธ ๋ถ๋ฌ์ค๊ธฐ
+ with open(prompt_path, "r", encoding="utf-8") as f:
+ base_prompt = f.read().strip()
+
+ # 5) DeepSeek ๋ชจ๋ธ ์ค๋น
+ model_path = "deepseek-ai/deepseek-vl-7b-chat"
+ vl_chat_processor = VLChatProcessor.from_pretrained(model_path)
+ tokenizer = vl_chat_processor.tokenizer
+
+ vl_gpt = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True)
+ vl_gpt = vl_gpt.to(torch.bfloat16).cuda().eval()
+
+ model_name = "deepseek-ai/deepseek-vl-7b-chat"
+
+ def process_image(image_url, prompt):
+ """
+ ๊ฐ๋ณ ์ด๋ฏธ์ง URL์ ๋ํด DeepSeek-VL ๋ชจ๋ธ ์ถ๋ก ์ ์ํํ๋ ๋ด๋ถ ํจ์.
+ """
+ # ์์ ๋ค์ด๋ก๋ ํ์ผ๋ช
+ temp_filename = "temp_deepseek.jpg"
+
+ # ์ด๋ฏธ์ง ๋ค์ด๋ก๋
+ urllib.request.urlretrieve(image_url, temp_filename)
+ try:
+ # ๋ํ ๋ฐ์ดํฐ ๊ตฌ์ฑ
+ conversation = [
+ {
+ "role": "User",
+ "content": f"\n{prompt}",
+ "images": [temp_filename]
+ },
+ {
+ "role": "Assistant",
+ "content": ""
+ }
+ ]
+ # ์ด๋ฏธ์ง ๋ก๋
+ pil_images = load_pil_images(conversation)
+
+ # ์
๋ ฅ ์ค๋น
+ prepare_inputs = vl_chat_processor(
+ conversations=conversation, images=pil_images, force_batchify=True
+ ).to(vl_gpt.device)
+
+ # ์ด๋ฏธ์ง ์๋ฒ ๋ฉ
+ inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
+
+ # ํ
์คํธ ์์ฑ
+ outputs = vl_gpt.language_model.generate(
+ inputs_embeds=inputs_embeds,
+ attention_mask=prepare_inputs.attention_mask,
+ pad_token_id=tokenizer.eos_token_id,
+ bos_token_id=tokenizer.bos_token_id,
+ eos_token_id=tokenizer.eos_token_id,
+ max_new_tokens=512,
+ do_sample=True,
+ temperature=0.2,
+ top_p=0.95
+ )
+ answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
+
+ print(f"{prepare_inputs['sft_format'][0]}", answer)
+ return answer
+ except Exception as e:
+ print(f"Error processing image at {image_url}: {e}")
+ return "Error: An unexpected error occurred."
+ finally:
+ # ๋ค์ด๋ก๋ ํ์ผ ์ญ์ (ํ์ ์)
+ if os.path.exists(temp_filename):
+ os.remove(temp_filename)
+
+ # 6) ๋ฐ๋ณต ์ถ๋ก
+ results = []
+ for idx, (image_url, product_name) in enumerate(zip(image_urls, product_names)):
+ start_time = time.time()
+
+ prompt = base_prompt
+
+ output = process_image(image_url, prompt)
+ elapsed = time.time() - start_time
+
+ results.append({
+ "Model": model_name,
+ "ImageURL": image_url,
+ "Prompt": prompt,
+ "Inference Time (s)": elapsed,
+ "Model Output": output
+ })
+
+ print(f"[{idx+1}/{len(image_urls)}] Processed in {elapsed:.2f}s => {product_name}")
+
+ # 7) ๊ฒฐ๊ณผ ์ ์ฅ
+ pd.DataFrame(results).to_csv(output_csv, index=False, encoding="utf-8-sig")
+ print(f"[DeepSeek-VL] Saved results => {output_csv}")
\ No newline at end of file
diff --git a/models/thumbnail_description/src/description_pipeline/inference_model/finetuned_janus_pro.py b/models/thumbnail_description/src/description_pipeline/inference_model/finetuned_janus_pro.py
new file mode 100644
index 0000000..67ae18e
--- /dev/null
+++ b/models/thumbnail_description/src/description_pipeline/inference_model/finetuned_janus_pro.py
@@ -0,0 +1,110 @@
+import os
+import time
+import urllib.request
+import pandas as pd
+import torch
+import yaml
+
+from transformers import AutoModelForCausalLM
+from janus.models import MultiModalityCausalLM, VLChatProcessor
+from janus.utils.io import load_pil_images
+
+from utils.common_utils import set_seed, load_and_filter_data
+
+def run_inference_janus_pro():
+ """
+ itsmenlp/finetuned-Janus-Pro-7B ๋ชจ๋ธ ์ถ๋ก ํ ๊ฒฐ๊ณผ CSV ์ ์ฅ.
+ config.yaml์ ์ฐธ๊ณ ํ์ฌ CSV, ํ๋กฌํํธ, ๊ฒฐ๊ณผ ๊ฒฝ๋ก๋ฅผ ์ค์ .
+ """
+ # 1) config.yaml ๋ก๋
+ with open("config/config.yaml", "r", encoding="utf-8") as f:
+ config = yaml.safe_load(f)
+
+ data_dir = config["paths"]["data_dir"]
+ prompt_dir = config["paths"]["prompt_dir"]
+ csv_name = config["paths"]["cleaned_text_contents"]
+ out_name = config["paths"]["janus_pro_eval"]
+
+ # ์ค์ ๊ฒฝ๋ก ๊ตฌ์ฑ
+ csv_path = os.path.join(data_dir, csv_name)
+ prompt_path = os.path.join(prompt_dir, "janus_prompt.txt")
+ output_csv = os.path.join(data_dir, out_name)
+
+ # 2) ์๋ ์ค์
+ set_seed(42)
+
+ # 3) CSV ๋ก๋ & ํํฐ๋ง
+ df_filtered = load_and_filter_data(csv_path)
+ image_urls = df_filtered['url_clean'].to_list()
+ product_names = df_filtered['์ํ๋ช
'].to_list()
+
+ # 4) ํ๋กฌํํธ ํ
์คํธ ์ฝ๊ธฐ
+ with open(prompt_path, "r", encoding="utf-8") as f:
+ base_prompt = f.read().strip()
+
+ # 5) Janus-Pro ๋ชจ๋ธ ์ค๋น
+ model_path = "itsmenlp/finetuned-Janus-Pro-7B" #private
+ vl_chat_processor = VLChatProcessor.from_pretrained(model_path)
+ tokenizer = vl_chat_processor.tokenizer
+
+ vl_gpt = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True)
+ vl_gpt = vl_gpt.to(torch.bfloat16).cuda().eval()
+
+ model_name = "itsmenlp/finetuned-Janus-Pro-7B"
+
+ def process_image(img_url, prompt):
+ temp_filename = "temp_janus.jpg"
+ urllib.request.urlretrieve(img_url, temp_filename)
+ try:
+ conversation = [
+ {"role": "<|User|>", "content": f"\n{prompt}", "images": [temp_filename]},
+ {"role": "<|Assistant|>", "content": ""}
+ ]
+ pil_images = load_pil_images(conversation)
+ prepare_inputs = vl_chat_processor(
+ conversations=conversation, images=pil_images, force_batchify=True
+ ).to(vl_gpt.device)
+
+ inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
+ outputs = vl_gpt.language_model.generate(
+ inputs_embeds=inputs_embeds,
+ attention_mask=prepare_inputs.attention_mask,
+ pad_token_id=tokenizer.eos_token_id,
+ bos_token_id=tokenizer.bos_token_id,
+ eos_token_id=tokenizer.eos_token_id,
+ max_new_tokens=512,
+ do_sample=True,
+ temperature=0.2,
+ top_p=0.95
+ )
+ answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
+ return answer
+ except Exception as e:
+ print(f"Error processing image at {img_url}: {e}")
+ return "Error: An unexpected error occurred."
+ finally:
+ if os.path.exists(temp_filename):
+ os.remove(temp_filename)
+
+ # 6) ๋ฐ๋ณต ์ถ๋ก
+ results = []
+ for idx, (image_url, product_name) in enumerate(zip(image_urls, product_names)):
+ start_time = time.time()
+ prompt = base_prompt
+
+ output = process_image(image_url, prompt)
+ elapsed_time = time.time() - start_time
+
+ results.append({
+ "Model": model_name,
+ "ImageURL": image_url,
+ "Prompt": prompt,
+ "Inference Time (s)": elapsed_time,
+ "Model Output": output
+ })
+
+ print(f"[JanusPro] Processed {idx+1}/{len(image_urls)} in {elapsed_time:.2f}s => {product_name}")
+
+ # 7) ๊ฒฐ๊ณผ ์ ์ฅ
+ pd.DataFrame(results).to_csv(output_csv, index=False, encoding="utf-8-sig")
+ print(f"[JanusPro] Saved results => {output_csv}")
\ No newline at end of file
diff --git a/models/thumbnail_description/src/description_pipeline/inference_model/janus_pro.py b/models/thumbnail_description/src/description_pipeline/inference_model/janus_pro.py
new file mode 100644
index 0000000..e380481
--- /dev/null
+++ b/models/thumbnail_description/src/description_pipeline/inference_model/janus_pro.py
@@ -0,0 +1,110 @@
+import os
+import time
+import urllib.request
+import pandas as pd
+import torch
+import yaml
+
+from transformers import AutoModelForCausalLM
+from janus.models import MultiModalityCausalLM, VLChatProcessor
+from janus.utils.io import load_pil_images
+
+from utils.common_utils import set_seed, load_and_filter_data
+
+def run_inference_janus_pro():
+ """
+ Janus-Pro-7B ๋ชจ๋ธ ์ถ๋ก ํ ๊ฒฐ๊ณผ CSV ์ ์ฅ.
+ config.yaml์ ์ฐธ๊ณ ํ์ฌ CSV, ํ๋กฌํํธ, ๊ฒฐ๊ณผ ๊ฒฝ๋ก๋ฅผ ์ค์ .
+ """
+ # 1) config.yaml ๋ก๋
+ with open("config/config.yaml", "r", encoding="utf-8") as f:
+ config = yaml.safe_load(f)
+
+ data_dir = config["paths"]["data_dir"]
+ prompt_dir = config["paths"]["prompt_dir"]
+ csv_name = config["paths"]["cleaned_text_contents"]
+ out_name = config["paths"]["janus_pro_eval"]
+
+ # ์ค์ ๊ฒฝ๋ก ๊ตฌ์ฑ
+ csv_path = os.path.join(data_dir, csv_name)
+ prompt_path = os.path.join(prompt_dir, "janus_prompt.txt")
+ output_csv = os.path.join(data_dir, out_name)
+
+ # 2) ์๋ ์ค์
+ set_seed(42)
+
+ # 3) CSV ๋ก๋ & ํํฐ๋ง
+ df_filtered = load_and_filter_data(csv_path)
+ image_urls = df_filtered['url_clean'].to_list()
+ product_names = df_filtered['์ํ๋ช
'].to_list()
+
+ # 4) ํ๋กฌํํธ ํ
์คํธ ์ฝ๊ธฐ
+ with open(prompt_path, "r", encoding="utf-8") as f:
+ base_prompt = f.read().strip()
+
+ # 5) Janus-Pro ๋ชจ๋ธ ์ค๋น
+ model_path = "deepseek-ai/Janus-Pro-7B"
+ vl_chat_processor = VLChatProcessor.from_pretrained(model_path)
+ tokenizer = vl_chat_processor.tokenizer
+
+ vl_gpt = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True)
+ vl_gpt = vl_gpt.to(torch.bfloat16).cuda().eval()
+
+ model_name = "deepseek-ai/Janus-Pro-7B"
+
+ def process_image(img_url, prompt):
+ temp_filename = "temp_janus.jpg"
+ urllib.request.urlretrieve(img_url, temp_filename)
+ try:
+ conversation = [
+ {"role": "<|User|>", "content": f"\n{prompt}", "images": [temp_filename]},
+ {"role": "<|Assistant|>", "content": ""}
+ ]
+ pil_images = load_pil_images(conversation)
+ prepare_inputs = vl_chat_processor(
+ conversations=conversation, images=pil_images, force_batchify=True
+ ).to(vl_gpt.device)
+
+ inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
+ outputs = vl_gpt.language_model.generate(
+ inputs_embeds=inputs_embeds,
+ attention_mask=prepare_inputs.attention_mask,
+ pad_token_id=tokenizer.eos_token_id,
+ bos_token_id=tokenizer.bos_token_id,
+ eos_token_id=tokenizer.eos_token_id,
+ max_new_tokens=512,
+ do_sample=True,
+ temperature=0.2,
+ top_p=0.95
+ )
+ answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
+ return answer
+ except Exception as e:
+ print(f"Error processing image at {img_url}: {e}")
+ return "Error: An unexpected error occurred."
+ finally:
+ if os.path.exists(temp_filename):
+ os.remove(temp_filename)
+
+ # 6) ๋ฐ๋ณต ์ถ๋ก
+ results = []
+ for idx, (image_url, product_name) in enumerate(zip(image_urls, product_names)):
+ start_time = time.time()
+ prompt = base_prompt
+
+ output = process_image(image_url, prompt)
+ elapsed_time = time.time() - start_time
+
+ results.append({
+ "Model": model_name,
+ "ImageURL": image_url,
+ "Prompt": prompt,
+ "Inference Time (s)": elapsed_time,
+ "Model Output": output
+ })
+
+ print(f"[JanusPro] Processed {idx+1}/{len(image_urls)} in {elapsed_time:.2f}s => {product_name}")
+
+ # 7) ๊ฒฐ๊ณผ ์ ์ฅ
+ pd.DataFrame(results).to_csv(output_csv, index=False, encoding="utf-8-sig")
+ print(f"[JanusPro] Saved results => {output_csv}")
\ No newline at end of file
diff --git a/models/thumbnail_description/src/description_pipeline/inference_model/maal.py b/models/thumbnail_description/src/description_pipeline/inference_model/maal.py
new file mode 100644
index 0000000..66b638b
--- /dev/null
+++ b/models/thumbnail_description/src/description_pipeline/inference_model/maal.py
@@ -0,0 +1,127 @@
+import os
+import time
+import re
+import requests
+import torch
+import pandas as pd
+import yaml
+from PIL import Image
+
+from transformers import MllamaForConditionalGeneration, AutoProcessor
+from utils.common_utils import set_seed, load_and_filter_data
+
+def run_inference_maal():
+ """
+ MAAL ๋ชจ๋ธ ์ถ๋ก ํ CSV ํ์ผ์ ์ ์ฅ.
+ config.yaml์ ์ฐธ๊ณ ํ์ฌ CSV, ํ๋กฌํํธ, ๊ฒฐ๊ณผ ๊ฒฝ๋ก๋ฅผ ์ค์ .
+ """
+ # 1) config.yaml ๋ก๋
+ with open("config/config.yaml", "r", encoding="utf-8") as f:
+ config = yaml.safe_load(f)
+
+ data_dir = config["paths"]["data_dir"]
+ prompt_dir = config["paths"]["prompt_dir"]
+ csv_name = config["paths"]["cleaned_text_contents"]
+ out_name = config["paths"]["maai_pro_eval"]
+
+ # ์ค์ ๊ฒฝ๋ก ๊ตฌ์ฑ
+ csv_path = os.path.join(data_dir, csv_name)
+ prompt_path = os.path.join(prompt_dir, "maal_prompt.txt")
+ output_csv = os.path.join(data_dir, out_name)
+
+ # 2) ์๋ ์ค์
+ set_seed(42)
+
+ # 3) CSV ๋ก๋ & ํํฐ๋ง
+ df_filtered = load_and_filter_data(csv_path)
+ image_urls = df_filtered['url_clean'].to_list()
+ product_names = df_filtered['์ํ๋ช
'].to_list()
+
+ # 4) ํ๋กฌํํธ ํ
์คํธ ๋ถ๋ฌ์ค๊ธฐ
+ with open(prompt_path, "r", encoding="utf-8") as f:
+ base_prompt = f.read().strip()
+
+ # 5) MAAL ๋ชจ๋ธ ์ค๋น
+ model_id = "maum-ai/Llama-3.2-MAAL-11B-Vision-v0.1"
+ print("[MAAL] Loading model...")
+ model = MllamaForConditionalGeneration.from_pretrained(
+ model_id,
+ torch_dtype=torch.float16,
+ device_map="auto",
+ )
+ processor = AutoProcessor.from_pretrained(model_id)
+
+ # ๋ชจ๋ธ output embeddings ์กฐ์
+ old_embeddings = model.get_output_embeddings()
+ num_tokens = model.vocab_size + 1
+ resized_embeddings = model._get_resized_lm_head(old_embeddings, new_num_tokens=num_tokens, mean_resizing=True)
+ resized_embeddings.requires_grad_(old_embeddings.weight.requires_grad)
+ model.set_output_embeddings(resized_embeddings)
+
+ model_name = "maai"
+ results = []
+
+ def process_image_with_maai(img_url, prompt):
+ try:
+ # ์ด๋ฏธ์ง ๋ก๋ (HTTP GET ํ PIL Image ๋ณํ)
+ image = Image.open(requests.get(img_url, stream=True).raw)
+
+ # ๋ฉ์์ง ๊ตฌ์ฑ (์ ์ ์ญํ ์ [์ด๋ฏธ์ง + ํ
์คํธ ํ๋กฌํํธ])
+ messages = [
+ {"role": "user", "content": [
+ {"type": "image"},
+ {"type": "text", "text": prompt}
+ ]}
+ ]
+ # chat ํ
ํ๋ฆฟ ์ ์ฉ
+ input_text = processor.apply_chat_template(messages, add_generation_prompt=True)
+ inputs = processor(
+ image,
+ input_text,
+ add_special_tokens=False,
+ return_tensors="pt"
+ ).to(model.device)
+
+ # ๋ชจ๋ธ ์ถ๋ก
+ output = model.generate(
+ **inputs,
+ max_new_tokens=256,
+ no_repeat_ngram_size=3,
+ do_sample=False
+ )
+ result = processor.decode(output[0])
+
+ # ์ ๊ท์์ผ๋ก MAAL ๋ชจ๋ธ ๊ฒฐ๊ณผ์์ assistant ๋ถ๋ถ ์ถ์ถ
+ pattern = r"<\|start_header_id\|>assistant<\|end_header_id\|>\n\n(.+?)(?:<\|eot_id\|>|$)"
+ match = re.search(pattern, result, re.DOTALL)
+ if match:
+ return match.group(1).strip()
+ else:
+ return "No match found"
+ except Exception as e:
+ print(f"Error processing image at {img_url}: {e}")
+ return "Error"
+
+ # 6) ๊ฐ ์ด๋ฏธ์ง์ ๋ํด ์ถ๋ก ์ํ
+ for idx, (image_url, product_name) in enumerate(zip(image_urls, product_names)):
+ start_time = time.time()
+ # prompt์ {product_name} ์นํ (ํ์ ์)
+ prompt = base_prompt.replace("{product_name}", product_name)
+
+ output = process_image_with_maai(image_url, prompt)
+ elapsed_time = time.time() - start_time
+
+ results.append({
+ "Model": model_name,
+ "ImageURL": image_url,
+ "Prompt": prompt[:100], # ๊ธธ์ด ์ ํ
+ "Inference Time (s)": elapsed_time,
+ "Model Output": output
+ })
+
+ print(f"[MAAL] Processed {idx+1}/{len(image_urls)} => {elapsed_time:.2f}s => {product_name}")
+
+ # 7) ๊ฒฐ๊ณผ๋ฅผ CSV๋ก ์ ์ฅ
+ df = pd.DataFrame(results)
+ df.to_csv(output_csv, index=False, encoding="utf-8-sig")
+ print(f"[MAAL] Results saved => {output_csv}")
\ No newline at end of file
diff --git a/models/thumbnail_description/src/description_pipeline/inference_model/qwen2_5_vl.py b/models/thumbnail_description/src/description_pipeline/inference_model/qwen2_5_vl.py
new file mode 100644
index 0000000..09eda50
--- /dev/null
+++ b/models/thumbnail_description/src/description_pipeline/inference_model/qwen2_5_vl.py
@@ -0,0 +1,123 @@
+import os
+import time
+import yaml
+import pandas as pd
+import torch
+
+from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
+from qwen_vl_utils import process_vision_info
+
+from utils.common_utils import set_seed, load_and_filter_data
+
+def run_inference_qwen2_5_vl():
+ """
+ Qwen2.5-VL-7B-Instruct ๋ชจ๋ธ์ ์ฌ์ฉํด CSV ๋ฐ์ดํฐ ๋ด ์ด๋ฏธ์ง์ ๋ํด
+ ์ถ๋ก ํ ๊ฒฐ๊ณผ CSV๋ฅผ ์ ์ฅ.
+ config.yaml์ ์ฐธ๊ณ ํ์ฌ ๊ฒฝ๋ก(์
๋ ฅ/ํ๋กฌํํธ/์ถ๋ ฅ)๋ฅผ ์ค์ ํฉ๋๋ค.
+ """
+ # 1) config.yaml ๋ก๋
+ with open("config/config.yaml", "r", encoding="utf-8") as f:
+ config = yaml.safe_load(f)
+
+ data_dir = config["paths"]["data_dir"]
+ prompt_dir = config["paths"]["prompt_dir"]
+ csv_name = config["paths"]["cleaned_text_contents"]
+ out_name = config["paths"]["qwen2.5_eval"]
+
+ # ์ค์ ๊ฒฝ๋ก ๊ตฌ์ฑ
+ csv_path = os.path.join(data_dir, csv_name)
+ prompt_path = os.path.join(prompt_dir, "qwen2_5_prompt.txt")
+ output_csv = os.path.join(data_dir, out_name)
+
+ # 2) ์๋ ์ค์
+ set_seed(42)
+
+ # 3) CSV ๋ก๋ & ํํฐ๋ง
+ df_filtered = load_and_filter_data(csv_path)
+ image_urls = df_filtered['url_clean'].to_list()
+ product_names = df_filtered['์ํ๋ช
'].to_list()
+
+ # 4) ํ๋กฌํํธ ํ
์คํธ ๋ถ๋ฌ์ค๊ธฐ
+ with open(prompt_path, "r", encoding="utf-8") as f:
+ base_prompt = f.read().strip()
+
+ # 5) ๋ชจ๋ธ ๋ก๋ฉ
+ print("[Qwen2.5-VL] Loading model...")
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
+ "Qwen/Qwen2.5-VL-7B-Instruct", torch_dtype="auto", device_map="auto"
+ )
+
+ # Qwen2.5-VL ๋ชจ๋ธ์ ์๊ฐ ํ ํฐ ๋ฒ์ ์ค์
+ min_pixels = 256 * 28 * 28
+ max_pixels = 1280 * 28 * 28
+ processor = AutoProcessor.from_pretrained(
+ "Qwen/Qwen2.5-VL-7B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels
+ )
+
+ model_name = "qwen2.5"
+ results = []
+
+ def process_qwen2_5_vl(img_url, prompt_text):
+ """
+ ๋จ์ผ ์ด๋ฏธ์ง URL๊ณผ ํ๋กฌํํธ๋ฅผ ์ด์ฉํด Qwen2.5-VL ๋ชจ๋ธ ์ถ๋ก ์ ์ํํ๋ ๋ด๋ถ ํจ์.
+ """
+ try:
+ # ๋ฉ์์ง(์ ์ ์ญํ ) ๊ตฌ์ฑ
+ messages = [
+ {
+ "role": "user",
+ "content": [
+ {"type": "image", "image": img_url},
+ {"type": "text", "text": prompt_text},
+ ],
+ }
+ ]
+ # ์ฑํ
ํ
ํ๋ฆฟ ์ ์ฉ
+ input_text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
+ image_inputs, video_inputs = process_vision_info(messages)
+
+ # ๋ชจ๋ธ ์
๋ ฅ
+ inputs = processor(
+ text=[input_text],
+ images=image_inputs,
+ videos=video_inputs,
+ padding=True,
+ return_tensors="pt",
+ ).to("cuda")
+
+ # ํ
์คํธ ์์ฑ
+ generated_ids = model.generate(**inputs, max_new_tokens=512)
+ generated_ids_trimmed = [
+ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
+ ]
+ output_text = processor.batch_decode(
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
+ )
+ return output_text[0]
+ except Exception as e:
+ print(f"Error processing image at {img_url}: {e}")
+ return "Error"
+
+ # 6) ๋ฐ๋ณต ์ถ๋ก
+ for idx, (img_url, product_name) in enumerate(zip(image_urls, product_names)):
+ start_time = time.time()
+
+ prompt = base_prompt.replace("{product_name}", product_name)
+
+ output = process_qwen2_5_vl(img_url, prompt)
+ elapsed_time = time.time() - start_time
+
+ results.append({
+ "Model": model_name,
+ "ImageURL": img_url,
+ "Prompt": prompt
+ "Inference Time (s)": elapsed_time,
+ "Model Output": output
+ })
+
+ print(f"[Qwen2.5-VL] {idx+1}/{len(image_urls)} => {elapsed_time:.2f}s => {product_name}")
+
+ # 7) ๊ฒฐ๊ณผ ์ ์ฅ
+ df = pd.DataFrame(results)
+ df.to_csv(output_csv, index=False, encoding="utf-8-sig")
+ print(f"[Qwen2.5-VL] Results => {output_csv}")
\ No newline at end of file
diff --git a/models/thumbnail_description/src/description_pipeline/inference_model/qwen2_vl.py b/models/thumbnail_description/src/description_pipeline/inference_model/qwen2_vl.py
new file mode 100644
index 0000000..b4ce82a
--- /dev/null
+++ b/models/thumbnail_description/src/description_pipeline/inference_model/qwen2_vl.py
@@ -0,0 +1,124 @@
+import os
+import time
+import yaml
+import pandas as pd
+import torch
+
+from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
+from qwen_vl_utils import process_vision_info
+
+from utils.common_utils import set_seed, load_and_filter_data
+
+def run_inference_qwen2_vl():
+ """
+ Qwen2-VL-7B-Instruct ๋ชจ๋ธ์ ํ์ฉํ์ฌ CSV ๋ฐ์ดํฐ ๋ด ์ด๋ฏธ์ง์ ๋ํด
+ ์ถ๋ก ์ ์ํํ ๋ค ๊ฒฐ๊ณผ๋ฅผ CSV ํ์ผ์ ์ ์ฅ.
+ config.yaml์์ ๊ฒฝ๋ก๋ฅผ ๋ถ๋ฌ์ต๋๋ค.
+ """
+ # 1) config.yaml ๋ก๋
+ with open("config/config.yaml", "r", encoding="utf-8") as f:
+ config = yaml.safe_load(f)
+
+ data_dir = config["paths"]["data_dir"]
+ prompt_dir = config["paths"]["prompt_dir"]
+ csv_name = config["paths"]["cleaned_text_contents"]
+ out_name = config["paths"]["qwen2_eval"]
+
+ # ์ค์ ๊ฒฝ๋ก ๊ตฌ์ฑ
+ csv_path = os.path.join(data_dir, csv_name)
+ prompt_path = os.path.join(prompt_dir, "qwen2_prompt.txt")
+ output_csv = os.path.join(data_dir, out_name)
+
+ # ์๋ ์ค์
+ set_seed(42)
+
+ # CSV ๋ก๋ & ํํฐ๋ง
+ df_filtered = load_and_filter_data(csv_path)
+ image_urls = df_filtered['url_clean'].to_list()
+ product_names = df_filtered['์ํ๋ช
'].to_list()
+
+ # ํ๋กฌํํธ ์ฝ๊ธฐ
+ with open(prompt_path, "r", encoding="utf-8") as f:
+ base_prompt = f.read().strip()
+
+ print("[Qwen2-VL] Loading model and processor...")
+ model = Qwen2VLForConditionalGeneration.from_pretrained(
+ "Qwen/Qwen2-VL-7B-Instruct", torch_dtype="auto", device_map="auto"
+ )
+ # ์ด๋ฏธ์ง ํ ํฐํ ํ๋ผ๋ฏธํฐ
+ min_pixels = 256 * 28 * 28
+ max_pixels = 1280 * 28 * 28
+ processor = AutoProcessor.from_pretrained(
+ "Qwen/Qwen2-VL-7B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels
+ )
+
+ model_name = "qwen2-vl"
+ results = []
+
+ def process_qwen2_vl(img_url, prompt_text):
+ """
+ ๋จ์ผ ์ด๋ฏธ์ง URL๊ณผ ํ๋กฌํํธ๋ก Qwen2-VL ๋ชจ๋ธ์ ์ถ๋ก ํ๋ ๋ด๋ถ ํจ์.
+ """
+ try:
+ # ๋ฉ์์ง ๊ตฌ์ฑ
+ messages = [
+ {
+ "role": "user",
+ "content": [
+ {"type": "image", "image": img_url},
+ {"type": "text", "text": prompt_text},
+ ],
+ }
+ ]
+ # ์ฑํ
ํ
ํ๋ฆฟ ์ ์ฉ
+ input_text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
+ image_inputs, video_inputs = process_vision_info(messages)
+
+ # ๋ชจ๋ธ ์
๋ ฅ ์์ฑ
+ inputs = processor(
+ text=[input_text],
+ images=image_inputs,
+ videos=video_inputs,
+ padding=True,
+ return_tensors="pt",
+ ).to("cuda")
+
+ # ๋ชจ๋ธ ์ถ๋ก
+ generated_ids = model.generate(**inputs, max_new_tokens=512)
+ generated_ids_trimmed = [
+ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
+ ]
+ output_text = processor.batch_decode(
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
+ )
+ return output_text[0]
+ except Exception as e:
+ print(f"Error processing image at {img_url}: {e}")
+ return "Error"
+
+ start_total = time.time()
+ for idx, (img_url, product_name) in enumerate(zip(image_urls, product_names)):
+ start_time = time.time()
+
+ # ํ์ํ ๊ฒฝ์ฐ {product_name} ์นํ
+ prompt = base_prompt.replace("{product_name}", product_name)
+
+ output = process_qwen2_vl(img_url, prompt)
+ elapsed_time = time.time() - start_time
+
+ results.append({
+ "Model": model_name,
+ "ImageURL": img_url,
+ "Prompt": prompt,
+ "Inference Time (s)": elapsed_time,
+ "Model Output": output
+ })
+
+ print(f"[Qwen2-VL] {idx+1}/{len(image_urls)} => {elapsed_time:.2f}s => {product_name}")
+
+ total_time = time.time() - start_total
+ print(f"[Qwen2-VL] Total time: {total_time:.2f}s")
+
+ # ๊ฒฐ๊ณผ CSV ์ ์ฅ
+ pd.DataFrame(results).to_csv(output_csv, index=False, encoding="utf-8-sig")
+ print(f"[Qwen2-VL] Results saved => {output_csv}")
\ No newline at end of file
diff --git a/models/thumbnail_description/src/description_pipeline/inference_model/unsloth_qwen2_vl.py b/models/thumbnail_description/src/description_pipeline/inference_model/unsloth_qwen2_vl.py
new file mode 100644
index 0000000..f945037
--- /dev/null
+++ b/models/thumbnail_description/src/description_pipeline/inference_model/unsloth_qwen2_vl.py
@@ -0,0 +1,145 @@
+import os
+import time
+import re
+import yaml
+import pandas as pd
+import torch
+import requests
+from PIL import Image
+from unsloth import FastVisionModel
+from transformers import TextStreamer
+from qwen_vl_utils import process_vision_info
+
+from utils.common_utils import set_seed, load_and_filter_data
+
+def run_inference_unsloth_qwen2_vl():
+ """
+ unsloth ๊ธฐ๋ฐ Qwen2-VL ๋ชจ๋ธ ์ถ๋ก .
+ config.yaml์ ํตํด CSV, ํ๋กฌํํธ, ๊ฒฐ๊ณผ ๊ฒฝ๋ก๋ฅผ ์ค์ ํด ์คํ.
+ """
+ # 1) config.yaml ๋ก๋
+ with open("config/config.yaml", "r", encoding="utf-8") as f:
+ config = yaml.safe_load(f)
+
+ data_dir = config["paths"]["data_dir"]
+ prompt_dir = config["paths"]["prompt_dir"]
+ csv_name = config["paths"]["cleaned_text_contents"]
+ out_name = config["paths"]["unsloth_qwen2_eval"]
+
+ # ์ค์ ๊ฒฝ๋ก ๊ตฌ์ฑ
+ csv_path = os.path.join(data_dir, csv_name)
+ prompt_path = os.path.join(prompt_dir, "unsloth_prompt.txt")
+ output_csv = os.path.join(data_dir, out_name)
+
+ # 2) ์๋ ์ค์
+ set_seed(42)
+
+ # 3) CSV ๋ก๋ & ํํฐ๋ง
+ df_filtered = load_and_filter_data(csv_path)
+ image_urls = df_filtered['url_clean'].to_list()
+ product_names = df_filtered['์ํ๋ช
'].to_list()
+
+ # 4) ํ๋กฌํํธ ํ
์คํธ ๋ถ๋ฌ์ค๊ธฐ
+ with open(prompt_path, "r", encoding="utf-8") as f:
+ base_prompt = f.read().strip()
+
+ # 5) unsloth ๊ธฐ๋ฐ Qwen2-VL ๋ชจ๋ธ ๋ก๋ฉ
+ print("[unsloth] Loading Qwen2-VL-7B-Instruct model...")
+ model, tokenizer = FastVisionModel.from_pretrained(
+ "unsloth/Qwen2-VL-7B-Instruct",
+ load_in_4bit=False,
+ use_gradient_checkpointing="True",
+ trust_remote_code=True,
+ )
+
+ # LoRA ์ค์ ๋ฑ
+ model = FastVisionModel.get_peft_model(
+ model,
+ finetune_vision_layers=True,
+ finetune_language_layers=True,
+ finetune_attention_modules=True,
+ finetune_mlp_modules=True,
+ r=16,
+ lora_alpha=16,
+ lora_dropout=0,
+ bias="none",
+ random_state=3407,
+ use_rslora=False,
+ loftq_config=None
+ )
+ FastVisionModel.for_inference(model)
+
+ # ์ด๋ฏธ์ง ํฌ๊ธฐ ์ค์
+ min_pixels = 256 * 28 * 28
+ max_pixels = 960 * 28 * 28
+ model_name = "qwen_unsloth"
+ results = []
+
+ def process_single_image(img_url, product_name):
+ """
+ ๋จ์ผ ์ด๋ฏธ์ง URL๊ณผ product_name์ผ๋ก ๋ชจ๋ธ ์ถ๋ก ์คํ.
+ """
+ start_time = time.time()
+
+ prompt_text = base_prompt.replace("{product_name}", product_name)
+ messages = [
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": "image",
+ "image": img_url,
+ "min_pixels": min_pixels,
+ "max_pixels": max_pixels,
+ },
+ {"type": "text", "text": prompt_text}
+ ]
+ }
+ ]
+ input_text = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
+ image_inputs, video_inputs = process_vision_info(messages)
+
+ # ์
๋ ฅ ํ
์ ์์ฑ
+ inputs = tokenizer(
+ image_inputs,
+ input_text,
+ add_special_tokens=False,
+ return_tensors="pt",
+ ).to("cuda")
+
+ # ๋ชจ๋ธ ์ถ๋ก
+ output_ids = model.generate(
+ **inputs,
+ max_new_tokens=256,
+ use_cache=True,
+ temperature=1.5,
+ min_p=0.1
+ )
+ output_texts = tokenizer.batch_decode(
+ output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
+ )
+
+ # assistant\n (ํ
์คํธ) ์ ๊ทํํ์์ผ๋ก ์ถ์ถ
+ pattern = r'assistant\n(.+)'
+ matches = re.findall(pattern, output_texts[0], re.DOTALL)
+ final_output = matches[0] if matches else "No valid output generated"
+
+ elapsed = time.time() - start_time
+ return final_output, elapsed
+
+ # 6) ๊ฐ ์ด๋ฏธ์ง์ ๋ํด ์ถ๋ก
+ for idx, (img_url, prod_name) in enumerate(zip(image_urls, product_names)):
+ caption, elapsed_time = process_single_image(img_url, prod_name)
+ results.append({
+ "Model": model_name,
+ "ImageURL": img_url,
+ "Product Name": prod_name,
+ "Inference Time (s)": elapsed_time,
+ "Model Output": caption
+ })
+ print(f"[unsloth Qwen2-VL] {idx+1}/{len(image_urls)} => {elapsed_time:.2f}s => {prod_name}")
+
+ # 7) ๊ฒฐ๊ณผ CSV ์ ์ฅ
+ df = pd.DataFrame(results)
+ df.to_csv(output_csv, index=False, encoding="utf-8-sig")
+ print(f"[unsloth Qwen2-VL] Saved results => {output_csv}")
\ No newline at end of file
diff --git a/models/thumbnail_description/src/description_pipeline/post_processing/janus_pro_hcx_translation.py b/models/thumbnail_description/src/description_pipeline/post_processing/janus_pro_hcx_translation.py
new file mode 100644
index 0000000..64e3ac9
--- /dev/null
+++ b/models/thumbnail_description/src/description_pipeline/post_processing/janus_pro_hcx_translation.py
@@ -0,0 +1,104 @@
+from utils.common_utils import (
+ set_seed, requests, pd, time
+)
+import yaml
+
+class CompletionExecutor:
+ def __init__(self, host, api_key, request_id):
+ self._host = host
+ self._api_key = api_key
+ self._request_id = request_id
+
+ def execute(self, completion_request, max_retries=5, retry_delay=20):
+ headers = {
+ 'Authorization': self._api_key,
+ 'X-NCP-CLOVASTUDIO-REQUEST-ID': self._request_id,
+ 'Content-Type': 'application/json; charset=utf-8',
+ }
+ for attempt in range(max_retries):
+ try:
+ url = self._host + '/testapp/v1/chat-completions/HCX-003'
+ response = requests.post(url, headers=headers, json=completion_request)
+ response_data = response.json()
+
+ if response_data.get("status", {}).get("code") == "20000":
+ return response_data["result"]["message"]["content"]
+ else:
+ raise ValueError(f"Invalid status code: {response_data.get('status', {}).get('code')}")
+ except (requests.RequestException, ValueError, KeyError) as e:
+ if attempt < max_retries - 1:
+ print(f"์๋ฌ ๋ฐ์: {str(e)}. {retry_delay}์ด ํ ์ฌ์๋ํฉ๋๋ค. (์๋ {attempt+1}/{max_retries})")
+ time.sleep(retry_delay)
+ else:
+ print(f"์ต๋ ์ฌ์๋ ํ์ {max_retries}ํ๋ฅผ ์ด๊ณผํ์ต๋๋ค. ์ต์ข
์๋ฌ: {str(e)}")
+ return None
+
+def main():
+ # 1) config.yaml ๋ก๋
+ with open("config/config.yaml", "r", encoding="utf-8") as f:
+ config_data = yaml.safe_load(f)
+
+ # HCX API ์ ๋ณด ๊ฐ์ ธ์ค๊ธฐ
+ host = config_data["hcx_api"]["host"]
+ api_key = config_data["hcx_api"]["api_key"]
+ request_id = config_data["hcx_api"]["request_id"]
+
+ # paths ์น์
์์ ๊ฒฝ๋ก ๊ฐ์ ธ์ค๊ธฐ
+ data_dir = config_data["paths"]["data_dir"]
+ hcx_prompt_dir = config_data["paths"].get("hcx_prompt_dir", "hcx_prompt")
+ csv_name = config_data["paths"].get("qwen2_5+janus_323_eval", "qwen2_5+janus_323_eval.csv")
+
+ # ์ค์ CSV ๊ฒฝ๋ก ๊ตฌ์ฑ
+ csv_path = os.path.join(data_dir, csv_name)
+
+ # prompt ํ์ผ ๊ฒฝ๋ก ๊ตฌ์ฑ
+ system_prompt_path = os.path.join(hcx_prompt_dir, "system_janus_pro_hcx_translation.txt")
+ user_prompt_path = os.path.join(hcx_prompt_dir, "user_janus_pro_hcx_translation.txt")
+
+ completion_executor = CompletionExecutor(host, api_key, request_id)
+
+ # system / user prompt ํ์ผ ๋ถ๋ฌ์ค๊ธฐ
+ with open(system_prompt_path, "r", encoding="utf-8") as f:
+ system_prompt = f.read().strip()
+
+ with open(user_prompt_path, "r", encoding="utf-8") as f:
+ user_prompt_template = f.read().strip()
+
+ # CSV ๋ก๋
+ df = pd.read_csv(csv_path)
+
+ # ๊ฐ ํ์ ๋ํด ์ถ๋ก ์คํ
+ for idx, row in df.iterrows():
+ # "Model Output" ์นผ๋ผ์์ ํ
์คํธ ๊ฐ์ ธ์ค๊ธฐ
+ model_output_text = row.get("Model Output", "")
+
+ # user ํ๋กฌํํธ ์์ฑ
+ user_prompt = user_prompt_template.replace("{model_output_text}", model_output_text)
+
+ request_data = {
+ 'messages': [
+ {"role": "system", "content": system_prompt},
+ {"role": "user", "content": user_prompt}
+ ],
+ 'topP': 0.9,
+ 'topK': 0,
+ 'maxTokens': 1024,
+ 'temperature': 0.1,
+ 'repeatPenalty': 5.0,
+ 'stopBefore': [],
+ 'includeAiFilters': True,
+ 'seed': 42
+ }
+
+ # HCX ๋ชจ๋ธ ํธ์ถ
+ model_output_ko = completion_executor.execute(request_data)
+ df.loc[idx, "Model Output HCX"] = model_output_ko
+
+ print(idx, model_output_ko)
+
+ # ๊ฒฐ๊ณผ ์ ์ฅ (๋ฎ์ด์ฐ๊ธฐ)
+ df.to_csv(csv_path, index=False, encoding="utf-8-sig")
+ print(f"[janus_pro_hcx_translation] ํ์ผ ์ ์ฅ ์๋ฃ: {csv_path}")
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/models/thumbnail_description/src/description_pipeline/post_processing/janus_pro_papago_translation.py b/models/thumbnail_description/src/description_pipeline/post_processing/janus_pro_papago_translation.py
new file mode 100644
index 0000000..7c34af9
--- /dev/null
+++ b/models/thumbnail_description/src/description_pipeline/post_processing/janus_pro_papago_translation.py
@@ -0,0 +1,46 @@
+from utils.common_utils import (
+ set_seed, requests, pd
+)
+import yaml
+
+def translate_text(text, source_lang, target_lang, client_id, client_secret):
+ url = "https://naveropenapi.apigw.ntruss.com/nmt/v1/translation"
+ headers = {
+ "x-ncp-apigw-api-key-id": client_id,
+ "x-ncp-apigw-api-key": client_secret
+ }
+ data = {
+ "source": source_lang,
+ "target": target_lang,
+ "text": text
+ }
+
+ resp = requests.post(url, headers=headers, data=data)
+ if resp.status_code == 200:
+ js = resp.json()
+ return js['message']['result']['translatedText']
+ else:
+ print(f"Papago Error Code: {resp.status_code}")
+ return None
+
+def main():
+ # config.yaml์์ Papago API ์ ๋ณด ๋ก๋
+ with open("config/config.yaml", "r", encoding="utf-8") as f:
+ cfg = yaml.safe_load(f)
+
+ client_id = cfg["papago_api"]["client_id"]
+ client_secret = cfg["papago_api"]["client_secret"]
+
+ csv_path = "qwen2_5+janus_323_eval.csv"
+ df = pd.read_csv(csv_path)
+
+ # Model Output ์ด ์์ด๋ฅผ ํ๊ตญ์ด๋ก ๋ฒ์ญ
+ df["Model Output Papago"] = df["Model Output"].apply(
+ lambda x: translate_text(str(x), "en", "ko", client_id, client_secret)
+ )
+
+ df.to_csv(csv_path, index=False, encoding="utf-8-sig")
+ print(f"[papago_translation] ๋ฒ์ญ ์๋ฃ => {csv_path}")
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/models/thumbnail_description/src/description_pipeline/post_processing/janus_pro_pp_hcx.py b/models/thumbnail_description/src/description_pipeline/post_processing/janus_pro_pp_hcx.py
new file mode 100644
index 0000000..fb5d124
--- /dev/null
+++ b/models/thumbnail_description/src/description_pipeline/post_processing/janus_pro_pp_hcx.py
@@ -0,0 +1,106 @@
+from utils.common_utils import (
+ set_seed, requests, pd, time
+)
+import yaml
+
+class CompletionExecutor:
+ def __init__(self, host, api_key, request_id):
+ self._host = host
+ self._api_key = api_key
+ self._request_id = request_id
+
+ def execute(self, completion_request, max_retries=5, retry_delay=20):
+ headers = {
+ 'Authorization': self._api_key,
+ 'X-NCP-CLOVASTUDIO-REQUEST-ID': self._request_id,
+ 'Content-Type': 'application/json; charset=utf-8',
+ }
+ for attempt in range(max_retries):
+ try:
+ response = requests.post(
+ self._host + '/testapp/v1/chat-completions/HCX-003',
+ headers=headers, json=completion_request
+ )
+ data = response.json()
+
+ if data.get("status", {}).get("code") == "20000":
+ return data["result"]["message"]["content"]
+ else:
+ raise ValueError(f"Invalid status code: {data.get('status', {}).get('code')}")
+ except (requests.RequestException, ValueError, KeyError) as e:
+ if attempt < max_retries - 1:
+ print(f"์๋ฌ ๋ฐ์: {str(e)}. {retry_delay}์ด ํ ์ฌ์๋ (์๋ {attempt+1}/{max_retries})")
+ time.sleep(retry_delay)
+ else:
+ print(f"์ต๋ ์ฌ์๋ ํ์ {max_retries}ํ ์ด๊ณผ. ์ต์ข
์๋ฌ: {str(e)}")
+ return None
+
+def main():
+ # 1) config.yaml ๋ก๋
+ with open("config/config.yaml", "r", encoding="utf-8") as f:
+ config = yaml.safe_load(f)
+
+ # HCX API ์ ๋ณด
+ host = config["hcx_api"]["host"]
+ api_key = config["hcx_api"]["api_key"]
+ request_id = config["hcx_api"]["request_id"]
+
+ # data_dir & prompt_dir ๋ฑ ๊ฐ์ ธ์ค๊ธฐ
+ data_dir = config["paths"]["data_dir"]
+ prompt_dir = config["paths"]["prompt_dir"]
+
+ # Foodly_323_product_information.csv ๊ฒฝ๋ก
+ csv_name = config["paths"]["Foodly_323_product_information"]
+ csv_path = os.path.join(data_dir, csv_name)
+
+ # system / user ํ๋กฌํํธ ํ์ผ ๊ฒฝ๋ก
+ system_prompt_path = os.path.join(prompt_dir, "system_janus_pro_hcx_fewshot.txt")
+ user_prompt_path = os.path.join(prompt_dir, "user_janus_pro_hcx_fewshot.txt")
+
+ completion_executor = CompletionExecutor(host, api_key, request_id)
+
+ # CSV ํ์ผ ๋ก๋
+ df = pd.read_csv(csv_path)
+
+ # system / user prompt ๋ถ๋ฌ์ค๊ธฐ
+ with open(system_prompt_path, "r", encoding="utf-8") as sf:
+ system_prompt_fewshot = sf.read().strip()
+
+ with open(user_prompt_path, "r", encoding="utf-8") as uf:
+ user_prompt_template_fewshot = uf.read().strip()
+
+ # ์๋ ์ค์
+ set_seed(42)
+
+ # 3) ๊ฐ ํ์ ๋ํด HCX ๋ชจ๋ธ ํธ์ถ
+ for idx, row in df.iterrows():
+ model_output_text = row.get("Janus_Pro_Model_Output", "")
+ # user ํ๋กฌํํธ์ CSV์์ ๊ฐ์ ธ์จ ํ
์คํธ๋ฅผ ์นํ
+ user_prompt = user_prompt_template_fewshot.replace("{model_output_text}", model_output_text)
+
+ request_data = {
+ "messages": [
+ {"role": "system", "content": system_prompt_fewshot},
+ {"role": "user", "content": user_prompt}
+ ],
+ 'topP': 0.9,
+ 'topK': 0,
+ 'maxTokens': 1024,
+ 'temperature': 0.1,
+ 'repeatPenalty': 5.0,
+ 'stopBefore': [],
+ 'includeAiFilters': True,
+ 'seed': 42
+ }
+
+ result_ko = completion_executor.execute(request_data)
+ df.loc[idx, "Janus_Pro_Model_Output_HCX"] = result_ko
+
+ print(idx, result_ko)
+
+ # 4) ๊ฒฐ๊ณผ CSV ์ ์ฅ (๊ฐ์ ํ์ผ์ ๋ฎ์ด์ฐ๊ธฐ)
+ df.to_csv(csv_path, index=False, encoding="utf-8-sig")
+ print(f"[fewshot_janus_pro_hcx] ํ์ผ ์ ์ฅ ์๋ฃ => {csv_path}")
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/models/thumbnail_description/src/description_pipeline/post_processing/qwen2_5_pp_hcx.py b/models/thumbnail_description/src/description_pipeline/post_processing/qwen2_5_pp_hcx.py
new file mode 100644
index 0000000..53ac0af
--- /dev/null
+++ b/models/thumbnail_description/src/description_pipeline/post_processing/qwen2_5_pp_hcx.py
@@ -0,0 +1,94 @@
+from utils.common_utils import (
+ set_seed, requests, pd, time
+)
+import yaml
+
+class CompletionExecutor:
+ def __init__(self, host, api_key, request_id):
+ self._host = host
+ self._api_key = api_key
+ self._request_id = request_id
+
+ def execute(self, completion_request, max_retries=5, retry_delay=20):
+ headers = {
+ 'Authorization': self._api_key,
+ 'X-NCP-CLOVASTUDIO-REQUEST-ID': self._request_id,
+ 'Content-Type': 'application/json; charset=utf-8'
+ }
+ for attempt in range(max_retries):
+ try:
+ response = requests.post(
+ self._host + '/testapp/v1/chat-completions/HCX-003',
+ headers=headers, json=completion_request
+ )
+ data = response.json()
+
+ if data.get("status", {}).get("code") == "20000":
+ return data["result"]["message"]["content"]
+ else:
+ raise ValueError(f"Invalid status code: {data.get('status', {}).get('code')}")
+ except (requests.RequestException, ValueError, KeyError) as e:
+ if attempt < max_retries - 1:
+ print(f"์๋ฌ ๋ฐ์: {str(e)}. {retry_delay}์ด ํ ์ฌ์๋. (์๋ {attempt+1}/{max_retries})")
+ time.sleep(retry_delay)
+ else:
+ print(f"์ต๋ ์ฌ์๋ ํ์ {max_retries}ํ ์ด๊ณผ. ์ต์ข
์๋ฌ: {str(e)}")
+ return None
+
+def main():
+ # 1) config.yaml ๋ก๋
+ with open("config/config.yaml", "r", encoding="utf-8") as f:
+ cfg = yaml.safe_load(f)
+
+ host = cfg["hcx_api"]["host"]
+ api_key = cfg["hcx_api"]["api_key"]
+ request_id = cfg["hcx_api"]["request_id"]
+
+ completion_executor = CompletionExecutor(host, api_key, request_id)
+
+ # 2) system / user prompt๋ฅผ ํ์ผ์์ ๋ถ๋ฌ์ค๊ธฐ
+ with open("hcx_prompt/system_qwen2_5_pp_hcx.txt", "r", encoding="utf-8") as sf:
+ system_prompt = sf.read().strip()
+
+ with open("hcx_prompt/user_qwen2_5_pp_hcx.txt", "r", encoding="utf-8") as uf:
+ user_prompt_template = uf.read().strip()
+
+ # 3) CSV ๋ก๋
+ csv_path = "qwen2.5_323_eval.csv"
+ df = pd.read_csv(csv_path)
+
+ # ์๋ ์ค์
+ set_seed(42)
+
+ # 4) ๊ฐ ํ๋ณ๋ก HCX ์์ฒญ
+ for idx, row in df.iterrows():
+ model_output_text = row.get("Model Output", "")
+
+ # user prompt ๊ตฌ์ฑ
+ user_prompt = user_prompt_template.replace("{model_output_text}", model_output_text)
+
+ request_data = {
+ "messages": [
+ {"role": "system", "content": system_prompt},
+ {"role": "user", "content": user_prompt}
+ ],
+ 'topP': 0.9,
+ 'topK': 0,
+ 'maxTokens': 1024,
+ 'temperature': 0.1,
+ 'repeatPenalty': 5.0,
+ 'stopBefore': [],
+ 'includeAiFilters': True,
+ 'seed': 42
+ }
+
+ model_output_HCX_PP = completion_executor.execute(request_data)
+ df.loc[idx, "Model Output HCX PP"] = model_output_HCX_PP
+ print(idx, model_output_HCX_PP)
+
+ # 5) CSV ์ ์ฅ
+ df.to_csv(csv_path, index=False, encoding="utf-8-sig")
+ print(f"[qwen2_5_pp_hcx] ํ์ผ ์ ์ฅ ์๋ฃ => {csv_path}")
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/models/thumbnail_description/src/sft_pipeline/detailed_feature_description.py b/models/thumbnail_description/src/sft_pipeline/detailed_feature_description.py
new file mode 100644
index 0000000..3b824a9
--- /dev/null
+++ b/models/thumbnail_description/src/sft_pipeline/detailed_feature_description.py
@@ -0,0 +1,173 @@
+import argparse
+import csv
+import os
+import requests
+import json
+from io import BytesIO
+from PIL import Image
+import yaml
+import logging
+import openai # pip install openai
+
+def load_config(config_path: str) -> dict:
+ with open(config_path, "r", encoding="utf-8") as f:
+ return yaml.safe_load(f)
+
+def clean_response_text(text: str) -> str:
+ """
+ ์๋ต ํ
์คํธ์์ Markdown ์ฝ๋ ๋ธ๋ก(๋ฐฑํฑ)์ ์ ๊ฑฐํฉ๋๋ค.
+ """
+ text = text.strip()
+ if text.startswith("```"):
+ lines = text.splitlines()
+ if lines[-1].strip().startswith("```"):
+ lines = lines[1:-1]
+ else:
+ lines = lines[1:]
+ text = "\n".join(lines).strip()
+ return text
+
+def analyze_image_with_gpt4o(img_url: str, client: object) -> dict:
+ """
+ GPT-4o๋ฅผ ํ์ฉํ์ฌ ์ ํ ์ด๋ฏธ์ง๋ฅผ ๋ถ์ํ๊ณ ,
+ texture, shape, color, transparency, design ์ ๋ณด๋ฅผ ์ถ์ถํฉ๋๋ค.
+
+ ๋ฐํ JSON ํ์:
+ {
+ "texture": "์: plastic",
+ "shape": "์: rectangular",
+ "color": "์: primary: red, secondary: black",
+ "transparency": "Yes or No",
+ "design": "์ ์ฒด ๋์์ธ์ ๋ํ ํ์ค ์ค๋ช
"
+ }
+ """
+ prompt = (
+ "Analyze the following product image and extract the details in English. "
+ "Provide the answer strictly in JSON format with the following keys:\n\n"
+ "1. texture\n2. shape\n3. color\n4. transparency\n5. design\n\n"
+ "Return your answer only as a JSON object."
+ )
+
+ try:
+ response = client.ChatCompletion.create(
+ model="gpt-4o",
+ messages=[
+ {
+ "role": "user",
+ "content": [
+ {"type": "text", "text": prompt},
+ {"type": "image_url", "image_url": {"url": img_url}},
+ ],
+ }
+ ],
+ temperature=0.1,
+ max_tokens=1024,
+ )
+ # ์๋ต ๋ฉ์์ง์ ํ
์คํธ ๋ด์ฉ ์ ๊ทผ
+ response_text = response.choices[0].message.content
+ logging.info("DEBUG: Raw response text: %s", response_text)
+ cleaned_text = clean_response_text(response_text)
+ logging.info("DEBUG: Cleaned response text: %s", cleaned_text)
+ if not cleaned_text.strip():
+ raise ValueError("Received empty response text from the API.")
+ result = json.loads(cleaned_text)
+ except Exception as e:
+ logging.error("Error during GPT-4o analysis: %s", e)
+ result = {"texture": "", "shape": "", "color": "", "transparency": "", "design": ""}
+ return result
+
+def process_csv(input_csv: str, output_csv: str, client: object, start_row: int = 324):
+ """
+ input_csv ํ์ผ์ ๊ฐ ํ์ ๋ํด "๋ํ ์ด๋ฏธ์ง URL"์์ ์ด๋ฏธ์ง๋ฅผ ์ ๊ทผ ๊ฐ๋ฅํ์ง ํ์ธํ๊ณ ,
+ GPT-4o๋ฅผ ์ฌ์ฉํ์ฌ ์ ํ ์ ๋ณด๋ฅผ ์ถ์ถํ ํ, output_csv ํ์ผ์ ์ ์ฅํฉ๋๋ค.
+
+ ์ถ๋ ฅ CSV๋ ๋ค์ ์ด์ ๋ฐ๋์ ํฌํจํฉ๋๋ค:
+ "๋ํ ์ด๋ฏธ์ง URL", "texture", "shape", "color", "transparency", "design"
+ """
+ # ์ถ๋ ฅ CSV๊ฐ ์ด๋ฏธ ์กด์ฌํ๋ฉด append, ์์ผ๋ฉด ์๋ก ์์ฑ
+ if os.path.exists(output_csv):
+ output_mode = 'a'
+ write_header = False
+ else:
+ output_mode = 'w'
+ write_header = True
+
+ with open(input_csv, newline='', encoding='utf-8') as infile:
+ reader = csv.DictReader(infile)
+ with open(output_csv, output_mode, newline='', encoding='utf-8') as outfile:
+ fieldnames = ["๋ํ ์ด๋ฏธ์ง URL", "texture", "shape", "color", "transparency", "design"]
+ writer = csv.DictWriter(outfile, fieldnames=fieldnames)
+ if write_header:
+ writer.writeheader()
+ for i, row in enumerate(reader, start=0):
+ # start_row ์ด์ ์ ํ์ ๊ฑด๋๋ฐ๊ธฐ
+ if i < start_row:
+ continue
+
+ image_url = row.get("๋ํ ์ด๋ฏธ์ง URL", "").strip()
+ if not image_url:
+ logging.info("Row %s: Empty image_url; skipping row.", i)
+ continue
+
+ # ์ด๋ฏธ์ง URL์ ์ ๊ทผ์ฑ ํ์ธ
+ try:
+ resp = requests.get(image_url, timeout=10)
+ resp.raise_for_status()
+ except Exception as e:
+ logging.error("Row %s: Failed to access image from %s: %s", i, image_url, e)
+ continue
+
+ # GPT-4o๋ฅผ ์ด์ฉํ ์ด๋ฏธ์ง ๋ถ์ ์ํ
+ analysis = analyze_image_with_gpt4o(image_url, client)
+
+ output_row = {
+ "๋ํ ์ด๋ฏธ์ง URL": image_url,
+ "texture": analysis.get("texture", ""),
+ "shape": analysis.get("shape", ""),
+ "color": analysis.get("color", ""),
+ "transparency": analysis.get("transparency", ""),
+ "design": analysis.get("design", "")
+ }
+ writer.writerow(output_row)
+ logging.info("Row %s: Processed: %s", i, image_url)
+
+def main():
+ logging.basicConfig(
+ level=logging.INFO,
+ format="%(asctime)s [%(levelname)s] %(name)s - %(message)s"
+ )
+ parser = argparse.ArgumentParser(description="GPT-4o ์ด๋ฏธ์ง ๋ถ์ ๋ฐ CSV ์ฒ๋ฆฌ")
+ parser.add_argument(
+ "--config",
+ "-c",
+ default="config/config.yaml",
+ help="์ค์ ํ์ผ ๊ฒฝ๋ก (๊ธฐ๋ณธ๊ฐ: config/config.yaml)"
+ )
+ args = parser.parse_args()
+
+ # ์ค์ ํ์ผ ๋ก๋
+ config = load_config(args.config)
+
+ # OpenAI API ํค ์ค์ (config.yaml์ openai ์น์
ํ์ฉ)
+ openai_api_key = config.get("openai", {}).get("api_key", "")
+ if not openai_api_key:
+ raise ValueError("OpenAI API key is not provided in config.")
+ openai.api_key = openai_api_key
+
+ # OpenAI ํด๋ผ์ด์ธํธ (openai ๋ชจ๋ ์ฌ์ฉ)
+ client = openai
+
+ # CSV ํ์ผ ๊ฒฝ๋ก ์ค์ (config.yaml์ paths ์น์
ํ์ฉ)
+ data_dir = config.get("paths", {}).get("data_dir", "./data")
+ input_csv_filename = config.get("paths", {}).get("thumbnail_1347", "thumbnail_1347.csv")
+ output_csv_filename = config.get("paths", {}).get("thumbnail_1347_gpt_train", "thumbnail_1347_gpt_train")
+ input_csv = os.path.join(data_dir, input_csv_filename)
+ output_csv = os.path.join(data_dir, output_csv_filename)
+
+ logging.info("Input CSV: %s", input_csv)
+ logging.info("Output CSV: %s", output_csv)
+
+ process_csv(input_csv, output_csv, client)
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/models/thumbnail_description/src/sft_pipeline/janus_pro_7b_finetuning.py b/models/thumbnail_description/src/sft_pipeline/janus_pro_7b_finetuning.py
new file mode 100644
index 0000000..43022c2
--- /dev/null
+++ b/models/thumbnail_description/src/sft_pipeline/janus_pro_7b_finetuning.py
@@ -0,0 +1,199 @@
+import argparse
+import yaml
+import os
+import urllib.request
+import logging
+from io import BytesIO
+from PIL import Image
+import torch
+from transformers import (
+ AutoTokenizer,
+ AutoModelForCausalLM,
+ TrainingArguments,
+ Trainer,
+ TrainerCallback
+)
+from datasets import load_dataset
+from dataclasses import dataclass
+from typing import Any, Dict
+from janus.models import VLChatProcessor, MultiModalityCausalLM
+from janus.utils.io import load_pil_images
+
+# Trainer ์ฝ๋ฐฑ: GPU ๋ฉ๋ชจ๋ฆฌ ํด์
+class MoveToCPUTensorCallback(TrainerCallback):
+ def on_step_end(self, args, state, control, **kwargs):
+ torch.cuda.empty_cache()
+ torch.cuda.ipc_collect()
+
+def main():
+ # argparse๋ฅผ ํตํด config ํ์ผ ๊ฒฝ๋ก๋ฅผ ์
๋ ฅ๋ฐ์
+ parser = argparse.ArgumentParser(description="Janus-Pro 7B ํ์ธํ๋ ์คํ")
+ parser.add_argument(
+ "--config",
+ "-c",
+ default="config/config.yaml",
+ help="์ค์ ํ์ผ ๊ฒฝ๋ก (๊ธฐ๋ณธ๊ฐ: config/config.yaml)"
+ )
+ args = parser.parse_args()
+
+ # config.yaml ๋ก๋
+ with open(args.config, "r", encoding="utf-8") as f:
+ config = yaml.safe_load(f)
+
+ # config.yaml์ ๊ฐ์ ํ์ฉํ์ฌ ๋ณ์ ์ค์
+ # ๋ชจ๋ธ๋ช
: config์ janus_pro_finetuning ์น์
์ด ์๋ค๋ฉด ์ฌ์ฉ, ์์ผ๋ฉด ๊ธฐ๋ณธ๊ฐ ์ฌ์ฉ
+ model_name = config.get("janus_pro_finetuning", {}).get("model_name", "deepseek-ai/Janus-Pro-7B")
+
+ # CSV ํ์ผ ๊ฒฝ๋ก: paths > data_dir์ thumbnail_1347_gpt_human_labeling_train ๊ฐ์ ์กฐํฉ
+ data_dir = config.get("paths", {}).get("data_dir", "./data")
+ csv_file = config.get("paths", {}).get("thumbnail_1347_gpt_human_labeling_train", "thumbnail_1347_gpt_human_labeling_train.csv")
+ csv_path = os.path.join(data_dir, csv_file)
+
+ # ๋ชจ๋ธ ์ ์ฅ ์ถ๋ ฅ ๋๋ ํ ๋ฆฌ (config์ ๋ณ๋ ์ค์ ์ด ์์ผ๋ฉด ๊ธฐ๋ณธ๊ฐ ์ฌ์ฉ)
+ output_dir = config.get("paths", {}).get("janus_pro_output_dir", "./janus_pro7b_finetuned")
+
+ logging.info(f"๋ชจ๋ธ๋ช
: {model_name}")
+ logging.info(f"ํ์ต CSV ํ์ผ ๊ฒฝ๋ก: {csv_path}")
+ logging.info(f"์ถ๋ ฅ ๋๋ ํ ๋ฆฌ: {output_dir}")
+
+ # 1. ๋ชจ๋ธ ๋ฐ ํ๋ก์ธ์/ํ ํฌ๋์ด์ ๋ก๋
+ # VLChatProcessor๋ฅผ ์ฌ์ฉํด ์ด๋ฏธ์ง ์ ์ฒ๋ฆฌ ๋ฐ ํ ํฌ๋์ด์ ๋ก๋
+ vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_name)
+ tokenizer = vl_chat_processor.tokenizer
+
+ # ๋ชจ๋ธ ๋ก๋ (์์ํ ์์ด, FP16 precision ์ฌ์ฉ)
+ model = AutoModelForCausalLM.from_pretrained(
+ model_name,
+ device_map="auto",
+ trust_remote_code=True,
+ torch_dtype=torch.float16
+ )
+
+ # 2. CSV ํ์ผ๋ก๋ถํฐ ๋ฐ์ดํฐ์
๋ก๋
+
+ data_files = {"train": csv_path}
+ dataset = load_dataset("csv", data_files=data_files)
+
+ # 3. ์ ์ฒ๋ฆฌ ํจ์ ์ ์
+ def preprocess_function(example):
+ # ํ
์คํธ ํ๋กฌํํธ ๊ตฌ์ฑ (CSV ํ์ผ ๋ด ์ปฌ๋ผ๋ช
์ ๋ง๊ฒ ์์ )
+ text_prompt = (
+ f"Texture: {example['texture']}. "
+ f"Shape: {example['shape']}. "
+ f"Color: {example['color']}. "
+ f"Transparency: {example['transparency']}. "
+ f"Design: {example['design']}."
+ )
+
+ # ์ด๋ฏธ์ง ๋ค์ด๋ก๋ ๋ฐ ์ ์ฒ๋ฆฌ
+ urllib.request.urlretrieve(example["๋ํ ์ด๋ฏธ์ง URL"], "full_test.jpg")
+ img = Image.open("full_test.jpg").convert("RGB")
+ target_size = (224, 224)
+ img.thumbnail(target_size, Image.Resampling.LANCZOS)
+ img.save("test.jpg")
+
+ # ๋ํ ํ์ ์
๋ ฅ ๊ตฌ์ฑ
+ conversation = {
+ "role": "<|User|>",
+ "content": f"\n{text_prompt}>",
+ "images": ["test.jpg"]
+ }
+
+ # PIL ์ด๋ฏธ์ง ๋ก๋
+ pil_images = load_pil_images([conversation])
+
+ # VLChatProcessor๋ฅผ ์ฌ์ฉํด ์
๋ ฅ ๋ณํ
+ prepare_inputs = vl_chat_processor(
+ conversations=[conversation],
+ images=pil_images,
+ force_batchify=True
+ ).to(model.device)
+
+ # ์
๋ ฅ ๋์
๋๋ฆฌ ๋ณํ: ํ
์ํ ๋ฐ์ดํฐ๋ FP16์ผ๋ก ๋ณํ
+ prepare_inputs_dict = {
+ k: (v.to(torch.float16) if isinstance(v, torch.Tensor) and v.is_floating_point() else v)
+ for k, v in vars(prepare_inputs).items() if not k.startswith("_")
+ }
+
+ # ์ด๋ฏธ์ง ์๋ฒ ๋ฉ ์์ฑ
+ inputs_embeds = model.prepare_inputs_embeds(**prepare_inputs_dict)
+ if not isinstance(inputs_embeds, torch.Tensor):
+ inputs_embeds = torch.tensor(inputs_embeds)
+ inputs_embeds = inputs_embeds.to(torch.float16)
+
+ # ๋ผ๋ฒจ ์์ฑ: ํ
์คํธ ํ๋กฌํํธ๋ฅผ ํ ํฌ๋์ด์ ๋ก ์ธ์ฝ๋ฉ
+ labels = tokenizer(
+ text_prompt,
+ padding="max_length",
+ truncation=True,
+ max_length=128,
+ return_tensors="pt"
+ )["input_ids"].squeeze(0)
+
+ return {
+ "inputs_embeds": inputs_embeds.squeeze(0),
+ "labels": labels
+ }
+
+ # ๋ฐ์ดํฐ์
์ ์ ์ฒ๋ฆฌ ํจ์ ์ ์ฉ
+ dataset = dataset.map(preprocess_function, batched=False)
+
+ # 4. ๋ฐ์ดํฐ Collator ์ ์
+ @dataclass
+ class DataCollatorForVLChat:
+ def __call__(self, features: list) -> Dict[str, torch.Tensor]:
+ inputs_embeds = []
+ labels = []
+ for f in features:
+ # inputs_embeds ์ฒ๋ฆฌ
+ if isinstance(f["inputs_embeds"], torch.Tensor):
+ inputs_embeds.append(f["inputs_embeds"])
+ else:
+ inputs_embeds.append(torch.tensor(f["inputs_embeds"], dtype=torch.float16))
+ # labels ์ฒ๋ฆฌ
+ if isinstance(f["labels"], torch.Tensor):
+ labels.append(f["labels"])
+ else:
+ labels.append(torch.tensor(f["labels"], dtype=torch.long))
+ inputs_embeds_batch = torch.stack(inputs_embeds)
+ labels_batch = torch.stack(labels)
+ return {
+ "inputs_embeds": inputs_embeds_batch,
+ "labels": labels_batch
+ }
+
+ data_collator = DataCollatorForVLChat()
+
+ # 5. TrainingArguments ์ค์
+ training_args = TrainingArguments(
+ output_dir=output_dir,
+ num_train_epochs=1,
+ per_device_train_batch_size=1,
+ learning_rate=5e-5,
+ save_steps=80,
+ fp16=True,
+ logging_steps=100,
+ save_total_limit=1,
+ gradient_accumulation_steps=4,
+ remove_unused_columns=False
+ )
+
+ # 6. Trainer ์์ฑ ๋ฐ ํ์ต
+ trainer = Trainer(
+ model=model,
+ args=training_args,
+ train_dataset=dataset["train"],
+ tokenizer=tokenizer,
+ data_collator=data_collator,
+ callbacks=[MoveToCPUTensorCallback()]
+ )
+
+ trainer.train()
+ trainer.save_model(training_args.output_dir)
+
+if __name__ == "__main__":
+ logging.basicConfig(
+ level=logging.INFO,
+ format="%(asctime)s [%(levelname)s] %(name)s - %(message)s"
+ )
+ main()
\ No newline at end of file
diff --git a/models/thumbnail_description/utils/__init__.py b/models/thumbnail_description/utils/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/models/thumbnail_description/utils/common_utils.py b/models/thumbnail_description/utils/common_utils.py
new file mode 100644
index 0000000..41eb602
--- /dev/null
+++ b/models/thumbnail_description/utils/common_utils.py
@@ -0,0 +1,64 @@
+# utils/common_utils.py
+
+import os
+import random
+import numpy as np
+import torch
+import pandas as pd
+
+def set_seed(seed: int = 42) -> None:
+ """
+ ๋ค์ํ ๋ผ์ด๋ธ๋ฌ๋ฆฌ์ ํ๋ซํผ์์ ์ฌํ์ฑ์ ์ํ ์๋ ์ค์ ํจ์.
+ """
+ os.environ["PYTHONHASHSEED"] = str(seed)
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = False
+
+
+def get_second_to_last(df: pd.DataFrame) -> pd.DataFrame:
+ """
+ ๊ฐ ๊ทธ๋ฃน ๋ด์์ 'img-ID'๋ฅผ ๊ธฐ์ค์ผ๋ก ์ ๋ ฌํ ํ,
+ ๊ฐ ๊ทธ๋ฃน์ ๋ค์์ ๋ ๋ฒ์งธ ํ๋ง ์ถ์ถํด ๋ฐํ.
+ """
+ df = df.sort_values(
+ by='img-ID',
+ key=lambda s: s.str.split('-').str[-1].astype(int),
+ ascending=True
+ )
+ return df.iloc[0]
+
+
+def load_and_filter_data(csv_path: str) -> pd.DataFrame:
+ """
+ 1. CSV ๋ก๋
+ 2. ID๋ณ '์ ์ฒด' ํ ์ ๋ณด๋ฅผ '๊ฐ๋ณ' ํ์ ์ฑ์๋ฃ๊ธฐ
+ 3. ID๋ณ ๋ค์์ ๋ ๋ฒ์งธ ํ๋ง ์ถ์ถ
+ 4. '?ref=storefarm' ์ ๊ฑฐ
+ 5. ์ ๋ ฌ ํ ๋ฐํ
+ """
+ # 1) CSV ๋ก๋
+ df_raw = pd.read_csv(csv_path)
+ df = df_raw.copy()
+
+ # 2) '์ ์ฒด' ๋ฐ์ดํฐ ์ถ์ถ ํ, '๊ฐ๋ณ'์ ์ ๋ณด ์ฑ์๋ฃ๊ธฐ
+ df_total = df[df['์ ์ฒด/๊ฐ๋ณ'] == '์ ์ฒด'].copy()
+ fill_cols = ['row', 'img-ID', '์นดํ
๊ณ ๋ฆฌ', '์ํ๋ช
', '์ํ ์์ธ URL']
+ info_dict = df_total.set_index('ID')[fill_cols].to_dict('index')
+ for col in fill_cols:
+ df[col] = df['ID'].map(lambda x: info_dict[x][col] if x in info_dict else None)
+
+ # 3) ๊ทธ๋ฃน๋ณ ๋ค์์ ๋ ๋ฒ์งธ ํ ์ถ์ถ
+ df_filtered = df.groupby('ID', group_keys=False).apply(get_second_to_last).reset_index(drop=True)
+
+ # 4) ์ด๋ฏธ์ง URL์์ '?ref=storefarm' ์ ๊ฑฐ
+ df_filtered['url_clean'] = df_filtered['์ด๋ฏธ์ง URL'].str.replace('?ref=storefarm', '', regex=False)
+
+ # 5) row ๊ธฐ์ค ์ ๋ ฌ
+ df_filtered = df_filtered.sort_values(by="row").reset_index(drop=True)
+
+ return df_filtered
\ No newline at end of file