Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 40 additions & 3 deletions websocket/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,14 +86,40 @@ type ManagedConnection struct {

// Used for the exponential backoff when connecting
connectionBackoff wait.Backoff

// OnConnect is called when a connection is successfully established.
// This callback is invoked each time the connection is established,
// including reconnections.
OnConnect func()

// OnDisconnect is called when a connection is lost.
// The error parameter contains the reason for the disconnection.
OnDisconnect func(error)
}

// ConnectionOption is a functional option for configuring ManagedConnection.
type ConnectionOption func(*ManagedConnection)

// WithOnConnect sets a callback that is invoked when a connection is established.
func WithOnConnect(f func()) ConnectionOption {
return func(c *ManagedConnection) {
c.OnConnect = f
}
}

// WithOnDisconnect sets a callback that is invoked when a connection is lost.
func WithOnDisconnect(f func(error)) ConnectionOption {
return func(c *ManagedConnection) {
c.OnDisconnect = f
}
}

// NewDurableSendingConnection creates a new websocket connection
// that can only send messages to the endpoint it connects to.
// The connection will continuously be kept alive and reconnected
// in case of a loss of connectivity.
func NewDurableSendingConnection(target string, logger *zap.SugaredLogger) *ManagedConnection {
return NewDurableConnection(target, nil, logger)
func NewDurableSendingConnection(target string, logger *zap.SugaredLogger, opts ...ConnectionOption) *ManagedConnection {
return NewDurableConnection(target, nil, logger, opts...)
}

// NewDurableSendingConnectionGuaranteed creates a new websocket connection
Expand Down Expand Up @@ -127,7 +153,7 @@ func NewDurableSendingConnectionGuaranteed(target string, duration time.Duration
//
// go func() {conn.Shutdown(); close(messageChan)}
// go func() {for range messageChan {}}
func NewDurableConnection(target string, messageChan chan []byte, logger *zap.SugaredLogger) *ManagedConnection {
func NewDurableConnection(target string, messageChan chan []byte, logger *zap.SugaredLogger, opts ...ConnectionOption) *ManagedConnection {
websocketConnectionFactory := func() (rawConnection, error) {
dialer := &websocket.Dialer{
// This needs to be relatively short to avoid the connection getting blackholed for a long time
Expand All @@ -149,6 +175,11 @@ func NewDurableConnection(target string, messageChan chan []byte, logger *zap.Su

c := newConnection(websocketConnectionFactory, messageChan)

// Apply options before starting the goroutine
for _, opt := range opts {
opt(c)
}

// Keep the connection alive asynchronously and reconnect on
// connection failure.
c.processingWg.Add(1)
Expand All @@ -164,8 +195,14 @@ func NewDurableConnection(target string, messageChan chan []byte, logger *zap.Su
continue
}
logger.Debug("Connected to ", target)
if c.OnConnect != nil {
c.OnConnect()
}
if err := c.keepalive(); err != nil {
logger.Errorw(fmt.Sprintf("Connection to %s broke down, reconnecting...", target), zap.Error(err))
if c.OnDisconnect != nil {
c.OnDisconnect(err)
}
}
if err := c.closeConnection(); err != nil {
logger.Errorw("Failed to close the connection after crashing", zap.Error(err))
Expand Down
100 changes: 100 additions & 0 deletions websocket/connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,106 @@ func TestDurableConnectionSendsPingsRegularly(t *testing.T) {
}
}

func TestOnConnectAndOnDisconnectCallbacks(t *testing.T) {
reconnectChan := make(chan struct{})

upgrader := websocket.Upgrader{}
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
c, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return
}

// Wait for signal to drop the connection.
<-reconnectChan
c.Close()
}))
defer s.Close()

logger := ktesting.TestLogger(t)
target := "ws" + strings.TrimPrefix(s.URL, "http")

onConnectCalled := make(chan struct{}, 10)
onDisconnectCalled := make(chan error, 10)

conn := NewDurableSendingConnection(target, logger,
WithOnConnect(func() {
onConnectCalled <- struct{}{}
}),
WithOnDisconnect(func(err error) {
onDisconnectCalled <- err
}),
)
defer conn.Shutdown()

// Wait for the first OnConnect call
select {
case <-onConnectCalled:
// Success - OnConnect was called
case <-time.After(propagationTimeout):
t.Fatal("Timed out waiting for OnConnect to be called")
}

// Trigger a disconnect by closing the server-side connection
reconnectChan <- struct{}{}

// Wait for OnDisconnect to be called
select {
case err := <-onDisconnectCalled:
if err == nil {
t.Error("Expected OnDisconnect to receive an error, got nil")
}
case <-time.After(propagationTimeout):
t.Fatal("Timed out waiting for OnDisconnect to be called")
}

// Wait for reconnection (OnConnect should be called again)
select {
case <-onConnectCalled:
// Success - OnConnect was called again after reconnection
case <-time.After(propagationTimeout):
t.Fatal("Timed out waiting for OnConnect to be called after reconnection")
}
}

func TestOnConnectAndOnDisconnectCallbacksNotSet(t *testing.T) {
reconnectChan := make(chan struct{})

upgrader := websocket.Upgrader{}
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
c, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return
}

// Wait for signal to drop the connection.
<-reconnectChan
c.Close()
}))
defer s.Close()

logger := ktesting.TestLogger(t)
target := "ws" + strings.TrimPrefix(s.URL, "http")

// Create connection without setting callbacks - should not panic
conn := NewDurableSendingConnection(target, logger)
defer conn.Shutdown()

// Wait for connection to be established
err := wait.PollUntilContextTimeout(context.Background(), 50*time.Millisecond, propagationTimeout, true, func(ctx context.Context) (bool, error) {
return conn.Status() == nil, nil
})
if err != nil {
t.Fatal("Timed out waiting for connection to be established:", err)
}

// Trigger disconnect - should not panic even without callbacks
reconnectChan <- struct{}{}

// Wait a bit and verify no panic occurred
time.Sleep(100 * time.Millisecond)
Copy link
Member

@dprotaso dprotaso Jan 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this? Seems unnecessary - what would panic?

}

func TestNewDurableSendingConnectionGuaranteed(t *testing.T) {
// Unhappy case.
logger := ktesting.TestLogger(t)
Expand Down
Loading