Skip to content

Commit c3fa55b

Browse files
committed
feat: implement caching proxy with main entry point
1 parent 572ca42 commit c3fa55b

2 files changed

Lines changed: 307 additions & 0 deletions

File tree

cmd/mbproxy/main.go

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
package main
2+
3+
import (
4+
"context"
5+
"log/slog"
6+
"os"
7+
"os/signal"
8+
"syscall"
9+
10+
"github.com/tma/mbproxy/internal/config"
11+
"github.com/tma/mbproxy/internal/logging"
12+
"github.com/tma/mbproxy/internal/proxy"
13+
)
14+
15+
func main() {
16+
cfg, err := config.Load()
17+
if err != nil {
18+
slog.Error("failed to load configuration", "error", err)
19+
os.Exit(1)
20+
}
21+
22+
logger := logging.New(cfg.LogLevel)
23+
24+
ctx, cancel := context.WithCancel(context.Background())
25+
defer cancel()
26+
27+
// Handle shutdown signals
28+
sigCh := make(chan os.Signal, 1)
29+
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
30+
31+
p, err := proxy.New(cfg, logger)
32+
if err != nil {
33+
logger.Error("failed to create proxy", "error", err)
34+
os.Exit(1)
35+
}
36+
37+
// Start proxy in background
38+
errCh := make(chan error, 1)
39+
go func() {
40+
errCh <- p.Run(ctx)
41+
}()
42+
43+
// Wait for shutdown signal or error
44+
select {
45+
case sig := <-sigCh:
46+
logger.Info("received signal, shutting down", "signal", sig)
47+
case err := <-errCh:
48+
if err != nil {
49+
logger.Error("proxy error", "error", err)
50+
os.Exit(1)
51+
}
52+
}
53+
54+
// Graceful shutdown
55+
cancel()
56+
if err := p.Shutdown(cfg.ShutdownTimeout); err != nil {
57+
logger.Error("shutdown error", "error", err)
58+
os.Exit(1)
59+
}
60+
61+
logger.Info("shutdown complete")
62+
}

internal/proxy/proxy.go

