Skip to content

Commit 636a0c0

Browse files
committed
refactor: update proxy for per-register cache
Decompose upstream responses into per-register cache entries and reassemble from cache on hits. Write invalidation now correctly removes individual registers in the written range, fixing stale data when writes overlap with larger cached read ranges. New helpers: - decomposeResponse: extract per-register values from Modbus PDU - assembleResponse: reconstruct Modbus PDU from cached values - Roundtrip tests for all function codes (registers + coils) - Tests verifying write invalidation of overlapping reads
1 parent 7b92f28 commit 636a0c0

File tree

2 files changed

+338
-26
lines changed

2 files changed

+338
-26
lines changed

internal/proxy/proxy.go

Lines changed: 103 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ func New(cfg *config.Config, logger *slog.Logger) (*Proxy, error) {
2828
cfg: cfg,
2929
logger: logger,
3030
client: modbus.NewClient(cfg.Upstream, cfg.Timeout, cfg.RequestDelay, cfg.ConnectDelay, logger),
31-
cache: cache.New(cfg.CacheTTL),
31+
cache: cache.New(cfg.CacheTTL, cfg.CacheServeStale),
3232
}
3333

3434
p.server = modbus.NewServer(p, logger)
@@ -104,10 +104,21 @@ func (p *Proxy) HandleRequest(ctx context.Context, req *modbus.Request) ([]byte,
104104
}
105105

