@@ -151,8 +151,13 @@ Tokenizer::Tokenizer(const char* tokenizerPath)
151151 specialVocab[i].id = i + regularVocabSize;
152152 }
153153
154- strBufferSize = maxTokenLength * 2 + 1 + 2 ;
154+ strBufferSize = maxTokenLength * 2 ;
155+ if (strBufferSize < (4 * 2 )) { // ensure place for 2 utf-8 multi-byte sequence
156+ strBufferSize = 4 * 2 ;
157+ }
158+ strBufferSize += 1 + 2 ;
155159 strBuffer = new char [strBufferSize];
160+ utf8Buffer = new char [strBufferSize];
156161
157162 if (bosId >= 0 ) printf (" 📄 BosId: %d (%s)\n " , bosId, vocab[bosId]);
158163 if (eosTokenIds.size () > 0 ) {
@@ -179,6 +184,7 @@ Tokenizer::~Tokenizer() {
179184 delete[] regularVocab;
180185 delete[] specialVocab;
181186 delete[] strBuffer;
187+ delete[] utf8Buffer;
182188}
183189
184190int Tokenizer::findSpecialTokenStartWith (char *piece) {
@@ -209,6 +215,73 @@ void Tokenizer::resetDecoder() {
209215 strBufferPos = 0 ;
210216}
211217
218+ char *Tokenizer::detokUtf8 () {
219+ char * src = strBuffer;
220+ char * dst = utf8Buffer;
221+ char * checkpoint_src = src;
222+ char * checkpoint = dst;
223+ unsigned expect_continuation = 0 ;
224+
225+ while (unsigned char c = *src) {
226+ bool need_recovery = false ;
227+ if (expect_continuation) {
228+ if ((c & 0xc0 ) == 0x80 ) {
229+ *dst++ = *src++;
230+ expect_continuation--;
231+ } else {
232+ need_recovery = true ;
233+ }
234+ } else if (c <= 0x7f ) {
235+ *dst++ = *src++;
236+ } else if (c >= 0xc0 && c <= 0xdf ) {
237+ *dst++ = *src++;
238+ expect_continuation = 1 ;
239+ } else if (c >= 0xe0 && c <= 0xef ) {
240+ *dst++ = *src++;
241+ expect_continuation = 2 ;
242+ } else if (c >= 0xf0 && c <= 0xf7 ) {
243+ *dst++ = *src++;
244+ expect_continuation = 3 ;
245+ } else {
246+ need_recovery = true ;
247+ }
248+
249+ if (!need_recovery) {
250+ if (!expect_continuation) {
251+ checkpoint = dst;
252+ checkpoint_src = src;
253+ }
254+ } else {
255+ // perform stream recover
256+ if (expect_continuation) {
257+ expect_continuation = 0 ;
258+ } else {
259+ ++src;
260+ }
261+ dst = checkpoint;
262+ // emit 0xfffd
263+ *dst++ = 0xef ;
264+ *dst++ = 0xbf ;
265+ *dst++ = 0xbd ;
266+
267+ fprintf (stderr, " \n Tokenizer: decoded invalid utf8 -- attempting stream recover\n " );
268+ }
269+ }
270+
271+ if (src > checkpoint_src) {
272+ memmove (strBuffer, checkpoint_src, src - checkpoint_src + 1 );
273+ strBufferPos = src - checkpoint_src;
274+ } else {
275+ strBufferPos = 0 ;
276+ }
277+ *checkpoint = ' \0 ' ;
278+ if (checkpoint > utf8Buffer) {
279+ return utf8Buffer;
280+ } else {
281+ return nullptr ;
282+ }
283+ }
284+
212285char *Tokenizer::decode (int token) {
213286 if (token == bosId)
214287 return nullptr ;
@@ -220,18 +293,13 @@ char *Tokenizer::decode(int token) {
220293
221294 char *piece = vocab[token];
222295 int pieceLen = vocabLength[token];
223- bool hasContinuation = (piece[pieceLen - 1 ] & 0xC0 ) == 0x80 ;
224296
225297 assert (strBufferPos + pieceLen + 1 < strBufferSize);
226298 std::memcpy (&strBuffer[strBufferPos], piece, pieceLen * sizeof (char ));
227299 strBufferPos += pieceLen;
228300 strBuffer[strBufferPos] = ' \0 ' ;
229301
230- if (!hasContinuation) {
231- strBufferPos = 0 ;
232- return strBuffer;
233- }
234- return nullptr ;
302+ return detokUtf8 ();
235303}
236304
237305void Tokenizer::encode (char *text, int *tokens, int *nTokens, bool addBos, bool addSpecialTokens) {
0 commit comments