Skip to content

Commit 8d624a7

Browse files
authored
fix: support \n\n. (#274)
1 parent e819dc6 commit 8d624a7

2 files changed

Lines changed: 20 additions & 14 deletions

File tree

src/api-types.hpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,10 @@ struct ChatMessage {
2626
struct ChunkChoice {
2727
int index;
2828
ChatMessageDelta delta;
29+
bool has_delta;
2930
std::string finish_reason;
3031

31-
ChunkChoice() : index(0) {}
32+
ChunkChoice() : index(0), delta(), has_delta(false) {}
3233
};
3334

3435

@@ -121,7 +122,10 @@ void to_json(json& j, const ChatMessage& msg) {
121122
}
122123

123124
void to_json(json& j, const ChunkChoice& choice) {
124-
j = json{{"index", choice.index}, {"delta", choice.delta}, {"finish_reason", choice.finish_reason}};
125+
j = json{{"index", choice.index}, {"finish_reason", choice.finish_reason}};
126+
if (choice.has_delta) {
127+
j["delta"] = choice.delta;
128+
}
125129
}
126130

127131
void to_json(json& j, const Choice& choice) {

src/dllama-api.cpp

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -114,25 +114,26 @@ class HttpRequest {
114114

115115
// First, read all headers
116116
std::string headerData;
117-
size_t headerEnd;
118-
bool headerDone = false;
119117
std::string extraReadPastHeader;
120-
while (!headerDone) {
118+
for (;;) {
121119
bytesRead = recv(serverSocket, buffer, sizeof(buffer) - 1, 0);
122120
if (bytesRead <= 0) {
123121
throw std::runtime_error("Error while reading headers from socket");
124122
}
125123
buffer[bytesRead] = '\0';
126124
headerData.append(buffer);
127125

128-
// Check for end of headers (http header says "\r\n\r\n")
129-
headerEnd = headerData.find("\r\n\r\n");
130-
if (headerEnd != std::string::npos) {
131-
headerDone = true;
132-
if (headerEnd < headerData.size()-4) {
133-
// We read something past the header
134-
extraReadPastHeader = headerData.substr(headerEnd+4);
135-
}
126+
const size_t endRnRn = headerData.find("\r\n\r\n");
127+
if (endRnRn != std::string::npos) {
128+
if (endRnRn < headerData.size() - 4)
129+
extraReadPastHeader = headerData.substr(endRnRn + 4);
130+
break;
131+
}
132+
const size_t endNN = headerData.find("\n\n");
133+
if (endNN != std::string::npos) {
134+
if (endNN < headerData.size() - 2)
135+
extraReadPastHeader = headerData.substr(endNN + 2);
136+
break;
136137
}
137138
}
138139

@@ -142,7 +143,7 @@ class HttpRequest {
142143
std::istringstream headerStream(headerData);
143144
std::string line;
144145
ssize_t contentLength = 0;
145-
while (std::getline(headerStream, line) && line != "\r") {
146+
while (std::getline(headerStream, line) && line != "\r" && line != "\n") {
146147
size_t pos = line.find(':');
147148
if (pos != std::string::npos) {
148149
std::string key = line.substr(0, pos);
@@ -280,6 +281,7 @@ void writeChatCompletionChunk(HttpRequest &request, const std::string &delta, co
280281
choice.finish_reason = "stop";
281282
} else {
282283
choice.delta = ChatMessageDelta("assistant", delta);
284+
choice.has_delta = true;
283285
}
284286
ChatCompletionChunk chunk = ChatCompletionChunk(choice);
285287

0 commit comments

Comments
 (0)