From 8c33fd2a89c76042d39cbc5cc9ff1e2eeadae17d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Henrik=20S=C3=B6lver?= Date: Sat, 25 Feb 2023 22:30:31 +0100 Subject: [PATCH] Set read deadline correctly --- client.go | 9 ++++++--- client_test.go | 6 +++--- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/client.go b/client.go index cd25539..bfca98e 100644 --- a/client.go +++ b/client.go @@ -17,10 +17,11 @@ type Mbclient struct { conn net.Conn t *time.Timer keepAliveDuration time.Duration + timeOut time.Duration wg sync.WaitGroup } -func New(Address string, Unit uint8, KeepAlive time.Duration) (*Mbclient, error) { +func New(Address string, Unit uint8, KeepAlive, TimeOut time.Duration) (*Mbclient, error) { c := new(Mbclient) c.address = Address @@ -28,6 +29,7 @@ func New(Address string, Unit uint8, KeepAlive time.Duration) (*Mbclient, error) c.t = time.NewTimer(0) <-c.t.C c.keepAliveDuration = KeepAlive + c.timeOut = TimeOut return c, nil } @@ -54,7 +56,7 @@ func (m *Mbclient) ReadRegisters(first uint16, numRegs uint16) ([]uint16, error) // Wait for closer to exit to mitigate race condiion // between closer routine and this code path m.wg.Wait() - m.conn, err = net.DialTimeout("tcp", m.address, 5*time.Second) + m.conn, err = net.DialTimeout("tcp", m.address, m.timeOut) if err != nil { return nil, err } @@ -86,7 +88,7 @@ func (m *Mbclient) ReadRegisters(first uint16, numRegs uint16) ([]uint16, error) m.closeConn() return nil, fmt.Errorf("failed to send request") } - m.conn.SetDeadline(time.Now().Add(5 * time.Second)) + m.conn.SetDeadline(time.Now().Add(m.timeOut)) _, err = io.ReadFull(m.conn, m.header[:]) if err != nil { m.closeConn() @@ -100,6 +102,7 @@ func (m *Mbclient) ReadRegisters(first uint16, numRegs uint16) ([]uint16, error) return nil, fmt.Errorf("modbus transaction mismatch %v != %v", m.transactionCounter, responseHeader.transactionID) } response := make([]byte, expectedDataLength) + m.conn.SetDeadline(time.Now().Add(m.timeOut)) bytesRead, err := io.ReadFull(m.conn, response) if err != nil { m.closeConn() diff --git a/client_test.go b/client_test.go index fe0253c..95091f8 100644 --- a/client_test.go +++ b/client_test.go @@ -8,7 +8,7 @@ import ( ) func TestReadOneRegisterKeepAlive(t *testing.T) { - c, err := New("IAM_248000012514.solver.nu:502", 1, 100*time.Millisecond) + c, err := New("IAM_248000012514.solver.nu:502", 1, 100*time.Millisecond, 5*time.Second) t.Log("Connect") assert.NoError(t, err) for n := 0; n < 5; n++ { @@ -36,7 +36,7 @@ func TestReadOneRegisterKeepAlive(t *testing.T) { } func TestReadOneRegisterShortKeepAlive(t *testing.T) { - c, err := New("IAM_248000012514.solver.nu:502", 1, 10*time.Nanosecond) + c, err := New("IAM_248000012514.solver.nu:502", 1, 10*time.Nanosecond, 5*time.Second) t.Log("Connect") assert.NoError(t, err) for n := 0; n < 5; n++ { @@ -68,7 +68,7 @@ func TestReadOneRegisterShortKeepAlive(t *testing.T) { } func TestReadALot(t *testing.T) { - c, err := New("IAM_248000012514.solver.nu:502", 1, 100*time.Millisecond) + c, err := New("IAM_248000012514.solver.nu:502", 1, 100*time.Millisecond, 5*time.Second) t.Log("Connect") assert.NoError(t, err) for n := 0; n < 500; n++ {