Skip to content

Commit c6cb514

Browse files
committed
add bytefallback bpe
1 parent 43914a1 commit c6cb514

File tree

5 files changed

+80
-19
lines changed

5 files changed

+80
-19
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
name = "BytePairEncoding"
22
uuid = "a4280ba5-8788-555a-8ca8-4a8c3d966a71"
33
authors = ["chengchingwen <[email protected]>"]
4-
version = "0.3.1"
4+
version = "0.3.2"
55

66
[deps]
7+
DoubleArrayTries = "abbaa0e5-f788-499c-92af-c35ff4258c82"
78
StructWalk = "31cdf514-beb7-4750-89db-dda9d2eb8d3d"
89
TextEncodeBase = "f92c20c0-9f2a-4705-8116-881385faba05"
910
Unicode = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"
1011

1112
[compat]
13+
DoubleArrayTries = "0.1"
1214
StructWalk = "0.2"
1315
TextEncodeBase = "0.5.4, 0.6"
1416
julia = "1.6"

src/BytePairEncoding.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ export BPE, BPETokenization
88

99
include("mstring.jl")
1010
include("bpe.jl")
11+
include("bytefallback.jl")
1112
include("tokenization.jl")
1213
include("learn.jl")
1314
include("gpt2_utils.jl")

src/bpe.jl

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ end
1010
BPE(merging_rank::Dict; sepsym = nothing, endsym = nothing) = BPE(merging_rank, sepsym, endsym)
1111
BPE(bpefile; sepsym = nothing, endsym = nothing, kws...) = BPE(read_merges(bpefile, endsym; kws...); sepsym, endsym)
1212

13-
(bpe::BPE)(x) = bytepairencode(x, bpe.merging_rank, bpe.sepsym, bpe.endsym)
13+
(bpe::BPE)(x) = bytepairencode(bpe, x)
1414

1515
function Base.show(io::IO, bpe::BPE)
1616
print(io, "BPE(")
@@ -101,20 +101,21 @@ function merge!(ms, i)
101101
return @inbounds @view(ms[1:desidx])
102102
end
103103

104-
function merges(x::AbstractString, endsym = nothing)
105-
buf = map(Merge, graphemes(x))
106-
if endsym !== nothing
107-
@inbounds buf[end] = Merge(buf[end], true)
108-
end
109-
return buf
104+
function merges(x::AbstractString, endsym::Union{Nothing, AbstractString} = nothing)
105+
buf = map(Merge, graphemes(x))
106+
if endsym !== nothing
107+
@inbounds buf[end] = Merge(buf[end], true)
108+
end
109+
return buf
110110
end
111+
merges(bpe::AbstractBPE, x::AbstractString) = merges(x, bpe.endsym)
111112

112-
function bytepairencode(x, merging_rank, sepsym = nothing, endsym = nothing)
113-
ms = merges(x, endsym)
113+
function bytepairencode(bpe::AbstractBPE, x::AbstractString)
114+
ms = merges(bpe, x)
114115
if length(ms) < 2
115-
y = [Merge(x, !isnothing(endsym))]
116+
y = [Merge(x, !isnothing(bpe.endsym))]
116117
else
117-
y = merge_loop!(merging_rank, ms, x)
118+
y = merge_loop!(bpe.merging_rank, ms, x)
118119
end
119-
return as_string.(y, sepsym, endsym)
120+
return as_string.(y, bpe.sepsym, bpe.endsym)
120121
end