Lines changed: 245 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,245 @@
1+
// Package proxy implements the Modbus caching proxy.
2+
package proxy
3+
4+
import (
5+
"context"
6+
"encoding/binary"
7+
"fmt"
8+
"log/slog"
9+
"time"
10+
11+
"github.com/tma/mbproxy/internal/cache"
12+
"github.com/tma/mbproxy/internal/config"
13+
"github.com/tma/mbproxy/internal/modbus"
14+
)
15+
16+
// Proxy is a caching Modbus proxy server.
17+
type Proxy struct {
18+
cfg *config.Config
19+
logger *slog.Logger
20+
server *modbus.Server
21+
client *modbus.Client
22+
cache *cache.Cache
23+
}
24+
25+
// New creates a new proxy instance.
26+
func New(cfg *config.Config, logger *slog.Logger) (*Proxy, error) {
27+
p := &Proxy{
28+
cfg: cfg,
29+
logger: logger,
30+
client: modbus.NewClient(cfg.Upstream, cfg.Timeout, logger),
31+
cache: cache.New(cfg.CacheTTL),
32+
}
33+
34+
p.server = modbus.NewServer(p, logger)
35+
36+
return p, nil
37+
}
38+
39+
// Run starts the proxy server.
40+
func (p *Proxy) Run(ctx context.Context) error {
41+
p.logger.Info("starting proxy",
42+
"listen", p.cfg.Listen,
43+
"upstream", p.cfg.Upstream,
44+
"readonly", p.cfg.ReadOnly,
45+
"cache_ttl", p.cfg.CacheTTL,
46+
)
47+
48+
if err := p.server.Listen(p.cfg.Listen); err != nil {
49+
return fmt.Errorf("listen: %w", err)
50+
}
51+
52+
// Connect to upstream
53+
if err := p.client.Connect(); err != nil {
54+
p.logger.Warn("initial upstream connection failed, will retry on first request", "error", err)
55+
}
56+
57+
return p.server.Serve(ctx)
58+
}
59+
60+
// Shutdown gracefully shuts down the proxy.
61+
func (p *Proxy) Shutdown(timeout time.Duration) error {
62+
p.logger.Info("shutting down", "timeout", timeout)
63+
64+
// Stop accepting new connections
65+
p.server.Close()
66+
67+
// Wait for in-flight requests with timeout
68+
done := make(chan struct{})
69+
go func() {
70+
p.server.Wait()
71+
close(done)
72+
}()
73+
74+
select {
75+
case <-done:
76+
p.logger.Debug("all connections closed")
77+
case <-time.After(timeout):
78+
p.logger.Warn("shutdown timeout, forcing close")
79+
}
80+
81+
// Close cache cleanup goroutine
82+
p.cache.Close()
83+
84+
// Close upstream connection
85+
return p.client.Close()
86+
}
87+
88+
// HandleRequest implements modbus.Handler interface.
89+
func (p *Proxy) HandleRequest(ctx context.Context, req *modbus.Request) ([]byte, error) {
90+
start := time.Now()
91+
92+
if modbus.IsWriteFunction(req.FunctionCode) {
93+
return p.handleWrite(ctx, req, start)
94+
}
95+
96+
if modbus.IsReadFunction(req.FunctionCode) {
97+
return p.handleRead(ctx, req, start)
98+
}
99+
100+
// Unknown function code
101+
p.logger.Debug("unknown function code",
102+
"func", fmt.Sprintf("0x%02X", req.FunctionCode),
103+
"slave_id", req.SlaveID,
104+
)
105+
return modbus.BuildExceptionResponse(req.FunctionCode, modbus.ExcIllegalFunction), nil
106+
}
107+
108+
func (p *Proxy) handleRead(ctx context.Context, req *modbus.Request, start time.Time) ([]byte, error) {
109+
key := cache.Key(req.SlaveID, req.FunctionCode, req.Address, req.Quantity)
110+
111+
// Use GetOrFetch for request coalescing
112+
data, cacheHit, err := p.cache.GetOrFetch(ctx, key, func(ctx context.Context) ([]byte, error) {
113+
p.logger.Debug("cache miss",
114+
"slave_id", req.SlaveID,
115+
"func", fmt.Sprintf("0x%02X", req.FunctionCode),
116+
"addr", req.Address,
117+
"qty", req.Quantity,
118+
)
119+
120+
resp, err := p.client.Execute(ctx, req)
121+
if err != nil {
122+
return nil, err
123+
}
124+
125+
p.logger.Debug("upstream read",
126+
"slave_id", req.SlaveID,
127+
"func", fmt.Sprintf("0x%02X", req.FunctionCode),
128+
"addr", req.Address,
129+
"qty", req.Quantity,
130+
"duration", time.Since(start),
131+
)
132+
133+
return resp, nil
134+
})
135+
136+
if err != nil {
137+
// Try serving stale data if configured
138+
if p.cfg.CacheServeStale {
139+
if stale, ok := p.cache.GetStale(key); ok {
140+
p.logger.Warn("upstream error, serving stale",
141+
"slave_id", req.SlaveID,
142+
"error", err,
143+
)
144+
return stale, nil
145+
}
146+
}
147+
return nil, err
148+
}
149+
150+
if cacheHit {
151+
p.logger.Debug("cache hit",
152+
"slave_id", req.SlaveID,
153+
"func", fmt.Sprintf("0x%02X", req.FunctionCode),
154+
"addr", req.Address,
155+
"qty", req.Quantity,
156+
)
157+
}
158+
159+
return data, nil
160+
}
161+
162+
func (p *Proxy) handleWrite(ctx context.Context, req *modbus.Request, start time.Time) ([]byte, error) {
163+
switch p.cfg.ReadOnly {
164+
case config.ReadOnlyOn:
165+
// Silently ignore, return success response
166+
p.logger.Debug("write ignored (readonly mode)",
167+
"slave_id", req.SlaveID,
168+
"func", fmt.Sprintf("0x%02X", req.FunctionCode),
169+
"addr", req.Address,
170+
)
171+
return p.buildFakeWriteResponse(req), nil
172+
173+
case config.ReadOnlyDeny:
174+
// Reject with exception
175+
p.logger.Debug("write denied (readonly mode)",
176+
"slave_id", req.SlaveID,
177+
"func", fmt.Sprintf("0x%02X", req.FunctionCode),
178+
"addr", req.Address,
179+
)
180+
return modbus.BuildExceptionResponse(req.FunctionCode, modbus.ExcIllegalFunction), nil
181+
182+
case config.ReadOnlyOff:
183+
// Forward to upstream
184+
resp, err := p.client.Execute(ctx, req)
185+
if err != nil {
186+
return nil, err
187+
}
188+
189+
p.logger.Debug("upstream write",
190+
"slave_id", req.SlaveID,
191+
"func", fmt.Sprintf("0x%02X", req.FunctionCode),
192+
"addr", req.Address,
193+
"qty", req.Quantity,
194+
"duration", time.Since(start),
195+
)
196+
197+
// Invalidate exact matching cache entries for all read function codes
198+
p.invalidateCache(req)
199+
200+
return resp, nil
201+
}
202+
203+
return nil, fmt.Errorf("unknown readonly mode: %s", p.cfg.ReadOnly)
204+
}
205+
206+
func (p *Proxy) invalidateCache(req *modbus.Request) {
207+
// Invalidate exact matches for all read function codes that could overlap
208+
readFuncs := []byte{
209+
modbus.FuncReadCoils,
210+
modbus.FuncReadDiscreteInputs,
211+
modbus.FuncReadHoldingRegisters,
212+
modbus.FuncReadInputRegisters,
213+
}
214+
215+
for _, fc := range readFuncs {
216+
key := cache.Key(req.SlaveID, fc, req.Address, req.Quantity)
217+
p.cache.Delete(key)
218+
}
219+
}
220+
221+
func (p *Proxy) buildFakeWriteResponse(req *modbus.Request) []byte {
222+
switch req.FunctionCode {
223+
case modbus.FuncWriteSingleCoil, modbus.FuncWriteSingleRegister:
224+
// Echo back the request: func + address + value
225+
resp := make([]byte, 5)
226+
resp[0] = req.FunctionCode
227+
binary.BigEndian.PutUint16(resp[1:3], req.Address)
228+
if len(req.Data) >= 2 {
229+
copy(resp[3:5], req.Data[:2])
230+
}
231+
return resp
232+
233+
case modbus.FuncWriteMultipleCoils, modbus.FuncWriteMultipleRegs:
234+
// Return: func + address + quantity
235+
resp := make([]byte, 5)
236+
resp[0] = req.FunctionCode
237+
binary.BigEndian.PutUint16(resp[1:3], req.Address)
238+
binary.BigEndian.PutUint16(resp[3:5], req.Quantity)
239+
return resp
240+
241+
default:
242+
// Should not happen, but return empty success
243+
return []byte{req.FunctionCode}
244+
}
245+
}

0 commit comments

Comments
 (0)