Add tcp keep alive

This commit was merged in pull request #6.
This commit is contained in:
2023-02-24 18:22:39 +01:00
parent 07ad6f4b24
commit a5f1936632
2 changed files with 81 additions and 20 deletions

View File

@@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"io" "io"
"net" "net"
"sync"
"time" "time"
) )
@@ -13,23 +14,44 @@ type Mbclient struct {
address string address string
header [7]byte header [7]byte
unit uint8 unit uint8
conn net.Conn
t *time.Timer
keepAliveDuration time.Duration
wg sync.WaitGroup
} }
func New(address string, unit uint8) (*Mbclient, error) { func New(Address string, Unit uint8, KeepAlive time.Duration) (*Mbclient, error) {
c := new(Mbclient) c := new(Mbclient)
c.address = address c.address = Address
c.unit = unit c.unit = Unit
c.t = time.NewTimer(0)
<-c.t.C
c.keepAliveDuration = KeepAlive
return c, nil return c, nil
} }
func (m *Mbclient) ReadRegisters(first uint16, numRegs uint16) ([]uint16, error) { func (m *Mbclient) closer() {
conn, err := net.Dial("tcp", m.address) <-m.t.C
if err != nil { m.conn.Close()
return nil, err m.wg.Done()
}
func (m *Mbclient) ReadRegisters(first uint16, numRegs uint16) ([]uint16, error) {
var err error
// If The timer is expired, conn is closed and needs to be reopened
if !m.t.Stop() {
// Wait for closer to exit to mitigate race condiion
// between closer routine and this code path
m.wg.Wait()
m.conn, err = net.Dial("tcp", m.address)
if err != nil {
return nil, err
}
m.wg.Add(1)
go m.closer()
} }
defer conn.Close()
const requestLength = 12 const requestLength = 12
m.transactionCounter++ m.transactionCounter++
@@ -44,22 +66,27 @@ func (m *Mbclient) ReadRegisters(first uint16, numRegs uint16) ([]uint16, error)
req[7] = 3 //FunctionCode req[7] = 3 //FunctionCode
binary.BigEndian.PutUint16(req[8:10], first-1) binary.BigEndian.PutUint16(req[8:10], first-1)
binary.BigEndian.PutUint16(req[10:12], numRegs) binary.BigEndian.PutUint16(req[10:12], numRegs)
conn.SetDeadline(time.Now().Add(10 * time.Second)) m.conn.SetDeadline(time.Now().Add(10 * time.Second))
byteswritten, err := conn.Write(req) byteswritten, err := m.conn.Write(req)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if byteswritten != requestLength { if byteswritten != requestLength {
return nil, fmt.Errorf("Failed to send request") return nil, fmt.Errorf("failed to send request")
}
m.conn.SetDeadline(time.Now().Add(10 * time.Second))
_, err = io.ReadFull(m.conn, m.header[:])
if err != nil {
return nil, err
} }
conn.SetDeadline(time.Now().Add(10 * time.Second))
_, err = io.ReadFull(conn, m.header[:])
responseHeader.unMarshal(m.header) responseHeader.unMarshal(m.header)
expectedDataLength := responseHeader.length - 1 expectedDataLength := responseHeader.length - 1
response := make([]byte, expectedDataLength) response := make([]byte, expectedDataLength)
_, err = conn.Read(response) _, err = m.conn.Read(response)
if err != nil {
return nil, err
}
err = mbpayload.unMarshal(response) err = mbpayload.unMarshal(response)
if mbpayload.functionCode != 3 { if mbpayload.functionCode != 3 {
return nil, fmt.Errorf("modbus exception %v", mbpayload.functionCode&0x7F) return nil, fmt.Errorf("modbus exception %v", mbpayload.functionCode&0x7F)
@@ -67,7 +94,7 @@ func (m *Mbclient) ReadRegisters(first uint16, numRegs uint16) ([]uint16, error)
if err != nil { if err != nil {
return nil, err return nil, err
} }
m.t.Reset(m.keepAliveDuration)
return mbpayload.registers, nil return mbpayload.registers, nil
} }
@@ -101,7 +128,7 @@ func (d *mbPDU) unMarshal(data []byte) error {
d.functionCode = data[0] d.functionCode = data[0]
d.length = data[1] d.length = data[1]
if d.length+2 != uint8(len(data)) { if d.length+2 != uint8(len(data)) {
return fmt.Errorf("Lenght mismatch in modbus payload") return fmt.Errorf("lenght mismatch in modbus payload")
} }
d.registers = make([]uint16, d.length/2) d.registers = make([]uint16, d.length/2)
var n uint8 var n uint8

View File

@@ -2,12 +2,14 @@ package modbustcpclient
import ( import (
"testing" "testing"
"time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func TestReadOneRegister(t *testing.T) { func TestReadOneRegisterKeepAlive(t *testing.T) {
c, err := New("IAM_248000012514.solver.nu:502", 1) c, err := New("IAM_248000012514.solver.nu:502", 1, 100*time.Millisecond)
t.Log("Connect")
assert.NoError(t, err) assert.NoError(t, err)
for n := 0; n < 5; n++ { for n := 0; n < 5; n++ {
res, err := c.ReadRegisters(12401, 2) res, err := c.ReadRegisters(12401, 2)
@@ -30,5 +32,37 @@ func TestReadOneRegister(t *testing.T) {
assert.Len(t, res, 1) assert.Len(t, res, 1)
t.Log(res) t.Log(res)
} }
time.Sleep(1 * time.Second)
}
func TestReadOneRegisterShortKeepAlive(t *testing.T) {
c, err := New("IAM_248000012514.solver.nu:502", 1, 10*time.Nanosecond)
t.Log("Connect")
assert.NoError(t, err)
for n := 0; n < 5; n++ {
res, err := c.ReadRegisters(12401, 2)
assert.NoError(t, err)
assert.Len(t, res, 2)
t.Log(res)
time.Sleep(100 * time.Millisecond)
res, err = c.ReadRegisters(12102, 2)
assert.NoError(t, err)
assert.Len(t, res, 2)
t.Log(res)
time.Sleep(100 * time.Millisecond)
res, err = c.ReadRegisters(12544, 1)
assert.NoError(t, err)
assert.Len(t, res, 1)
t.Log(float32(res[0]) / 10)
time.Sleep(100 * time.Millisecond)
res, err = c.ReadRegisters(12136, 1)
assert.NoError(t, err)
assert.Len(t, res, 1)
t.Log(res)
time.Sleep(100 * time.Millisecond)
}
time.Sleep(1 * time.Second)
} }