Skip to content

Commit c9c4fde

Browse files
Copilottma
andcommitted
fix: normalize wildcard hosts in CheckHealth, isolate config test env
Co-authored-by: tma <4719+tma@users.noreply.github.com>
1 parent 971d77b commit c9c4fde

File tree

3 files changed

+65
-8
lines changed

3 files changed

+65
-8
lines changed

internal/config/config_test.go

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -198,12 +198,19 @@ func TestLoad_InvalidDuration(t *testing.T) {
198198
}
199199

200200
func TestLoad_HealthListenCustom(t *testing.T) {
201-
os.Setenv("MODBUS_UPSTREAM", "localhost:502")
202-
os.Setenv("HEALTH_LISTEN", ":9090")
203-
defer func() {
204-
os.Unsetenv("MODBUS_UPSTREAM")
205-
os.Unsetenv("HEALTH_LISTEN")
206-
}()
201+
// Ensure optional env vars that Load() may read do not inherit
202+
// potentially invalid values from the surrounding environment.
203+
t.Setenv("MODBUS_LISTEN", "")
204+
t.Setenv("MODBUS_READONLY", "")
205+
t.Setenv("MODBUS_CACHE_TTL", "")
206+
t.Setenv("MODBUS_TIMEOUT", "")
207+
t.Setenv("MODBUS_REQUEST_DELAY", "")
208+
t.Setenv("MODBUS_CONNECT_DELAY", "")
209+
t.Setenv("MODBUS_SHUTDOWN_TIMEOUT", "")
210+
211+
// Set required and explicitly tested env vars.
212+
t.Setenv("MODBUS_UPSTREAM", "localhost:502")
213+
t.Setenv("HEALTH_LISTEN", ":9090")
207214

208215
cfg, err := Load()
209216
if err != nil {

internal/health/health.go

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,17 +107,21 @@ func (s *Server) Shutdown(ctx context.Context) error {
107107

108108
// CheckHealth performs an HTTP health check against the given address.
109109
// It returns nil if the endpoint responds with 200 OK.
110+
// Wildcard listen addresses (e.g. ":8080", "0.0.0.0:8080", "[::]:8080") are
111+
// normalized to localhost so they can be used as dial targets. IPv6 addresses
112+
// are handled correctly via net.JoinHostPort.
110113
func CheckHealth(addr string) error {
111114
// Resolve the address so we can build a proper URL.
112115
host, port, err := net.SplitHostPort(addr)
113116
if err != nil {
114117
return fmt.Errorf("invalid address %q: %w", addr, err)
115118
}
116-
if host == "" {
119+
// Normalize wildcard and empty hosts to localhost.
120+
if host == "" || host == "0.0.0.0" || host == "::" {
117121
host = "localhost"
118122
}
119123

120-
url := fmt.Sprintf("http://%s:%s/health", host, port)
124+
url := fmt.Sprintf("http://%s/health", net.JoinHostPort(host, port))
121125

122126
client := &http.Client{Timeout: 3 * time.Second}
123127
resp, err := client.Get(url)

internal/health/health_test.go

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,3 +149,49 @@ func TestCheckHealth_ConnectionRefused(t *testing.T) {
149149
t.Error("expected error for connection refused")
150150
}
151151
}
152+
153+
func TestCheckHealth_WildcardAddresses(t *testing.T) {
154+
// Start a test server bound to localhost so wildcard addresses can reach it.
155+
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
156+
w.WriteHeader(http.StatusOK)
157+
json.NewEncoder(w).Encode(Response{Status: "ok"})
158+
}))
159+
defer ts.Close()
160+
161+
_, port, err := net.SplitHostPort(ts.Listener.Addr().String())
162+
if err != nil {
163+
t.Fatalf("failed to parse test server address: %v", err)
164+
}
165+
166+
// Each of these listen-style addresses should be normalized to localhost.
167+
wildcards := []string{
168+
":" + port, // empty host (":8080")
169+
"0.0.0.0:" + port, // IPv4 wildcard
170+
"[::]:" + port, // IPv6 wildcard
171+
}
172+
for _, addr := range wildcards {
173+
if err := CheckHealth(addr); err != nil {
174+
t.Errorf("CheckHealth(%q) expected success, got: %v", addr, err)
175+
}
176+
}
177+
}
178+
179+
func TestCheckHealth_IPv6Loopback(t *testing.T) {
180+
// Start a test server bound to the IPv6 loopback address.
181+
ln, err := net.Listen("tcp6", "[::1]:0")
182+
if err != nil {
183+
t.Skip("IPv6 loopback not available:", err)
184+
}
185+
ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
186+
w.WriteHeader(http.StatusOK)
187+
json.NewEncoder(w).Encode(Response{Status: "ok"})
188+
}))
189+
ts.Listener = ln
190+
ts.Start()
191+
defer ts.Close()
192+
193+
addr := ts.Listener.Addr().String() // "[::1]:PORT"
194+
if err := CheckHealth(addr); err != nil {
195+
t.Errorf("CheckHealth(%q) expected success for IPv6 loopback, got: %v", addr, err)
196+
}
197+
}

0 commit comments

Comments
 (0)