Skip to content

Commit 905e6e7

Browse files
authored
fix: tokenizer utf-8 handling (#228)
1 parent f3a902b commit 905e6e7

3 files changed

Lines changed: 106 additions & 14 deletions

File tree

src/tokenizer-test.cpp

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,22 @@ void dev_testEncode(Tokenizer *tokenizer) {
6868
}
6969
}
7070

71+
void dev_testDecoderEmojiStreamRecover(Tokenizer *tokenizer) {
72+
char *x0 = tokenizer->decode(128000);
73+
assert(x0 == nullptr);
74+
75+
char *x1 = tokenizer->decode(76460);
76+
assert(x1 == nullptr);
77+
78+
char *x2 = tokenizer->decode(76460);
79+
assert(x2 == nullptr);
80+
81+
char *x3 = tokenizer->decode(225);
82+
assert(strcmp(x3, "�😃") == 0);
83+
84+
printOk("testDecoderEmojiStreamRecover");
85+
}
86+
7187
void dev_testDecoderEmoji(Tokenizer *tokenizer) {
7288
char *x0 = tokenizer->decode(128000);
7389
assert(x0 == nullptr);
@@ -76,27 +92,30 @@ void dev_testDecoderEmoji(Tokenizer *tokenizer) {
7692
assert(x1 == nullptr);
7793

7894
char *x2 = tokenizer->decode(225);
79-
assert(x2 == nullptr);
95+
assert(strcmp(x2, "😃") == 0);
8096

8197
char *x3 = tokenizer->decode(0);
82-
assert(strstr(x3, "😃!") != NULL);
98+
assert(strcmp(x3, "!") == 0);
8399

84100
char *x4 = tokenizer->decode(56);
85-
assert(strstr(x3, "Y") != NULL);
101+
assert(strcmp(x4, "Y") == 0);
86102

87103
printOk("testDecoderEmoji");
88104
}
89105

90106
void dev_testDecoderEmojiWithEos(Tokenizer *tokenizer) {
91107
char *x0 = tokenizer->decode(128000);
108+
assert(x0 == nullptr);
109+
92110
char *x1 = tokenizer->decode(76460);
111+
assert(x1 == nullptr);
112+
93113
char *x2 = tokenizer->decode(225);
114+
assert(strcmp(x2, "😃") == 0);
115+
94116
char *x3 = tokenizer->decode(128001);
117+
assert(x3 == nullptr); // piece should not contain <|end_of_text|>
95118

96-
assert(x0 == nullptr);
97-
assert(x1 == nullptr);
98-
assert(x2 == nullptr);
99-
assert(strstr(x3, "😃") != NULL); // piece should not contain <|end_of_text|>
100119
printOk("decoderEmojiWithEos");
101120
}
102121

@@ -289,6 +308,7 @@ int main() {
289308
dev_testEncode(&tokenizer);
290309
dev_testDecoderEmoji(&tokenizer);
291310
dev_testDecoderEmojiWithEos(&tokenizer);
311+
dev_testDecoderEmojiStreamRecover(&tokenizer);
292312
#endif
293313

294314
testChatTemplateDetection();

src/tokenizer.cpp

Lines changed: 75 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -151,8 +151,13 @@ Tokenizer::Tokenizer(const char* tokenizerPath)
151151
specialVocab[i].id = i + regularVocabSize;
152152
}
153153

154-
strBufferSize = maxTokenLength * 2 + 1 + 2;
154+
strBufferSize = maxTokenLength * 2;
155+
if (strBufferSize < (4 * 2)) { // ensure place for 2 utf-8 multi-byte sequence
156+
strBufferSize = 4 * 2;
157+
}
158+
strBufferSize += 1 + 2;
155159
strBuffer = new char[strBufferSize];
160+
utf8Buffer = new char[strBufferSize];
156161

157162
if (bosId >= 0) printf("📄 BosId: %d (%s)\n", bosId, vocab[bosId]);
158163
if (eosTokenIds.size() > 0) {
@@ -179,6 +184,7 @@ Tokenizer::~Tokenizer() {
179184
delete[] regularVocab;
180185
delete[] specialVocab;
181186
delete[] strBuffer;
187+
delete[] utf8Buffer;
182188
}
183189

184190
int Tokenizer::findSpecialTokenStartWith(char *piece) {
@@ -209,6 +215,73 @@ void Tokenizer::resetDecoder() {
209215
strBufferPos = 0;
210216
}
211217

218+
char *Tokenizer::detokUtf8() {
219+
char* src = strBuffer;
220+
char* dst = utf8Buffer;
221+
char* checkpoint_src = src;
222+
char* checkpoint = dst;
223+
unsigned expect_continuation = 0;
224+
225+
while (unsigned char c = *src) {
226+
bool need_recovery = false;
227+
if (expect_continuation) {
228+
if ((c & 0xc0) == 0x80) {
229+
*dst++ = *src++;
230+
expect_continuation--;
231+
} else {
232+
need_recovery = true;
233+
}
234+
} else if (c <= 0x7f) {
235+
*dst++ = *src++;
236+
} else if (c >= 0xc0 && c <= 0xdf) {
237+
*dst++ = *src++;
238+
expect_continuation = 1;
239+
} else if (c >= 0xe0 && c <= 0xef) {
240+
*dst++ = *src++;
241+
expect_continuation = 2;
242+
} else if (c >= 0xf0 && c <= 0xf7) {
243+
*dst++ = *src++;
244+
expect_continuation = 3;
245+
} else {
246+
need_recovery = true;
247+
}
248+
249+
if (!need_recovery) {
250+
if (!expect_continuation) {
251+
checkpoint = dst;
252+
checkpoint_src = src;
253+
}
254+
} else {
255+
// perform stream recover
256+
if (expect_continuation) {
257+
expect_continuation = 0;
258+
} else {
259+
++src;
260+
}
261+
dst = checkpoint;
262+
// emit 0xfffd
263+
*dst++ = 0xef;
264+
*dst++ = 0xbf;
265+
*dst++ = 0xbd;
266+
267+
fprintf(stderr, "\nTokenizer: decoded invalid utf8 -- attempting stream recover\n");
268+
}
269+
}
270+
271+
if (src > checkpoint_src) {
272+
memmove(strBuffer, checkpoint_src, src - checkpoint_src + 1);
273+
strBufferPos = src - checkpoint_src;
274+
} else {
275+
strBufferPos = 0;
276+
}
277+
*checkpoint = '\0';
278+
if (checkpoint > utf8Buffer) {
279+
return utf8Buffer;
280+
} else {
281+
return nullptr;
282+
}
283+
}
284+
212285
char *Tokenizer::decode(int token) {
213286
if (token == bosId)
214287
return nullptr;
@@ -220,18 +293,13 @@ char *Tokenizer::decode(int token) {
220293

221294
char *piece = vocab[token];
222295
int pieceLen = vocabLength[token];
223-
bool hasContinuation = (piece[pieceLen - 1] & 0xC0) == 0x80;
224296

225297
assert(strBufferPos + pieceLen + 1 < strBufferSize);
226298
std::memcpy(&strBuffer[strBufferPos], piece, pieceLen * sizeof(char));
227299
strBufferPos += pieceLen;
228300
strBuffer[strBufferPos] = '\0';
229301

230-
if (!hasContinuation) {
231-
strBufferPos = 0;
232-
return strBuffer;
233-
}
234-
return nullptr;
302+
return detokUtf8();
235303
}
236304

237305
void Tokenizer::encode(char *text, int *tokens, int *nTokens, bool addBos, bool addSpecialTokens) {

src/tokenizer.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ class Tokenizer {
4242
TokenIndex *specialVocab;
4343
size_t strBufferSize;
4444
char *strBuffer;
45+
char *utf8Buffer;
4546
size_t strBufferPos;
4647

4748

@@ -60,6 +61,9 @@ class Tokenizer {
6061
bool isEos(int token);
6162
char *decode(int token);
6263
void resetDecoder();
64+
65+
private:
66+
char *detokUtf8();
6367
};
6468

6569
typedef struct {

0 commit comments

Comments
 (0)