Skip to content

Commit d21c92b

Browse files
committed
support tiktoken encoder
1 parent 08bc710 commit d21c92b

File tree

4 files changed

+126
-3
lines changed

4 files changed

+126
-3
lines changed

README.md

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,4 +36,21 @@ julia> tkr2("hello world aaaaaaaaaaaa")
3636
"aaaa"
3737
"aaa"
3838

39+
julia> enc = BytePairEncoding.load_tiktoken_encoder("cl100k_base")
40+
┌ Warning: The maximum encoded value (`length(BPEEncoder.vocab)`) is larger than the number of possible tokens
41+
│ because there are some "gaps" in the vocabulary. Be carefull if used to initialize embedding table.
42+
└ @ BytePairEncodin
43+
BPEEncoder(BPETokenizer(MatchTokenization(BPETokenization(Cl100kBaseTokenization, bpe = TikTokenBPE(100256 merges)), 5 patterns)), Vocab(size = 100277))
44+
45+
julia> enc.encode("hello world aaaaaaaaaaaa") # === enc(...)
46+
5-element Vector{Int64}:
47+
15340
48+
1918
49+
265
50+
70541
51+
33747
52+
53+
julia> enc.decode(enc("hello world aaaaaaaaaaaa"))
54+
"hello world aaaaaaaaaaaa"
55+
3956
```

src/BytePairEncoding.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,6 @@ include("tokenization.jl")
1414
include("learn.jl")
1515
include("gpt2_utils.jl")
1616
include("tiktoken.jl")
17+
include("encoder.jl")
1718

1819
end # module

src/encoder.jl

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
using TextEncodeBase
2+
using TextEncodeBase: PerforatedOverwritableLookupVector, DictBackedLookupDict, DATLookupVector
3+
4+
struct BPEEncoder{T<:BPETokenizer, V<:Vocab} <: AbstractTextEncoder
5+
tokenizer::T
6+
vocab::V
7+
end
8+
TextEncodeBase.process(e::BPEEncoder) = identity
9+
(e::BPEEncoder)(x::AbstractString) = TextEncodeBase.lookup(e.vocab, encode_indices(e, x))
10+
11+
Base.propertynames(e::BPEEncoder) = (:encode, :decode, fieldnames(BPEEncoder)...)
12+
function Base.getproperty(e::BPEEncoder, sym::Symbol)
13+
if sym == :encode
14+
return e
15+
elseif sym == :decode
16+
return Base.Fix1(TextEncodeBase.decode_text, e)
17+
else
18+
return getfield(e, sym)
19+
end
20+
end
21+
22+
function Base.show(io::IO, e::BPEEncoder)
23+
print(io, "BPEEncoder(")
24+
show(io, e.tokenizer)
25+
print(io, ", Vocab(size = ")
26+
print(io, length(e.vocab))
27+
print(io, "))")
28+
end
29+
30+
"""
31+
load_tiktoken_encoder(name)
32+
33+
Load the tiktoken encoder (tokenizer + predefined vocabulary)
34+
35+
!!! warning
36+
The encoded value is off by 1 comparing to the python/rust tiktoken.
37+
38+
```julia-repl
39+
julia> enc = BytePairEncoding.load_tiktoken_encoder("cl100k_base")
40+
┌ Warning: The maximum encoded value (`length(BPEEncoder.vocab)`) is larger than the number of possible tokens
41+
│ because there are some "gaps" in the vocabulary. Be carefull if used to initialize embedding table.
42+
└ @ BytePairEncodin
43+
BPEEncoder(BPETokenizer(MatchTokenization(BPETokenization(Cl100kBaseTokenization, bpe = TikTokenBPE(100256 merges)), 5 patterns)), Vocab(size = 100277))
44+
45+
julia> enc.encode("hello world aaaaaaaaaaaa") # === enc(...)
46+
5-element Vector{Int64}:
47+
15340
48+
1918
49+
265
50+
70541
51+
33747
52+
53+
julia> enc.decode(enc("hello world aaaaaaaaaaaa"))
54+
"hello world aaaaaaaaaaaa"
55+
56+
```
57+
"""
58+
function load_tiktoken_encoder(name)
59+
ENDOFTEXT = "<|endoftext|>"
60+
FIM_PREFIX = "<|fim_prefix|>"
61+
FIM_MIDDLE = "<|fim_middle|>"
62+
FIM_SUFFIX = "<|fim_suffix|>"
63+
ENDOFPROMPT = "<|endofprompt|>"
64+
tkr = load_tiktoken(name)
65+
bpe = tkr.tokenization.base.bpe
66+
warn = true
67+
if name == "o200k_base"
68+
sptk = Dict(
69+
ENDOFTEXT => 199999 + 1,
70+
ENDOFPROMPT => 200018 + 1,
71+
)
72+
elseif name == "cl100k_base"
73+
sptk = Dict(
74+
ENDOFTEXT => 100257 + 1,
75+
FIM_PREFIX => 100258 + 1,
76+
FIM_MIDDLE => 100259 + 1,
77+
FIM_SUFFIX => 100260 + 1,
78+
ENDOFPROMPT => 100276 + 1,
79+
)
80+
elseif name == "p50k_edit"
81+
sptk = Dict(
82+
ENDOFTEXT => 50256 + 1,
83+
FIM_PREFIX => 50281 + 1,
84+
FIM_MIDDLE => 50282 + 1,
85+
FIM_SUFFIX => 50283 + 1,
86+
)
87+
else
88+
sptk = Dict(
89+
ENDOFTEXT => 50256 + 1,
90+
)
91+
warn = false
92+
end
93+
if warn
94+
@warn """The maximum encoded value (`length(BPEEncoder.vocab)`) is larger than the number of possible tokens
95+
because there are some "gaps" in the vocabulary. Be carefull if used to initialize embedding table."""
96+
end
97+
vector = PerforatedOverwritableLookupVector(
98+
DATLookupVector(bpe.encoder),
99+
DictBackedLookupDict(sptk, Dict(v=>k for (k, v) in sptk)))
100+
vocab = Vocab(vector, "", 0) # byte level bpe should be free from unknown token
101+
return BPEEncoder(tkr, vocab)
102+
end

test/test_tiktoken.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ const xnli = readlines(joinpath(artifact_dir, "xnli-dev.txt"))
22
using PythonCall
33
const tiktoken = pyimport("tiktoken")
44

5-
using BytePairEncoding: load_tiktoken, load_gpt2, tiktoken2bbpe, bbpe2tiktoken, gpt2_codemap
5+
using BytePairEncoding: load_tiktoken_encoder, load_tiktoken, load_gpt2, tiktoken2bbpe, bbpe2tiktoken, gpt2_codemap
66

77
@testset "TikToken" begin
88
codemap = gpt2_codemap()
@@ -15,11 +15,14 @@ using BytePairEncoding: load_tiktoken, load_gpt2, tiktoken2bbpe, bbpe2tiktoken,
1515
"r50k_base",
1616
"gpt2",
1717
)
18-
tkr = load_tiktoken(model)
18+
enc = load_tiktoken_encoder(model)
19+
tkr = enc.tokenizer
1920
tkr2 = tiktoken2bbpe(tkr, codemap)
20-
@test tkr.tokenization.base.bpe.encoder == bbpe2tiktoken(tkr2).tokenization.base.bpe.encoder
21+
@test collect(tkr.tokenization.base.bpe.encoder) == collect(bbpe2tiktoken(tkr2).tokenization.base.bpe.encoder)
2122
pytkr = tiktoken.get_encoding(model)
2223
for line in xnli
24+
@test enc(line) == pyconvert(Array{Int}, pytkr.encode(line)) .+ 1
25+
@test enc.decode(enc(line)) == line
2326
tokens = tkr(line)
2427
@test join(tokens) == line
2528
@test tokens == map(py->pyconvert(Base.CodeUnits, py).s,

0 commit comments

Comments
 (0)