diff --git a/websocket/connection.go b/websocket/connection.go index b70d8b1ed1..dc0c20ee0b 100644 --- a/websocket/connection.go +++ b/websocket/connection.go @@ -86,14 +86,40 @@ type ManagedConnection struct { // Used for the exponential backoff when connecting connectionBackoff wait.Backoff + + // OnConnect is called when a connection is successfully established. + // This callback is invoked each time the connection is established, + // including reconnections. + OnConnect func() + + // OnDisconnect is called when a connection is lost. + // The error parameter contains the reason for the disconnection. + OnDisconnect func(error) +} + +// ConnectionOption is a functional option for configuring ManagedConnection. +type ConnectionOption func(*ManagedConnection) + +// WithOnConnect sets a callback that is invoked when a connection is established. +func WithOnConnect(f func()) ConnectionOption { + return func(c *ManagedConnection) { + c.OnConnect = f + } +} + +// WithOnDisconnect sets a callback that is invoked when a connection is lost. +func WithOnDisconnect(f func(error)) ConnectionOption { + return func(c *ManagedConnection) { + c.OnDisconnect = f + } } // NewDurableSendingConnection creates a new websocket connection // that can only send messages to the endpoint it connects to. // The connection will continuously be kept alive and reconnected // in case of a loss of connectivity. -func NewDurableSendingConnection(target string, logger *zap.SugaredLogger) *ManagedConnection { - return NewDurableConnection(target, nil, logger) +func NewDurableSendingConnection(target string, logger *zap.SugaredLogger, opts ...ConnectionOption) *ManagedConnection { + return NewDurableConnection(target, nil, logger, opts...) } // NewDurableSendingConnectionGuaranteed creates a new websocket connection @@ -127,7 +153,7 @@ func NewDurableSendingConnectionGuaranteed(target string, duration time.Duration // // go func() {conn.Shutdown(); close(messageChan)} // go func() {for range messageChan {}} -func NewDurableConnection(target string, messageChan chan []byte, logger *zap.SugaredLogger) *ManagedConnection { +func NewDurableConnection(target string, messageChan chan []byte, logger *zap.SugaredLogger, opts ...ConnectionOption) *ManagedConnection { websocketConnectionFactory := func() (rawConnection, error) { dialer := &websocket.Dialer{ // This needs to be relatively short to avoid the connection getting blackholed for a long time @@ -149,6 +175,11 @@ func NewDurableConnection(target string, messageChan chan []byte, logger *zap.Su c := newConnection(websocketConnectionFactory, messageChan) + // Apply options before starting the goroutine + for _, opt := range opts { + opt(c) + } + // Keep the connection alive asynchronously and reconnect on // connection failure. c.processingWg.Add(1) @@ -164,8 +195,14 @@ func NewDurableConnection(target string, messageChan chan []byte, logger *zap.Su continue } logger.Debug("Connected to ", target) + if c.OnConnect != nil { + c.OnConnect() + } if err := c.keepalive(); err != nil { logger.Errorw(fmt.Sprintf("Connection to %s broke down, reconnecting...", target), zap.Error(err)) + if c.OnDisconnect != nil { + c.OnDisconnect(err) + } } if err := c.closeConnection(); err != nil { logger.Errorw("Failed to close the connection after crashing", zap.Error(err)) diff --git a/websocket/connection_test.go b/websocket/connection_test.go index 7b35a6278f..59f0edddbc 100644 --- a/websocket/connection_test.go +++ b/websocket/connection_test.go @@ -419,6 +419,106 @@ func TestDurableConnectionSendsPingsRegularly(t *testing.T) { } } +func TestOnConnectAndOnDisconnectCallbacks(t *testing.T) { + reconnectChan := make(chan struct{}) + + upgrader := websocket.Upgrader{} + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + c, err := upgrader.Upgrade(w, r, nil) + if err != nil { + return + } + + // Wait for signal to drop the connection. + <-reconnectChan + c.Close() + })) + defer s.Close() + + logger := ktesting.TestLogger(t) + target := "ws" + strings.TrimPrefix(s.URL, "http") + + onConnectCalled := make(chan struct{}, 10) + onDisconnectCalled := make(chan error, 10) + + conn := NewDurableSendingConnection(target, logger, + WithOnConnect(func() { + onConnectCalled <- struct{}{} + }), + WithOnDisconnect(func(err error) { + onDisconnectCalled <- err + }), + ) + defer conn.Shutdown() + + // Wait for the first OnConnect call + select { + case <-onConnectCalled: + // Success - OnConnect was called + case <-time.After(propagationTimeout): + t.Fatal("Timed out waiting for OnConnect to be called") + } + + // Trigger a disconnect by closing the server-side connection + reconnectChan <- struct{}{} + + // Wait for OnDisconnect to be called + select { + case err := <-onDisconnectCalled: + if err == nil { + t.Error("Expected OnDisconnect to receive an error, got nil") + } + case <-time.After(propagationTimeout): + t.Fatal("Timed out waiting for OnDisconnect to be called") + } + + // Wait for reconnection (OnConnect should be called again) + select { + case <-onConnectCalled: + // Success - OnConnect was called again after reconnection + case <-time.After(propagationTimeout): + t.Fatal("Timed out waiting for OnConnect to be called after reconnection") + } +} + +func TestOnConnectAndOnDisconnectCallbacksNotSet(t *testing.T) { + reconnectChan := make(chan struct{}) + + upgrader := websocket.Upgrader{} + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + c, err := upgrader.Upgrade(w, r, nil) + if err != nil { + return + } + + // Wait for signal to drop the connection. + <-reconnectChan + c.Close() + })) + defer s.Close() + + logger := ktesting.TestLogger(t) + target := "ws" + strings.TrimPrefix(s.URL, "http") + + // Create connection without setting callbacks - should not panic + conn := NewDurableSendingConnection(target, logger) + defer conn.Shutdown() + + // Wait for connection to be established + err := wait.PollUntilContextTimeout(context.Background(), 50*time.Millisecond, propagationTimeout, true, func(ctx context.Context) (bool, error) { + return conn.Status() == nil, nil + }) + if err != nil { + t.Fatal("Timed out waiting for connection to be established:", err) + } + + // Trigger disconnect - should not panic even without callbacks + reconnectChan <- struct{}{} + + // Wait a bit and verify no panic occurred + time.Sleep(100 * time.Millisecond) +} + func TestNewDurableSendingConnectionGuaranteed(t *testing.T) { // Unhappy case. logger := ktesting.TestLogger(t)