From a5f19366328a484488fe3ac640dee621d35b9475 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Henrik=20S=C3=B6lver?= Date: Fri, 24 Feb 2023 18:22:39 +0100 Subject: [PATCH] Add tcp keep alive --- client.go | 61 ++++++++++++++++++++++++++++++++++++-------------- client_test.go | 40 ++++++++++++++++++++++++++++++--- 2 files changed, 81 insertions(+), 20 deletions(-) diff --git a/client.go b/client.go index 9668676..462ff3c 100644 --- a/client.go +++ b/client.go @@ -5,6 +5,7 @@ import ( "fmt" "io" "net" + "sync" "time" ) @@ -13,23 +14,44 @@ type Mbclient struct { address string header [7]byte unit uint8 + conn net.Conn + t *time.Timer + keepAliveDuration time.Duration + wg sync.WaitGroup } -func New(address string, unit uint8) (*Mbclient, error) { +func New(Address string, Unit uint8, KeepAlive time.Duration) (*Mbclient, error) { c := new(Mbclient) - c.address = address - c.unit = unit + c.address = Address + c.unit = Unit + c.t = time.NewTimer(0) + <-c.t.C + c.keepAliveDuration = KeepAlive return c, nil } -func (m *Mbclient) ReadRegisters(first uint16, numRegs uint16) ([]uint16, error) { +func (m *Mbclient) closer() { - conn, err := net.Dial("tcp", m.address) - if err != nil { - return nil, err + <-m.t.C + m.conn.Close() + m.wg.Done() +} + +func (m *Mbclient) ReadRegisters(first uint16, numRegs uint16) ([]uint16, error) { + var err error + // If The timer is expired, conn is closed and needs to be reopened + if !m.t.Stop() { + // Wait for closer to exit to mitigate race condiion + // between closer routine and this code path + m.wg.Wait() + m.conn, err = net.Dial("tcp", m.address) + if err != nil { + return nil, err + } + m.wg.Add(1) + go m.closer() } - defer conn.Close() const requestLength = 12 m.transactionCounter++ @@ -44,22 +66,27 @@ func (m *Mbclient) ReadRegisters(first uint16, numRegs uint16) ([]uint16, error) req[7] = 3 //FunctionCode binary.BigEndian.PutUint16(req[8:10], first-1) binary.BigEndian.PutUint16(req[10:12], numRegs) - conn.SetDeadline(time.Now().Add(10 * time.Second)) - byteswritten, err := conn.Write(req) + m.conn.SetDeadline(time.Now().Add(10 * time.Second)) + byteswritten, err := m.conn.Write(req) if err != nil { return nil, err } if byteswritten != requestLength { - return nil, fmt.Errorf("Failed to send request") + return nil, fmt.Errorf("failed to send request") + } + m.conn.SetDeadline(time.Now().Add(10 * time.Second)) + _, err = io.ReadFull(m.conn, m.header[:]) + if err != nil { + return nil, err } - conn.SetDeadline(time.Now().Add(10 * time.Second)) - _, err = io.ReadFull(conn, m.header[:]) responseHeader.unMarshal(m.header) expectedDataLength := responseHeader.length - 1 response := make([]byte, expectedDataLength) - _, err = conn.Read(response) - + _, err = m.conn.Read(response) + if err != nil { + return nil, err + } err = mbpayload.unMarshal(response) if mbpayload.functionCode != 3 { return nil, fmt.Errorf("modbus exception %v", mbpayload.functionCode&0x7F) @@ -67,7 +94,7 @@ func (m *Mbclient) ReadRegisters(first uint16, numRegs uint16) ([]uint16, error) if err != nil { return nil, err } - + m.t.Reset(m.keepAliveDuration) return mbpayload.registers, nil } @@ -101,7 +128,7 @@ func (d *mbPDU) unMarshal(data []byte) error { d.functionCode = data[0] d.length = data[1] if d.length+2 != uint8(len(data)) { - return fmt.Errorf("Lenght mismatch in modbus payload") + return fmt.Errorf("lenght mismatch in modbus payload") } d.registers = make([]uint16, d.length/2) var n uint8 diff --git a/client_test.go b/client_test.go index aa5b638..24d8098 100644 --- a/client_test.go +++ b/client_test.go @@ -2,12 +2,14 @@ package modbustcpclient import ( "testing" + "time" "github.com/stretchr/testify/assert" ) -func TestReadOneRegister(t *testing.T) { - c, err := New("IAM_248000012514.solver.nu:502", 1) +func TestReadOneRegisterKeepAlive(t *testing.T) { + c, err := New("IAM_248000012514.solver.nu:502", 1, 100*time.Millisecond) + t.Log("Connect") assert.NoError(t, err) for n := 0; n < 5; n++ { res, err := c.ReadRegisters(12401, 2) @@ -30,5 +32,37 @@ func TestReadOneRegister(t *testing.T) { assert.Len(t, res, 1) t.Log(res) } - + time.Sleep(1 * time.Second) +} + +func TestReadOneRegisterShortKeepAlive(t *testing.T) { + c, err := New("IAM_248000012514.solver.nu:502", 1, 10*time.Nanosecond) + t.Log("Connect") + assert.NoError(t, err) + for n := 0; n < 5; n++ { + res, err := c.ReadRegisters(12401, 2) + assert.NoError(t, err) + assert.Len(t, res, 2) + t.Log(res) + time.Sleep(100 * time.Millisecond) + + res, err = c.ReadRegisters(12102, 2) + assert.NoError(t, err) + assert.Len(t, res, 2) + t.Log(res) + time.Sleep(100 * time.Millisecond) + + res, err = c.ReadRegisters(12544, 1) + assert.NoError(t, err) + assert.Len(t, res, 1) + t.Log(float32(res[0]) / 10) + time.Sleep(100 * time.Millisecond) + + res, err = c.ReadRegisters(12136, 1) + assert.NoError(t, err) + assert.Len(t, res, 1) + t.Log(res) + time.Sleep(100 * time.Millisecond) + } + time.Sleep(1 * time.Second) }