Skip to content

Commit e819dc6

Browse files
authored
feat: add host arg. (#273)
1 parent 142a48b commit e819dc6

7 files changed

Lines changed: 22 additions & 11 deletions

File tree

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ Worker, API
100100

101101
| Argument | Description | Example |
102102
| ---------------------------- | --------------------------------- | ----------------- |
103+
| `--host <addr>` | Binding address. | `127.0.0.1` |
103104
| `--port <port>` | Binding port. | `9999` |
104105

105106
Inference

docs/HOW_TO_RUN_RASPBERRYPI.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ sudo nice -n -20 ./dllama inference \
7878

7979
```sh
8080
sudo nice -n -20 ./dllama-api \
81+
--host 0.0.0.0 \
8182
--port 9999 \
8283
--model models/llama3_2_3b_instruct_q40/dllama_model_llama3_2_3b_instruct_q40.m \
8384
--tokenizer models/llama3_2_3b_instruct_q40/dllama_tokenizer_llama3_2_3b_instruct_q40.t \

src/app.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ AppCliArgs AppCliArgs::parse(int argc, char* *argv, bool requireMode) {
3535
args.nWorkers = 0;
3636
args.workerHosts = nullptr;
3737
args.workerPorts = nullptr;
38+
args.host = "0.0.0.0";
3839
args.port = 9990;
3940
args.temperature = 0.8f;
4041
args.topp = 0.9f;
@@ -97,6 +98,8 @@ AppCliArgs AppCliArgs::parse(int argc, char* *argv, bool requireMode) {
9798
i += count - 1;
9899
} else if (std::strcmp(name, "--port") == 0) {
99100
args.port = atoi(value);
101+
} else if (std::strcmp(name, "--host") == 0) {
102+
args.host = value;
100103
} else if (std::strcmp(name, "--nthreads") == 0) {
101104
args.nThreads = atoi(value);
102105
} else if (std::strcmp(name, "--steps") == 0) {
@@ -302,7 +305,7 @@ void runInferenceApp(AppCliArgs *args, void (*handler)(AppInferenceContext *cont
302305

303306
void runWorkerApp(AppCliArgs *args) {
304307
while (true) {
305-
std::unique_ptr<NnNetwork> networkPtr = NnNetwork::serve(args->port);
308+
std::unique_ptr<NnNetwork> networkPtr = NnNetwork::serve(args->host, args->port);
306309
NnNetwork *network = networkPtr.get();
307310

308311
NnWorkerConfigReader configReader(network);

src/app.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ class AppCliArgs {
3535
int gpuSegmentFrom;
3636
int gpuSegmentTo;
3737

38-
// worker
38+
// binding
39+
const char *host;
3940
NnUint port;
4041

4142
static AppCliArgs parse(int argc, char **argv, bool hasMode);

src/dllama-api.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -534,14 +534,16 @@ void handleModelsRequest(HttpRequest& request, const char* modelPath) {
534534
}
535535

536536
static void server(AppInferenceContext *context) {
537-
NnSocket serverSocket(createServerSocket(context->args->port));
537+
NnSocket serverSocket(createServerSocket(context->args->host, context->args->port));
538538

539539
TokenizerChatStops stops(context->tokenizer);
540540
ChatTemplateGenerator templateGenerator(context->args->chatTemplateType, context->tokenizer->chatTemplate, stops.stops[0]);
541541
EosDetector eosDetector(stops.nStops, context->tokenizer->eosTokenIds.data(), stops.stops, stops.maxStopLength, stops.maxStopLength);
542542
ApiServer api(context->inference, context->tokenizer, context->sampler, context->args, context->header, &eosDetector, &templateGenerator);
543543

544-
printf("Server URL: http://127.0.0.1:%d/v1/\n", context->args->port);
544+
if (strcmp(context->args->host, "0.0.0.0") == 0 ||
545+
strcmp(context->args->host, "127.0.0.1") == 0)
546+
printf("Server URL: http://localhost:%d/v1/\n", context->args->port);
545547

546548
std::vector<Route> routes = {
547549
{
@@ -577,7 +579,7 @@ static void server(AppInferenceContext *context) {
577579
#endif
578580

579581
void usage() {
580-
fprintf(stderr, "Usage: %s {--model <path>} {--tokenizer <path>} [--port <p>]\n", EXECUTABLE_NAME);
582+
fprintf(stderr, "Usage: %s {--model <path>} {--tokenizer <path>} [--host <addr>] [--port <p>]\n", EXECUTABLE_NAME);
581583
fprintf(stderr, " [--buffer-float-type {f32|f16|q40|q80}]\n");
582584
fprintf(stderr, " [--weights-float-type {f32|f16|q40|q80}]\n");
583585
fprintf(stderr, " [--max-seq-len <max>]\n");

src/nn/nn-network.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -172,8 +172,7 @@ static inline int connectSocket(char *host, int port) {
172172
return sock;
173173
}
174174

175-
int createServerSocket(int port) {
176-
const char *host = "0.0.0.0";
175+
int createServerSocket(const char *host, const int port) {
177176
struct sockaddr_in serverAddr;
178177

179178
int serverSocket = ::socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
@@ -185,6 +184,10 @@ int createServerSocket(int port) {
185184
serverAddr.sin_family = AF_INET;
186185
serverAddr.sin_port = htons(port);
187186
serverAddr.sin_addr.s_addr = inet_addr(host);
187+
if (serverAddr.sin_addr.s_addr == INADDR_NONE) {
188+
destroySocket(serverSocket);
189+
throw std::runtime_error("Invalid bind host");
190+
}
188191

189192
int bindResult;
190193
#ifdef _WIN32
@@ -289,8 +292,8 @@ int NnSocket::release() {
289292
return fd;
290293
}
291294

292-
std::unique_ptr<NnNetwork> NnNetwork::serve(int port) {
293-
NnSocket socketSocket(createServerSocket(port));
295+
std::unique_ptr<NnNetwork> NnNetwork::serve(const char *host, const int port) {
296+
NnSocket socketSocket(createServerSocket(host, port));
294297

295298
NnUint nSockets;
296299
NnUint nodeIndex;

src/nn/nn-network.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ int acceptSocket(int serverSocket);
1111
void setReuseAddr(int socket);
1212
void writeSocket(int socket, const void* data, NnSize size);
1313
void readSocket(int socket, void* data, NnSize size);
14-
int createServerSocket(int port);
14+
int createServerSocket(const char *host, const int port);
1515
void destroySocket(int serverSocket);
1616

1717
class NnConnectionSocketException : public std::runtime_error {
@@ -48,7 +48,7 @@ class NnNetwork {
4848
NnSize *recvBytes;
4949

5050
public:
51-
static std::unique_ptr<NnNetwork> serve(int port);
51+
static std::unique_ptr<NnNetwork> serve(const char *host, const int port);
5252
static std::unique_ptr<NnNetwork> connect(NnUint nSockets, char **hosts, NnUint *ports);
5353

5454
NnUint nSockets;

0 commit comments

Comments
 (0)