From e88df26ed7d085cdfd89e7d9e21884ba0deb4e02 Mon Sep 17 00:00:00 2001 From: mountain Date: Wed, 30 Aug 2023 18:23:19 +0900 Subject: [PATCH] =?UTF-8?q?wshandler=EB=8F=84=20session.Consumer=EB=A1=9C?= =?UTF-8?q?=20=EA=B5=90=EC=B2=B4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- session/consumer_common.go | 4 +- session/consumer_mongo.go | 98 ++++++++++--- session/consumer_redis.go | 111 +++++++++------ session/session_test.go | 10 +- wshandler/wshandler.go | 285 +++++++++++++++---------------------- 5 files changed, 268 insertions(+), 240 deletions(-) diff --git a/session/consumer_common.go b/session/consumer_common.go index b9db40a..f8acba1 100644 --- a/session/consumer_common.go +++ b/session/consumer_common.go @@ -21,8 +21,8 @@ func make_cache_stage[T any]() *cache_stage[T] { } type Consumer interface { - Query(string) *Authorization - Touch(string) bool + Query(string) (*Authorization, error) + Touch(string) (*Authorization, error) } type consumer_common[T any] struct { diff --git a/session/consumer_mongo.go b/session/consumer_mongo.go index b34b1a3..b290e1c 100644 --- a/session/consumer_mongo.go +++ b/session/consumer_mongo.go @@ -19,7 +19,7 @@ type sessionMongo struct { } type consumer_mongo struct { - consumer_common[sessionMongo] + consumer_common[*sessionMongo] ids map[primitive.ObjectID]string mongoClient gocommon.MongoClient ttl time.Duration @@ -30,7 +30,7 @@ type sessionPipelineDocument struct { DocumentKey struct { Id primitive.ObjectID `bson:"_id"` } `bson:"documentKey"` - Session sessionMongo `bson:"fullDocument"` + Session *sessionMongo `bson:"fullDocument"` } func NewConsumerWithMongo(ctx context.Context, mongoUrl string, dbname string, ttl time.Duration) (Consumer, error) { @@ -40,10 +40,10 @@ func NewConsumerWithMongo(ctx context.Context, mongoUrl string, dbname string, t } consumer := &consumer_mongo{ - consumer_common: consumer_common[sessionMongo]{ + consumer_common: consumer_common[*sessionMongo]{ ttl: ttl, ctx: ctx, - stages: [2]*cache_stage[sessionMongo]{make_cache_stage[sessionMongo](), make_cache_stage[sessionMongo]()}, + stages: [2]*cache_stage[*sessionMongo]{make_cache_stage[*sessionMongo](), make_cache_stage[*sessionMongo]()}, startTime: time.Now(), }, ids: make(map[primitive.ObjectID]string), @@ -134,16 +134,13 @@ func NewConsumerWithMongo(ctx context.Context, mongoUrl string, dbname string, t return consumer, nil } -func (c *consumer_mongo) Query(key string) *Authorization { - c.lock.Lock() - defer c.lock.Unlock() - +func (c *consumer_mongo) query_internal(key string) (*sessionMongo, bool, error) { if _, deleted := c.stages[0].deleted[key]; deleted { - return nil + return nil, false, nil } if _, deleted := c.stages[1].deleted[key]; deleted { - return nil + return nil, false, nil } found, ok := c.stages[0].cache[key] @@ -151,11 +148,8 @@ func (c *consumer_mongo) Query(key string) *Authorization { found, ok = c.stages[1].cache[key] } - now := time.Now().UTC() if ok { - if now.Before(found.Ts.Time().Add(c.ttl)) { - return found.Authorization - } + return found, false, nil } var si sessionMongo @@ -165,21 +159,42 @@ func (c *consumer_mongo) Query(key string) *Authorization { if err != nil { logger.Println("consumer Query :", err) - return nil + return nil, false, err } if len(si.Key) > 0 { - c.add_internal(key, si) - return si.Authorization + siptr := &si + c.add_internal(key, siptr) + return siptr, true, nil } - return nil + return nil, false, nil } -func (c *consumer_mongo) Touch(key string) bool { +func (c *consumer_mongo) Query(key string) (*Authorization, error) { c.lock.Lock() defer c.lock.Unlock() - _, _, err := c.mongoClient.Update(session_collection_name, bson.M{ + si, _, err := c.query_internal(key) + if err != nil { + return nil, err + } + + if si == nil { + return nil, nil + } + + if time.Now().After(si.Ts.Time().Add(c.ttl)) { + return nil, nil + } + + return si.Authorization, nil +} + +func (c *consumer_mongo) Touch(key string) (*Authorization, error) { + c.lock.Lock() + defer c.lock.Unlock() + + worked, _, err := c.mongoClient.Update(session_collection_name, bson.M{ "key": key, }, bson.M{ "$currentDate": bson.M{ @@ -189,13 +204,50 @@ func (c *consumer_mongo) Touch(key string) bool { if err != nil { logger.Println("consumer Touch :", err) - return false + return nil, err } - return true + if !worked { + // 이미 만료되서 사라짐 + return nil, nil + } + + si, added, err := c.query_internal(key) + if err != nil { + return nil, err + } + + if si == nil { + return nil, nil + } + + if !added { + var doc struct { + sessionMongo `bson:",inline"` + Id primitive.ObjectID `bson:"_id"` + } + + err := c.mongoClient.FindOneAs(session_collection_name, bson.M{ + "key": key, + }, &doc) + + if err != nil { + logger.Println("consumer Query :", err) + return nil, err + } + + if len(si.Key) > 0 { + c.add_internal(key, &doc.sessionMongo) + c.ids[doc.Id] = key + + return doc.Authorization, nil + } + } + + return si.Authorization, nil } -func (c *consumer_mongo) add(key string, id primitive.ObjectID, si sessionMongo) { +func (c *consumer_mongo) add(key string, id primitive.ObjectID, si *sessionMongo) { c.lock.Lock() defer c.lock.Unlock() diff --git a/session/consumer_redis.go b/session/consumer_redis.go index e77de0c..1f7f68a 100644 --- a/session/consumer_redis.go +++ b/session/consumer_redis.go @@ -16,7 +16,7 @@ type sessionRedis struct { } type consumer_redis struct { - consumer_common[sessionRedis] + consumer_common[*sessionRedis] redisClient *redis.Client } @@ -27,10 +27,10 @@ func NewConsumerWithRedis(ctx context.Context, redisUrl string, ttl time.Duratio } consumer := &consumer_redis{ - consumer_common: consumer_common[sessionRedis]{ + consumer_common: consumer_common[*sessionRedis]{ ttl: ttl, ctx: ctx, - stages: [2]*cache_stage[sessionRedis]{make_cache_stage[sessionRedis](), make_cache_stage[sessionRedis]()}, + stages: [2]*cache_stage[*sessionRedis]{make_cache_stage[*sessionRedis](), make_cache_stage[*sessionRedis]()}, startTime: time.Now(), }, redisClient: redisClient, @@ -73,7 +73,7 @@ func NewConsumerWithRedis(ctx context.Context, redisUrl string, ttl time.Duratio } else if len(raw) > 0 { var si Authorization if bson.Unmarshal([]byte(raw), &si) == nil { - consumer.add(key, sessionRedis{ + consumer.add(key, &sessionRedis{ Authorization: &si, expireAt: time.Now().Add(consumer.ttl), }) @@ -90,16 +90,13 @@ func NewConsumerWithRedis(ctx context.Context, redisUrl string, ttl time.Duratio return consumer, nil } -func (c *consumer_redis) Query(key string) *Authorization { - c.lock.Lock() - defer c.lock.Unlock() - +func (c *consumer_redis) query_internal(key string) (*sessionRedis, bool, error) { if _, deleted := c.stages[0].deleted[key]; deleted { - return nil + return nil, false, nil } if _, deleted := c.stages[1].deleted[key]; deleted { - return nil + return nil, false, nil } found, ok := c.stages[0].cache[key] @@ -108,62 +105,92 @@ func (c *consumer_redis) Query(key string) *Authorization { } if ok { - if found.expireAt.After(time.Now()) { - return found.Authorization - } + return found, false, nil } payload, err := c.redisClient.Get(c.ctx, key).Result() if err == redis.Nil { - return nil + return nil, false, nil } else if err != nil { logger.Println("consumer Query :", err) - return nil + return nil, false, err } - if len(payload) > 0 { - var si Authorization - if bson.Unmarshal([]byte(payload), &si) == nil { - ttl, err := c.redisClient.TTL(c.ctx, key).Result() - if err != nil { - logger.Println("consumer Query :", err) - return nil - } - - c.add_internal(key, sessionRedis{ - Authorization: &si, - expireAt: time.Now().Add(ttl), - }) - return &si - } + if len(payload) == 0 { + return nil, false, nil } - return nil + + var auth Authorization + if err := bson.Unmarshal([]byte(payload), &auth); err != nil { + return nil, false, err + } + + ttl, err := c.redisClient.TTL(c.ctx, key).Result() + if err != nil { + logger.Println("consumer Query :", err) + return nil, false, err + } + + si := &sessionRedis{ + Authorization: &auth, + expireAt: time.Now().Add(ttl), + } + c.add_internal(key, si) + + return si, true, nil } -func (c *consumer_redis) Touch(key string) bool { +func (c *consumer_redis) Query(key string) (*Authorization, error) { + c.lock.Lock() + defer c.lock.Unlock() + + si, _, err := c.query_internal(key) + if err != nil { + return nil, err + } + + if si == nil { + return nil, nil + } + + if time.Now().After(si.expireAt) { + return nil, nil + } + + return si.Authorization, nil +} + +func (c *consumer_redis) Touch(key string) (*Authorization, error) { c.lock.Lock() defer c.lock.Unlock() ok, err := c.redisClient.Expire(c.ctx, key, c.ttl).Result() if err == redis.Nil { - return false + return nil, nil } else if err != nil { logger.Println("consumer Touch :", err) - return false + return nil, err } if ok { - newexpire := time.Now().Add(c.ttl) - found, ok := c.stages[0].cache[key] - if ok { - found.expireAt = newexpire + // redis에 살아있다. + si, added, err := c.query_internal(key) + if err != nil { + return nil, err } - found, ok = c.stages[1].cache[key] - if ok { - found.expireAt = newexpire + if si == nil { + return nil, nil } + + if !added { + si.expireAt = time.Now().Add(c.ttl) + // stage 0으로 옮기기 위해 add_internal을 다시 부름 + c.add_internal(key, si) + } + + return si.Authorization, nil } - return ok + return nil, nil } diff --git a/session/session_test.go b/session/session_test.go index 60f2c60..570b3ae 100644 --- a/session/session_test.go +++ b/session/session_test.go @@ -35,8 +35,11 @@ func TestExpTable(t *testing.T) { sk2 := primitive.NewObjectID().Hex() go func() { for { - logger.Println("query :", cs.Query(sk1)) - logger.Println("query :", cs.Query(sk2)) + q1, err := cs.Query(sk1) + logger.Println("query :", q1, err) + + q2, err := cs.Query(sk2) + logger.Println("query :", q2, err) time.Sleep(time.Second) } }() @@ -73,6 +76,7 @@ func TestExpTable(t *testing.T) { t.Error(err) } - logger.Println("queryf :", cs2.Query(sk2)) + q2, err := cs2.Query(sk2) + logger.Println("queryf :", q2, err) time.Sleep(20 * time.Second) } diff --git a/wshandler/wshandler.go b/wshandler/wshandler.go index ef36619..d623f1c 100644 --- a/wshandler/wshandler.go +++ b/wshandler/wshandler.go @@ -19,6 +19,7 @@ import ( "repositories.action2quare.com/ayo/gocommon" "repositories.action2quare.com/ayo/gocommon/flagx" "repositories.action2quare.com/ayo/gocommon/logger" + "repositories.action2quare.com/ayo/gocommon/session" "github.com/go-redis/redis/v8" "github.com/gorilla/websocket" @@ -94,7 +95,6 @@ const ( ) type Sender struct { - Region string Accid primitive.ObjectID Alias string disconnectedCallbacks map[string]func() @@ -118,8 +118,8 @@ func (s *Sender) PopDisconnectedCallback(name string) func() { type EventReceiver interface { OnClientMessageReceived(sender *Sender, messageType WebSocketMessageType, body io.Reader) - OnRoomCreated(region, name string) - OnRoomDestroyed(region, name string) + OnRoomCreated(name string) + OnRoomDestroyed(name string) } type send_msg_queue_elem struct { @@ -128,7 +128,7 @@ type send_msg_queue_elem struct { msg []byte } -type subhandler struct { +type WebsocketHandler struct { redisMsgChanName string redisCmdChanName string redisSync *redis.Client @@ -138,17 +138,12 @@ type subhandler struct { sendMsgChan chan send_msg_queue_elem callReceiver EventReceiver connWaitGroup sync.WaitGroup - region string receiverChain []EventReceiver -} - -// WebsocketHandler : -type WebsocketHandler struct { - subhandlers map[string]*subhandler + sessionConsumer session.Consumer } type wsConfig struct { - gocommon.RegionStorageConfig + gocommon.StorageAddr Maingate string `json:"maingate_service_url"` } @@ -165,62 +160,51 @@ func init() { gob.Register([]any{}) } -func NewWebsocketHandler() (*WebsocketHandler, error) { - subhandlers := make(map[string]*subhandler) - for region, cfg := range config.RegionStorage { - redisSync, err := gocommon.NewRedisClient(cfg.Redis["wshandler"]) - if err != nil { - return nil, err - } - - sendchan := make(chan send_msg_queue_elem, 1000) - go func() { - sender := func(elem *send_msg_queue_elem) { - defer func() { - r := recover() - if r != nil { - logger.Println(r) - } - }() - elem.to.WriteMessage(elem.mt, elem.msg) - } - - for elem := range sendchan { - sender(&elem) - } - }() - - sh := &subhandler{ - redisMsgChanName: fmt.Sprintf("_wsh_msg_%s_%d", region, redisSync.Options().DB), - redisCmdChanName: fmt.Sprintf("_wsh_cmd_%s_%d", region, redisSync.Options().DB), - redisSync: redisSync, - connInOutChan: make(chan *wsconn), - deliveryChan: make(chan any, 1000), - localDeliveryChan: make(chan any, 100), - sendMsgChan: sendchan, - region: region, - } - - subhandlers[region] = sh +func NewWebsocketHandler(consumer session.Consumer) (*WebsocketHandler, error) { + redisSync, err := gocommon.NewRedisClient(config.Redis["wshandler"]) + if err != nil { + return nil, err } + sendchan := make(chan send_msg_queue_elem, 1000) + go func() { + sender := func(elem *send_msg_queue_elem) { + defer func() { + r := recover() + if r != nil { + logger.Println(r) + } + }() + elem.to.WriteMessage(elem.mt, elem.msg) + } + + for elem := range sendchan { + sender(&elem) + } + }() + return &WebsocketHandler{ - subhandlers: subhandlers, + redisMsgChanName: fmt.Sprintf("_wsh_msg_%d", redisSync.Options().DB), + redisCmdChanName: fmt.Sprintf("_wsh_cmd_%d", redisSync.Options().DB), + redisSync: redisSync, + connInOutChan: make(chan *wsconn), + deliveryChan: make(chan any, 1000), + localDeliveryChan: make(chan any, 100), + sendMsgChan: sendchan, + sessionConsumer: consumer, }, nil } -func (ws *WebsocketHandler) RegisterReceiver(region string, receiver EventReceiver) { - if sh := ws.subhandlers[region]; sh != nil { - sh.receiverChain = append(sh.receiverChain, receiver) - } +func (ws *WebsocketHandler) RegisterReceiver(receiver EventReceiver) { + ws.receiverChain = append(ws.receiverChain, receiver) } type nilReceiver struct{} func (r *nilReceiver) OnClientMessageReceived(sender *Sender, messageType WebSocketMessageType, body io.Reader) { } -func (r *nilReceiver) OnRoomCreated(region, name string) {} -func (r *nilReceiver) OnRoomDestroyed(region, name string) {} +func (r *nilReceiver) OnRoomCreated(name string) {} +func (r *nilReceiver) OnRoomDestroyed(name string) {} type chainReceiver struct { chain []EventReceiver @@ -232,100 +216,77 @@ func (r *chainReceiver) OnClientMessageReceived(sender *Sender, messageType WebS } } -func (r *chainReceiver) OnRoomCreated(region, name string) { +func (r *chainReceiver) OnRoomCreated(name string) { for _, cr := range r.chain { - cr.OnRoomCreated(region, name) + cr.OnRoomCreated(name) } } -func (r *chainReceiver) OnRoomDestroyed(region, name string) { +func (r *chainReceiver) OnRoomDestroyed(name string) { for _, cr := range r.chain { - cr.OnRoomDestroyed(region, name) + cr.OnRoomDestroyed(name) } } func (ws *WebsocketHandler) Start(ctx context.Context) { - for _, sh := range ws.subhandlers { - chain := sh.receiverChain - if len(chain) == 0 { - sh.callReceiver = &nilReceiver{} - } else if len(chain) == 1 { - sh.callReceiver = chain[0] - } else { - sh.callReceiver = &chainReceiver{chain: sh.receiverChain} - } - - sh.connWaitGroup.Add(1) - go sh.mainLoop(ctx) + chain := ws.receiverChain + if len(chain) == 0 { + ws.callReceiver = &nilReceiver{} + } else if len(chain) == 1 { + ws.callReceiver = chain[0] + } else { + ws.callReceiver = &chainReceiver{chain: ws.receiverChain} } + + ws.connWaitGroup.Add(1) + go ws.mainLoop(ctx) } func (ws *WebsocketHandler) Cleanup() { - for _, sh := range ws.subhandlers { - sh.connWaitGroup.Wait() - } + ws.connWaitGroup.Wait() } func (ws *WebsocketHandler) RegisterHandlers(serveMux *http.ServeMux, prefix string) error { - for region, sh := range ws.subhandlers { - if region == "default" { - region = "" - } - url := gocommon.MakeHttpHandlerPattern(prefix, region, "ws") - if *noAuthFlag { - serveMux.HandleFunc(url, sh.upgrade_nosession) - } else { - serveMux.HandleFunc(url, sh.upgrade) - } + url := gocommon.MakeHttpHandlerPattern(prefix, "ws") + if *noAuthFlag { + serveMux.HandleFunc(url, ws.upgrade_nosession) + } else { + serveMux.HandleFunc(url, ws.upgrade) } return nil } -func (ws *WebsocketHandler) GetState(region string, accid primitive.ObjectID) string { - if sh := ws.subhandlers[region]; sh != nil { - state, _ := sh.redisSync.Get(context.Background(), accid.Hex()).Result() - return state - } - return "" +func (ws *WebsocketHandler) GetState(accid primitive.ObjectID) string { + state, _ := ws.redisSync.Get(context.Background(), accid.Hex()).Result() + return state } -func (ws *WebsocketHandler) SetState(region string, accid primitive.ObjectID, state string) { - if sh := ws.subhandlers[region]; sh != nil { - sh.redisSync.SetArgs(context.Background(), accid.Hex(), state, redis.SetArgs{Mode: "XX"}).Result() +func (ws *WebsocketHandler) SetState(accid primitive.ObjectID, state string) { + ws.redisSync.SetArgs(context.Background(), accid.Hex(), state, redis.SetArgs{Mode: "XX"}).Result() +} + +func (ws *WebsocketHandler) SendUpstreamMessage(msg *UpstreamMessage) { + ws.localDeliveryChan <- msg +} + +func (ws *WebsocketHandler) EnterRoom(room string, accid primitive.ObjectID) { + ws.localDeliveryChan <- &commandMessage{ + Cmd: commandType_EnterRoom, + Args: []any{room, accid}, } } -func (ws *WebsocketHandler) SendUpstreamMessage(region string, msg *UpstreamMessage) { - sh := ws.subhandlers[region] - if sh != nil { - sh.localDeliveryChan <- msg +func (ws *WebsocketHandler) LeaveRoom(room string, accid primitive.ObjectID) { + ws.localDeliveryChan <- &commandMessage{ + Cmd: commandType_LeaveRoom, + Args: []any{room, accid}, } } -func (ws *WebsocketHandler) EnterRoom(region string, room string, accid primitive.ObjectID) { - sh := ws.subhandlers[region] - if sh != nil { - sh.localDeliveryChan <- &commandMessage{ - Cmd: commandType_EnterRoom, - Args: []any{room, accid}, - } - } -} - -func (ws *WebsocketHandler) LeaveRoom(region string, room string, accid primitive.ObjectID) { - sh := ws.subhandlers[region] - if sh != nil { - sh.localDeliveryChan <- &commandMessage{ - Cmd: commandType_LeaveRoom, - Args: []any{room, accid}, - } - } -} - -func (sh *subhandler) mainLoop(ctx context.Context) { +func (ws *WebsocketHandler) mainLoop(ctx context.Context) { defer func() { - sh.connWaitGroup.Done() + ws.connWaitGroup.Done() s := recover() if s != nil { logger.Error(s) @@ -337,7 +298,7 @@ func (sh *subhandler) mainLoop(ctx context.Context) { var pubsub *redis.PubSub for { if pubsub == nil { - pubsub = sh.redisSync.Subscribe(ctx, sh.redisMsgChanName, sh.redisCmdChanName) + pubsub = ws.redisSync.Subscribe(ctx, ws.redisMsgChanName, ws.redisCmdChanName) } raw, err := pubsub.ReceiveMessage(ctx) @@ -345,17 +306,17 @@ func (sh *subhandler) mainLoop(ctx context.Context) { buffer := bytes.NewBuffer([]byte(raw.Payload)) dec := gob.NewDecoder(buffer) - if raw.Channel == sh.redisMsgChanName { + if raw.Channel == ws.redisMsgChanName { var msg UpstreamMessage if err := dec.Decode(&msg); err == nil { - sh.deliveryChan <- &msg + ws.deliveryChan <- &msg } else { logger.Println("decode UpstreamMessage failed :", err) } - } else if raw.Channel == sh.redisCmdChanName { + } else if raw.Channel == ws.redisCmdChanName { var cmd commandMessage if err := dec.Decode(&cmd); err == nil { - sh.deliveryChan <- &cmd + ws.deliveryChan <- &cmd } else { logger.Println("decode UpstreamMessage failed :", err) } @@ -378,10 +339,10 @@ func (sh *subhandler) mainLoop(ctx context.Context) { findRoom := func(name string, create bool) *room { room := rooms[name] if room == nil && create { - room = makeRoom(name, roomDestroyChan, sh.sendMsgChan) + room = makeRoom(name, roomDestroyChan, ws.sendMsgChan) rooms[name] = room room.start(ctx) - go sh.callReceiver.OnRoomCreated(sh.region, name) + go ws.callReceiver.OnRoomCreated(name) } return room } @@ -393,7 +354,7 @@ func (sh *subhandler) mainLoop(ctx context.Context) { roomnames = append(roomnames, room.name) } bt, _ := json.Marshal(roomnames) - sh.callReceiver.OnClientMessageReceived(conn.sender, Disconnected, bytes.NewBuffer(bt)) + ws.callReceiver.OnClientMessageReceived(conn.sender, Disconnected, bytes.NewBuffer(bt)) conn.Close() } }() @@ -431,7 +392,7 @@ func (sh *subhandler) mainLoop(ctx context.Context) { Body: usermsg.Body, Tag: usermsg.Tag, }) - sh.sendMsgChan <- send_msg_queue_elem{ + ws.sendMsgChan <- send_msg_queue_elem{ to: conn, mt: websocket.TextMessage, msg: ds, @@ -489,9 +450,9 @@ func (sh *subhandler) mainLoop(ctx context.Context) { case destroyedRoom := <-roomDestroyChan: delete(rooms, destroyedRoom) - go sh.callReceiver.OnRoomDestroyed(sh.region, destroyedRoom) + go ws.callReceiver.OnRoomDestroyed(destroyedRoom) - case usermsg := <-sh.localDeliveryChan: + case usermsg := <-ws.localDeliveryChan: // 로컬에 connection이 있는지 먼저 확인해 보기 위한 채널 // 없으면 publish한다. switch usermsg := usermsg.(type) { @@ -501,7 +462,7 @@ func (sh *subhandler) mainLoop(ctx context.Context) { enc := gob.NewEncoder(buffer) var err error if err = enc.Encode(usermsg); err == nil { - _, err = sh.redisSync.Publish(context.Background(), sh.redisMsgChanName, buffer.Bytes()).Result() + _, err = ws.redisSync.Publish(context.Background(), ws.redisMsgChanName, buffer.Bytes()).Result() } if err != nil { @@ -520,7 +481,7 @@ func (sh *subhandler) mainLoop(ctx context.Context) { var err error enc := gob.NewEncoder(buffer) if err = enc.Encode(usermsg); err == nil { - _, err = sh.redisSync.Publish(context.Background(), sh.redisCmdChanName, buffer.Bytes()).Result() + _, err = ws.redisSync.Publish(context.Background(), ws.redisCmdChanName, buffer.Bytes()).Result() } if err != nil { logger.Println("gob.Encode or Publish failed :", err) @@ -528,7 +489,7 @@ func (sh *subhandler) mainLoop(ctx context.Context) { } } - case usermsg := <-sh.deliveryChan: + case usermsg := <-ws.deliveryChan: switch usermsg := usermsg.(type) { case *UpstreamMessage: target := usermsg.Target @@ -553,7 +514,7 @@ func (sh *subhandler) mainLoop(ctx context.Context) { logger.Println("usermsg is unknown type") } - case c := <-sh.connInOutChan: + case c := <-ws.connInOutChan: if c.Conn == nil { delete(entireConns, c.sender.Accid.Hex()) var roomnames []string @@ -564,29 +525,28 @@ func (sh *subhandler) mainLoop(ctx context.Context) { c.joinedRooms = nil bt, _ := json.Marshal(roomnames) - go sh.callReceiver.OnClientMessageReceived(c.sender, Disconnected, bytes.NewBuffer(bt)) + go ws.callReceiver.OnClientMessageReceived(c.sender, Disconnected, bytes.NewBuffer(bt)) } else { entireConns[c.sender.Accid.Hex()] = c - go sh.callReceiver.OnClientMessageReceived(c.sender, Connected, nil) + go ws.callReceiver.OnClientMessageReceived(c.sender, Connected, nil) } } } } -func upgrade_core(sh *subhandler, conn *websocket.Conn, accid primitive.ObjectID, alias string) { +func upgrade_core(ws *WebsocketHandler, conn *websocket.Conn, accid primitive.ObjectID, alias string) { newconn := &wsconn{ Conn: conn, sender: &Sender{ - Region: sh.region, - Alias: alias, - Accid: accid, + Alias: alias, + Accid: accid, }, } - sh.connInOutChan <- newconn + ws.connInOutChan <- newconn - sh.connWaitGroup.Add(1) + ws.connWaitGroup.Add(1) go func(c *wsconn, accid primitive.ObjectID, deliveryChan chan<- any) { - sh.redisSync.Set(context.Background(), accid.Hex(), "online", 0) + ws.redisSync.Set(context.Background(), accid.Hex(), "online", 0) for { messageType, r, err := c.NextReader() if err != nil { @@ -600,9 +560,9 @@ func upgrade_core(sh *subhandler, conn *websocket.Conn, accid primitive.ObjectID if messageType == websocket.TextMessage { // 유저가 직접 보낸 메시지 - sh.callReceiver.OnClientMessageReceived(c.sender, TextMessage, r) + ws.callReceiver.OnClientMessageReceived(c.sender, TextMessage, r) } else if messageType == websocket.BinaryMessage { - sh.callReceiver.OnClientMessageReceived(c.sender, BinaryMessage, r) + ws.callReceiver.OnClientMessageReceived(c.sender, BinaryMessage, r) } } if c.sender.disconnectedCallbacks != nil { @@ -611,15 +571,15 @@ func upgrade_core(sh *subhandler, conn *websocket.Conn, accid primitive.ObjectID } } - sh.redisSync.Del(context.Background(), accid.Hex()) - sh.connWaitGroup.Done() + ws.redisSync.Del(context.Background(), accid.Hex()) + ws.connWaitGroup.Done() c.Conn = nil - sh.connInOutChan <- c - }(newconn, accid, sh.deliveryChan) + ws.connInOutChan <- c + }(newconn, accid, ws.deliveryChan) } -func (sh *subhandler) upgrade_nosession(w http.ResponseWriter, r *http.Request) { +func (ws *WebsocketHandler) upgrade_nosession(w http.ResponseWriter, r *http.Request) { // 클라이언트 접속 defer func() { s := recover() @@ -665,10 +625,10 @@ func (sh *subhandler) upgrade_nosession(w http.ResponseWriter, r *http.Request) alias = accid.Hex() } - upgrade_core(sh, conn, accid, alias) + upgrade_core(ws, conn, accid, alias) } -func (sh *subhandler) upgrade(w http.ResponseWriter, r *http.Request) { +func (ws *WebsocketHandler) upgrade(w http.ResponseWriter, r *http.Request) { // 클라이언트 접속 defer func() { s := recover() @@ -680,27 +640,12 @@ func (sh *subhandler) upgrade(w http.ResponseWriter, r *http.Request) { }() sk := r.Header.Get("AS-X-SESSION") - auth := r.Header.Get("Authorization") - - req, _ := http.NewRequest("GET", fmt.Sprintf("%s/query?sk=%s", config.Maingate, sk), nil) - req.Header.Add("Authorization", auth) - - client := http.Client{} - resp, err := client.Do(req) + authinfo, err := ws.sessionConsumer.Query(sk) if err != nil { w.WriteHeader(http.StatusInternalServerError) logger.Error("authorize query failed :", err) return } - defer resp.Body.Close() - - var authinfo gocommon.Authinfo - dec := json.NewDecoder(resp.Body) - if err = dec.Decode(&authinfo); err != nil { - w.WriteHeader(http.StatusInternalServerError) - logger.Error("authorize query failed :", err) - return - } var upgrader = websocket.Upgrader{} // use default options conn, err := upgrader.Upgrade(w, r, nil) @@ -714,8 +659,8 @@ func (sh *subhandler) upgrade(w http.ResponseWriter, r *http.Request) { vt, _ := base64.StdEncoding.DecodeString(v) alias = string(vt) } else { - alias = authinfo.Accid.Hex() + alias = authinfo.Account.Hex() } - upgrade_core(sh, conn, authinfo.Accid, alias) + upgrade_core(ws, conn, authinfo.Account, alias) }