diff --git a/wshandler/wshandler.go b/wshandler/wshandler.go index 7170bbd..9bb7acb 100644 --- a/wshandler/wshandler.go +++ b/wshandler/wshandler.go @@ -83,7 +83,7 @@ type subhandler struct { type WebsocketHandler struct { authCaches map[string]*subhandler RedisSync *redis.Client - receiverChain []WebSocketMessageReceiver + receiverChain map[string][]WebSocketMessageReceiver } type wsConfig struct { @@ -130,22 +130,26 @@ func NewWebsocketHandler(authglobal gocommon.AuthCollectionGlobal) (wsh *Websock } return &WebsocketHandler{ - authCaches: authCaches, - RedisSync: redisSync, + authCaches: authCaches, + RedisSync: redisSync, + receiverChain: make(map[string][]WebSocketMessageReceiver), } } -func (ws *WebsocketHandler) RegisterReceiver(receiver WebSocketMessageReceiver) { - ws.receiverChain = append(ws.receiverChain, receiver) +func (ws *WebsocketHandler) RegisterReceiver(region string, receiver WebSocketMessageReceiver) { + ws.receiverChain[region] = append(ws.receiverChain[region], receiver) } func (ws *WebsocketHandler) Start(ctx context.Context) { - for _, sh := range ws.authCaches { - if len(ws.receiverChain) == 1 { - sh.callReceiver = ws.receiverChain[0] + for region, sh := range ws.authCaches { + chain := ws.receiverChain[region] + if len(chain) == 0 { + sh.callReceiver = func(accid primitive.ObjectID, alias string, messageType WebSocketMessageType, body io.Reader) {} + } else if len(chain) == 1 { + sh.callReceiver = chain[0] } else { sh.callReceiver = func(accid primitive.ObjectID, alias string, messageType WebSocketMessageType, body io.Reader) { - for _, r := range ws.receiverChain { + for _, r := range chain { r(accid, alias, messageType, body) } }