src/bytefallback.jl

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import DoubleArrayTries
2+
using DoubleArrayTries: DoubleArrayTrie, StringView
3+
4+
struct ByteFallbackBPE <: AbstractBPE
5+
vocab::DoubleArrayTrie
6+
merging_rank::Dict{NTuple{2, Merge}, Int}
7+
sepsym::Union{String, Nothing}
8+
endsym::Union{String, Nothing}
9+
end
10+
11+
ByteFallbackBPE(vocab_list::AbstractVector{String}, merging_rank, sepsym, endsym) =
12+
ByteFallbackBPE(DoubleArrayTrie(collect(vocab_list)), merging_rank, sepsym, endsym)
13+
14+
(bpe::ByteFallbackBPE)(x) = bytepairencode(bpe, x)
15+
16+
function Base.show(io::IO, bpe::ByteFallbackBPE)
17+
print(io, "ByteFallbackBPE(")
18+
print(io, length(bpe.merging_rank))
19+
print(io, " merges")
20+
!isnothing(bpe.sepsym) && print(io, ", sepsym = ", bpe.sepsym)
21+
!isnothing(bpe.endsym) && print(io, ", endsym = ", bpe.endsym)
22+
print(io, ')')
23+
end
24+
25+
function merges(bpe::ByteFallbackBPE, x::AbstractString)
26+
vocab = bpe.vocab
27+
y = Vector{Merge}()
28+
offset = 0
29+
for c in split(x, "")
30+
i = DoubleArrayTries.lookup(vocab, c)
31+
nu = ncodeunits(c)
32+
if iszero(i)
33+
cu = codeunits(c)
34+
for i = 1:nu
35+
push!(y, Merge(x, offset, 1, false, true))
36+
offset += 1
37+
end
38+
else
39+
push!(y, Merge(x, offset, nu, false))
40+
offset += nu
41+
end
42+
end
43+
if bpe.endsym !== nothing
44+
@inbounds y[end] = Merge(y[end], true)
45+
end
46+
return y
47+
end

src/mstring.jl

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
1+
using DoubleArrayTries: StringView
2+
13
struct Merge
24
string::String
35
offset::UInt16
46
ncodeunits::UInt16
57
extra::Bool
8+
byte::Bool
69
end
710

8-
Merge(str, offset::Int, ncodeunits::Int, extra) = Merge(str, UInt16(offset), UInt16(ncodeunits), extra)
9-
Merge(a::Merge, e::Bool) = Merge(a.string, a.offset, a.ncodeunits, e)
10-
Merge(s::SubString, e::Bool = false) = Merge(s.string, s.offset, s.ncodeunits, e)
11+
Merge(str, offset::Int, ncodeunits::Int, extra, byte = false) = Merge(str, UInt16(offset), UInt16(ncodeunits), extra, byte)
12+
Merge(a::Merge, e::Bool) = Merge(a.string, a.offset, a.ncodeunits, e, a.byte)
13+
Merge(s::SubString, e::Bool = false) = Merge(s.string, s.offset, s.ncodeunits, e, false)
1114
Merge(s::String, e::Bool = false) = Merge(SubString(s), e)
1215

1316
function Merge(a::Merge, b::Merge)
@@ -22,7 +25,7 @@ function Merge(a::Merge, b::Merge)
2225
error("merge two Merge at same offset: partial string?")
2326
end
2427
nunits = a.ncodeunits + b.ncodeunits
25-
return Merge(a.string, offset, nunits, b.extra)
28+
return Merge(a.string, offset, nunits, b.extra, a.byte & b.byte)
2629
else
2730
error("merge different Merge")
2831
end
@@ -92,7 +95,7 @@ function write_merges(io::IO, rank, endsym = nothing; limit = typemax(Int), comm
9295
end
9396

9497
function Base.hash(m::Merge, h::UInt)
95-
h = hash(m.extra, h) + Base.memhash_seed
98+
h = hash(m.byte, hash(m.extra, h)) + Base.memhash_seed
9699
str_size = m.ncodeunits * sizeof(UInt8)
97100
str = m.string
98101
ptr = convert(Ptr{UInt8}, pointer(str)) + m.offset
@@ -103,6 +106,7 @@ function Base.:(==)(m1::Merge, m2::Merge)
103106
m1.extra == m2.extra || return false
104107
s = m1.ncodeunits
105108
s == m2.ncodeunits || return false
109+
m1.byte == m2.byte || return false
106110
str1 = m1.string
107111
str2 = m2.string
108112
p1 = convert(Ptr{UInt8}, pointer(str1)) + m1.offset
@@ -113,7 +117,13 @@ end
113117
function as_string(m::Merge, sepsym, endsym)
114118
str = m.string
115119
offset = m.offset
116-
s = SubString(str, offset+1, prevind(str, offset + m.ncodeunits + 1))
120+
cu = codeunits(str)
121+
range = offset+1:offset+m.ncodeunits
122+
if m.byte
123+
s = join(("<0x$(uppercase(string(cu[i]; base=16, pad=2)))>" for i in range))
124+
else
125+
s = StringView(@view(cu[range]))
126+
end
117127
sym = m.extra ? endsym : sepsym
118128
return isnothing(sym) ? String(s) : string(s, sym)
119129
end

0 commit comments

Comments
 (0)