Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 79 additions & 14 deletions synthetic_data_kit/utils/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,30 +9,95 @@
from typing import List, Dict, Any

def split_into_chunks(text: str, chunk_size: int = 4000, overlap: int = 200) -> List[str]:
"""Split text into chunks with optional overlap"""
paragraphs = text.split("\n\n")
"""Split text into chunks with optional overlap using hierarchical approach"""

# Ensure overlap is not larger than chunk_size
overlap = min(overlap, chunk_size // 2)

if "\n\n" in text:
# Paragraph-based splitting
segments = text.split("\n\n")
join_str = "\n\n"
overlap_split_str = '. '
elif "\n" in text:
# Line-based splitting
segments = text.split("\n")
join_str = "\n"
overlap_split_str = '. '
else:
# Sentence-based splitting - try to find sentences first
sentences = re.split(r'([.!?])\s+(?=[A-Z])', text)
segments = []
for i in range(0, len(sentences) - 1, 2):
if i + 1 < len(sentences):
sentence = sentences[i] + sentences[i + 1]
segments.append(sentence.strip())

# Add remaining part if exists
if len(sentences) % 2 == 1:
segments.append(sentences[-1].strip())

# If no proper sentences found, fall back to word splitting
if len(segments) <= 1 and len(text) > chunk_size:
segments = text.split(' ')
join_str = " "
overlap_split_str = ' '
else:
join_str = " "
overlap_split_str = '. '

chunks = []
current_chunk = ""

for para in paragraphs:
if len(current_chunk) + len(para) > chunk_size and current_chunk:
chunks.append(current_chunk)
# Keep some overlap for context
sentences = current_chunk.split('. ')
if len(sentences) > 3:
current_chunk = '. '.join(sentences[-3:]) + "\n\n" + para
for segment in segments:
potential_length = len(current_chunk) + (len(join_str) if current_chunk else 0) + len(segment)

if potential_length > chunk_size and current_chunk:
chunks.append(current_chunk.strip())

# Create overlap for next chunk
if overlap > 0:
overlap_parts = current_chunk.split(overlap_split_str)
if len(overlap_parts) > 1:
# Keep overlap amount of characters from the end
overlap_text = current_chunk[-overlap:] if len(current_chunk) > overlap else current_chunk
space_pos = overlap_text.find(' ')
if space_pos > 0:
overlap_text = overlap_text[space_pos + 1:]
current_chunk = overlap_text + join_str + segment
else:
current_chunk = segment
else:
current_chunk = para
current_chunk = segment
else:
if current_chunk:
current_chunk += "\n\n" + para
current_chunk += join_str + segment
else:
current_chunk = para
current_chunk = segment

# Add final chunk if it exists
if current_chunk:
chunks.append(current_chunk)
chunks.append(current_chunk.strip())

# Fallback: if only one chunk and text is longer than chunk_size, force character-based splitting
if len(chunks) == 1 and len(text) > chunk_size:
chunks = []
step_size = max(1, chunk_size - overlap)

for i in range(0, len(text), step_size):
chunk_end = min(i + chunk_size, len(text))
chunk = text[i:chunk_end]

# Try to end at word boundary if not at end
if chunk_end < len(text) and ' ' in chunk:
last_space = chunk.rfind(' ')
if last_space > len(chunk) * 0.7: # Don't lose too much content
chunk = chunk[:last_space]

if chunk.strip():
chunks.append(chunk.strip())

return chunks
return [chunk for chunk in chunks if chunk.strip()]

def extract_json_from_text(text: str) -> Dict[str, Any]:
"""Extract JSON from text that might contain markdown or other content"""
Expand Down
15 changes: 14 additions & 1 deletion tests/unit/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from synthetic_data_kit.utils import config, text


@pytest.mark.unit
def test_split_into_chunks():
"""Test splitting text into chunks."""
# Create multi-paragraph text
Expand All @@ -33,7 +32,21 @@ def test_split_into_chunks():
# Empty text should produce an empty list, not a list with an empty string
assert empty_chunks == []

non_paragraph_tests = [
"This is a sample example of inputs without paragraphs.\n" *16,
"This is a sample example of inputs without sentences." *16,
]
# Using a small chunk size to ensure splitting
non_paragraph_chunks = text.split_into_chunks(non_paragraph_tests[0], chunk_size=20, overlap=10)
assert len(non_paragraph_chunks) > 1
assert len(non_paragraph_chunks) >= 16

non_sentence_chunks = text.split_into_chunks(non_paragraph_tests[1], chunk_size=20, overlap=5)
assert len(non_sentence_chunks) > 1
assert len(non_sentence_chunks) >=20


test_split_into_chunks()
@pytest.mark.unit
def test_extract_json_from_text():
"""Test extracting JSON from text."""
Expand Down