diff --git a/wshandler/wshandler.go b/wshandler/wshandler.go index e8597e3..381c5b4 100644 --- a/wshandler/wshandler.go +++ b/wshandler/wshandler.go @@ -8,6 +8,7 @@ import ( "io" "net/http" "strings" + "sync" "go.mongodb.org/mongo-driver/bson/primitive" "repositories.action2quare.com/ayo/gocommon" @@ -65,7 +66,6 @@ const ( type WebSocketMessageReceiver func(accid primitive.ObjectID, alias string, messageType WebSocketMessageType, body io.Reader) type subhandler struct { - name string authCache *gocommon.AuthCollection redisMsgChanName string redisCmdChanName string @@ -73,6 +73,7 @@ type subhandler struct { connInOutChan chan *wsconn deliveryChan chan any callReceiver WebSocketMessageReceiver + connWaitGroup sync.WaitGroup } // WebsocketHandler : @@ -112,7 +113,6 @@ func NewWebsocketHandler(authglobal gocommon.AuthCollectionGlobal, receiver WebS authCaches := make(map[string]*subhandler) for _, region := range authglobal.Regions() { sh := &subhandler{ - name: region, authCache: authglobal.Get(region), redisMsgChanName: fmt.Sprintf("_wsh_msg_%s", region), redisCmdChanName: fmt.Sprintf("_wsh_cmd_%s", region), @@ -131,7 +131,10 @@ func NewWebsocketHandler(authglobal gocommon.AuthCollectionGlobal, receiver WebS } } -func (ws *WebsocketHandler) Destructor() { +func (ws *WebsocketHandler) Cleanup() { + for _, sh := range ws.authCaches { + sh.connWaitGroup.Wait() + } } func (ws *WebsocketHandler) RegisterHandlers(ctx context.Context, serveMux *http.ServeMux, prefix string) error { @@ -153,13 +156,17 @@ func (ws *WebsocketHandler) RegisterHandlers(ctx context.Context, serveMux *http } func (ws *WebsocketHandler) GetState(region string, accid primitive.ObjectID) string { - state, err := ws.RedisSync.HGet(context.Background(), region, accid.Hex()).Result() + state, err := ws.RedisSync.Get(context.Background(), accid.Hex()).Result() if err == redis.Nil { return "" } return state } +func (ws *WebsocketHandler) SetState(region string, accid primitive.ObjectID, state string) { + ws.RedisSync.SetArgs(context.Background(), accid.Hex(), state, redis.SetArgs{Mode: "XX"}).Result() +} + func (sh *subhandler) mainLoop(ctx context.Context) { defer func() { s := recover() @@ -281,8 +288,9 @@ func upgrade_core(sh *subhandler, conn *websocket.Conn, accid primitive.ObjectID } sh.connInOutChan <- newconn + sh.connWaitGroup.Add(1) go func(c *wsconn, accid primitive.ObjectID, deliveryChan chan<- any) { - sh.redisSync.HSet(context.Background(), sh.name, accid.Hex(), "online") + sh.redisSync.Set(context.Background(), accid.Hex(), "online", 0) for { messageType, r, err := c.NextReader() if err != nil { @@ -301,7 +309,8 @@ func upgrade_core(sh *subhandler, conn *websocket.Conn, accid primitive.ObjectID sh.callReceiver(accid, c.alias, BinaryMessage, r) } } - sh.redisSync.HDel(context.Background(), sh.name, accid.Hex()) + sh.redisSync.Del(context.Background(), accid.Hex()) + sh.connWaitGroup.Done() c.Conn = nil sh.connInOutChan <- c