106106
func (p *Proxy) handleRead(ctx context.Context, req *modbus.Request) ([]byte, error) {
107-
key := cache.Key(req.SlaveID, req.FunctionCode, req.Address, req.Quantity)
107+
// Check per-register cache
108+
values, cacheHit := p.cache.GetRange(req.SlaveID, req.FunctionCode, req.Address, req.Quantity)
109+
if cacheHit {
110+
p.logger.Debug("cache hit",
111+
"slave_id", req.SlaveID,
112+
"func", fmt.Sprintf("0x%02X", req.FunctionCode),
113+
"addr", req.Address,
114+
"qty", req.Quantity,
115+
)
116+
return assembleResponse(req.FunctionCode, req.Quantity, values), nil
117+
}
108118

109-
// Use GetOrFetch for request coalescing
110-
data, cacheHit, err := p.cache.GetOrFetch(ctx, key, func(ctx context.Context) ([]byte, error) {
119+
// Cache miss — fetch with coalescing
120+
rangeKey := cache.RangeKey(req.SlaveID, req.FunctionCode, req.Address, req.Quantity)
121+
data, err := p.cache.Coalesce(ctx, rangeKey, func(ctx context.Context) ([]byte, error) {
111122
p.logger.Debug("cache miss",
112123
"slave_id", req.SlaveID,
113124
"func", fmt.Sprintf("0x%02X", req.FunctionCode),
@@ -121,24 +132,21 @@ func (p *Proxy) handleRead(ctx context.Context, req *modbus.Request) ([]byte, er
121132
if err != nil {
122133
// Try serving stale data if configured
123134
if p.cfg.CacheServeStale {
124-
if stale, ok := p.cache.GetStale(key); ok {
135+
if staleValues, ok := p.cache.GetRangeStale(req.SlaveID, req.FunctionCode, req.Address, req.Quantity); ok {
125136
p.logger.Warn("upstream error, serving stale",
126137
"slave_id", req.SlaveID,
127138
"error", err,
128139
)
129-
return stale, nil
140+
return assembleResponse(req.FunctionCode, req.Quantity, staleValues), nil
130141
}
131142
}
132143
return nil, err
133144
}
134145

135-
if cacheHit {
136-
p.logger.Debug("cache hit",
137-
"slave_id", req.SlaveID,
138-
"func", fmt.Sprintf("0x%02X", req.FunctionCode),
139-
"addr", req.Address,
140-
"qty", req.Quantity,
141-
)
146+
// Decompose response and store per-register
147+
regValues := decomposeResponse(req.FunctionCode, req.Quantity, data)
148+
if regValues != nil {
149+
p.cache.SetRange(req.SlaveID, req.FunctionCode, req.Address, regValues)
142150
}
143151

144152
return data, nil
@@ -171,7 +179,7 @@ func (p *Proxy) handleWrite(ctx context.Context, req *modbus.Request) ([]byte, e
171179
return nil, err
172180
}
173181

174-
// Invalidate exact matching cache entries for all read function codes
182+
// Invalidate per-register cache entries for the written range
175183
p.invalidateCache(req)
176184

177185
return resp, nil
@@ -181,7 +189,7 @@ func (p *Proxy) handleWrite(ctx context.Context, req *modbus.Request) ([]byte, e
181189
}
182190

183191
func (p *Proxy) invalidateCache(req *modbus.Request) {
184-
// Invalidate exact matches for all read function codes that could overlap
192+
// Invalidate per-register entries for all read function codes
185193
readFuncs := []byte{
186194
modbus.FuncReadCoils,
187195
modbus.FuncReadDiscreteInputs,
@@ -190,9 +198,87 @@ func (p *Proxy) invalidateCache(req *modbus.Request) {
190198
}
191199

192200
for _, fc := range readFuncs {
193-
key := cache.Key(req.SlaveID, fc, req.Address, req.Quantity)
194-
p.cache.Delete(key)
201+
p.cache.DeleteRange(req.SlaveID, fc, req.Address, req.Quantity)
202+
}
203+
}
204+
205+
// decomposeResponse extracts per-register/coil values from a Modbus read response.
206+
// Response format: [funcCode, byteCount, data...]
207+
// For registers (FC 0x03, 0x04): each register is 2 bytes.
208+
// For coils/discrete inputs (FC 0x01, 0x02): each coil is 1 bit, stored as 1 byte (0 or 1).
209+
func decomposeResponse(functionCode byte, quantity uint16, data []byte) [][]byte {
210+
if len(data) < 2 {
211+
return nil
212+
}
213+
214+
payload := data[2:] // Skip funcCode and byteCount
215+
216+
switch functionCode {
217+
case modbus.FuncReadHoldingRegisters, modbus.FuncReadInputRegisters:
218+
values := make([][]byte, quantity)
219+
for i := uint16(0); i < quantity; i++ {
220+
offset := i * 2
221+
if int(offset+2) > len(payload) {
222+
return nil
223+
}
224+
reg := make([]byte, 2)
225+
copy(reg, payload[offset:offset+2])
226+
values[i] = reg
227+
}
228+
return values
229+
230+
case modbus.FuncReadCoils, modbus.FuncReadDiscreteInputs:
231+
values := make([][]byte, quantity)
232+
for i := uint16(0); i < quantity; i++ {
233+
byteIdx := i / 8
234+
bitIdx := i % 8
235+
if int(byteIdx) >= len(payload) {
236+
return nil
237+
}
238+
if payload[byteIdx]&(1<<bitIdx) != 0 {
239+
values[i] = []byte{1}
240+
} else {
241+
values[i] = []byte{0}
242+
}
243+
}
244+
return values
195245
}
246+
247+
return nil
248+
}
249+
250+
// assembleResponse reconstructs a Modbus read response from per-register/coil values.
251+
func assembleResponse(functionCode byte, quantity uint16, values [][]byte) []byte {
252+
switch functionCode {
253+
case modbus.FuncReadHoldingRegisters, modbus.FuncReadInputRegisters:
254+
byteCount := quantity * 2
255+
resp := make([]byte, 2+byteCount)
256+
resp[0] = functionCode
257+
resp[1] = byte(byteCount)
258+
for i, v := range values {
259+
if len(v) >= 2 {
260+
resp[2+i*2] = v[0]
261+
resp[2+i*2+1] = v[1]
262+
}
263+
}
264+
return resp
265+
266+
case modbus.FuncReadCoils, modbus.FuncReadDiscreteInputs:
267+
byteCount := (quantity + 7) / 8
268+
resp := make([]byte, 2+byteCount)
269+
resp[0] = functionCode
270+
resp[1] = byte(byteCount)
271+
for i, v := range values {
272+
if len(v) > 0 && v[0] != 0 {
273+
byteIdx := i / 8
274+
bitIdx := uint(i % 8)
275+
resp[2+byteIdx] |= 1 << bitIdx
276+
}
277+
}
278+
return resp
279+
}
280+
281+
return nil
196282
}
197283

198284
func (p *Proxy) buildFakeWriteResponse(req *modbus.Request) []byte {

0 commit comments

Comments
 (0)