diff --git a/rpc/rpc.go b/rpc/rpc.go new file mode 100644 index 0000000..3bcff82 --- /dev/null +++ b/rpc/rpc.go @@ -0,0 +1,217 @@ +package rpc + +import ( + "bytes" + "context" + "crypto/md5" + "encoding/gob" + "encoding/hex" + "errors" + "fmt" + "path" + "reflect" + "runtime" + "strings" + "time" + + "github.com/go-redis/redis/v8" + "go.mongodb.org/mongo-driver/bson/primitive" + "repositories.action2quare.com/ayo/gocommon/logger" +) + +type Receiver interface { + TargetExists(primitive.ObjectID) bool +} + +type receiverManifest struct { + r Receiver + methods map[string]reflect.Method +} + +type rpcEngine struct { + receivers map[string]receiverManifest + publish func([]byte) error +} + +var engine = rpcEngine{ + receivers: make(map[string]receiverManifest), +} + +func RegistReceiver(ptr Receiver) { + rname := reflect.TypeOf(ptr).Elem().Name() + rname = fmt.Sprintf("(*%s)", rname) + + methods := make(map[string]reflect.Method) + for i := 0; i < reflect.TypeOf(ptr).NumMethod(); i++ { + method := reflect.TypeOf(ptr).Method(i) + methods[method.Name] = method + } + engine.receivers[rname] = receiverManifest{ + r: ptr, + methods: methods, + } +} + +func Start(ctx context.Context, redisClient *redis.Client) { + if engine.publish != nil { + return + } + + hash := md5.New() + for k, manifest := range engine.receivers { + hash.Write([]byte(k)) + for m, r := range manifest.methods { + hash.Write([]byte(m)) + hash.Write([]byte(r.Name)) + for i := 0; i < r.Type.NumIn(); i++ { + inName := r.Type.In(i).Name() + hash.Write([]byte(inName)) + } + } + } + + pubsubName := hex.EncodeToString(hash.Sum(nil))[:16] + + engine.publish = func(s []byte) error { + _, err := redisClient.Publish(ctx, pubsubName, s).Result() + return err + } + + go engine.loop(ctx, redisClient, pubsubName) +} + +func (re *rpcEngine) callFromMessage(msg *redis.Message) { + defer func() { + r := recover() + if r != nil { + logger.Error(r) + } + }() + + encoded := []byte(msg.Payload) + var target primitive.ObjectID + copy(target[:], encoded[:12]) + + encoded = encoded[12:] + for i, c := range encoded { + if c == ')' { + if manifest, ok := re.receivers[string(encoded[:i+1])]; ok { + // 리시버 찾음 + if manifest.r.TargetExists(target) { + // 이 리시버가 타겟을 가지고 있음 + encoded = encoded[i+1:] + decoder := gob.NewDecoder(bytes.NewBuffer(encoded)) + var params []any + if decoder.Decode(¶ms) == nil { + method := manifest.methods[params[0].(string)] + args := []reflect.Value{ + reflect.ValueOf(manifest.r), + } + for _, arg := range params[1:] { + args = append(args, reflect.ValueOf(arg)) + } + method.Func.Call(args) + } + } + } + } + } +} + +func (re *rpcEngine) loop(ctx context.Context, redisClient *redis.Client, chanName string) { + defer func() { + r := recover() + if r != nil { + logger.Error(r) + } + }() + + pubsub := redisClient.Subscribe(ctx, chanName) + for { + if ctx.Err() != nil { + return + } + + if pubsub == nil { + pubsub = redisClient.Subscribe(ctx, chanName) + } + + msg, err := pubsub.ReceiveMessage(ctx) + + if err != nil { + if err == redis.ErrClosed { + time.Sleep(time.Second) + } + pubsub = nil + } else { + re.callFromMessage(msg) + } + } +} + +var errNoReceiver = errors.New("no receiver") + +type CallContext struct { + r Receiver + t primitive.ObjectID +} + +var ErrCanExecuteHere = errors.New("go ahead") + +func (c *CallContext) Call(args ...any) error { + if c.r.TargetExists(c.t) { + // 여기 있네? + return ErrCanExecuteHere + } + + pc := make([]uintptr, 1) + n := runtime.Callers(3, pc[:]) + if n < 1 { + return errNoReceiver + } + + frame, _ := runtime.CallersFrames(pc).Next() + fullname := path.Base(frame.Function) + prf := strings.Split(fullname, ".") + rname := prf[1] + funcname := prf[2] + + serialized, err := encode(c.t, rname, funcname, args...) + if err != nil { + return err + } + + return engine.publish(serialized) +} + +func Make(r Receiver) *CallContext { + return &CallContext{ + r: r, + } +} + +func (cc *CallContext) To(target primitive.ObjectID) *CallContext { + cc.t = target + return cc +} + +func encode(target primitive.ObjectID, receiver string, funcname string, args ...any) ([]byte, error) { + buff := new(bytes.Buffer) + + // 타겟을 가장 먼저 기록 + buff.Write(target[:]) + + // receiver + buff.Write([]byte(receiver)) + + // 다음 call context 기록 + m := append([]any{funcname}, args...) + encoder := gob.NewEncoder(buff) + err := encoder.Encode(m) + if err != nil { + logger.Error("rpcCallContext.send err :", err) + return nil, err + } + + return buff.Bytes(), nil +} diff --git a/rpc/rpc_test.go b/rpc/rpc_test.go new file mode 100644 index 0000000..ccc3071 --- /dev/null +++ b/rpc/rpc_test.go @@ -0,0 +1,48 @@ +package rpc + +import ( + "context" + "math/rand" + "testing" + "time" + + "go.mongodb.org/mongo-driver/bson/primitive" + "repositories.action2quare.com/ayo/gocommon" + "repositories.action2quare.com/ayo/gocommon/logger" +) + +type testReceiver struct { +} + +func (tr *testReceiver) TargetExists(tid primitive.ObjectID) bool { + logger.Println(tid.Hex()) + return tid[0] >= 10 +} + +func (tr *testReceiver) TestFunc(a string, b string, c int) { + target := primitive.NewObjectID() + target[0] = byte(rand.Intn(2) * 20) + if Make(tr).To(target).Call(a, b, c) != ErrCanExecuteHere { + return + } + + logger.Println(" ", a, b, target[0]) +} + +func TestRpc(t *testing.T) { + var tr testReceiver + RegistReceiver(&tr) + myctx, cancel := context.WithCancel(context.Background()) + + redisClient, _ := gocommon.NewRedisClient("redis://192.168.8.94:6379", 0) + go func() { + for { + tr.TestFunc("aaaa", "bbbb", 333) + time.Sleep(time.Second) + } + }() + + Start(myctx, redisClient) + <-myctx.Done() + cancel() +} diff --git a/wshandler/room.go b/wshandler/room.go new file mode 100644 index 0000000..762e165 --- /dev/null +++ b/wshandler/room.go @@ -0,0 +1,85 @@ +package wshandler + +import ( + "context" + "encoding/json" + + "github.com/gorilla/websocket" + "repositories.action2quare.com/ayo/gocommon/logger" +) + +type room struct { + inChan chan *wsconn + outChan chan *wsconn + messageChan chan *UpstreamMessage + name string +} + +func makeRoom(name string) *room { + return &room{ + inChan: make(chan *wsconn, 10), + outChan: make(chan *wsconn, 10), + messageChan: make(chan *UpstreamMessage, 100), + name: name, + } +} + +func (r *room) broadcast(msg *UpstreamMessage) { + r.messageChan <- msg +} + +func (r *room) in(conn *wsconn) { + r.inChan <- conn +} + +func (r *room) out(conn *wsconn) { + r.outChan <- conn +} + +func (r *room) start(ctx context.Context) { + go func(ctx context.Context) { + conns := make(map[string]*wsconn) + normal := false + for !normal { + normal = r.loop(ctx, &conns) + } + }(ctx) +} + +func (r *room) loop(ctx context.Context, conns *map[string]*wsconn) (normalEnd bool) { + defer func() { + s := recover() + if s != nil { + logger.Error(s) + normalEnd = false + } + }() + + tag := "#" + r.name + for { + select { + case <-ctx.Done(): + return true + + case conn := <-r.inChan: + (*conns)[conn.sender.Accid.Hex()] = conn + + case conn := <-r.outChan: + delete((*conns), conn.sender.Accid.Hex()) + + case msg := <-r.messageChan: + ds := DownstreamMessage{ + Alias: msg.Alias, + Body: msg.Body, + Tag: append(msg.Tag, tag), + } + bt, _ := json.Marshal(ds) + + for _, conn := range *conns { + writer, _ := conn.NextWriter(websocket.TextMessage) + writer.Write(bt) + writer.Close() + } + } + } +} diff --git a/wshandler/wshandler.go b/wshandler/wshandler.go index 0ad7fb0..e4427c0 100644 --- a/wshandler/wshandler.go +++ b/wshandler/wshandler.go @@ -7,505 +7,460 @@ import ( "fmt" "io" "net/http" - "os" "strings" "sync" + "time" - common "repositories.action2quare.com/ayo/gocommon" + "go.mongodb.org/mongo-driver/bson/primitive" + "repositories.action2quare.com/ayo/gocommon" "repositories.action2quare.com/ayo/gocommon/flagx" "repositories.action2quare.com/ayo/gocommon/logger" "github.com/go-redis/redis/v8" "github.com/gorilla/websocket" - "go.mongodb.org/mongo-driver/bson" - "go.mongodb.org/mongo-driver/bson/primitive" ) var noSessionFlag = flagx.Bool("nosession", false, "nosession=[true|false]") +type wsconn struct { + *websocket.Conn + sender *Sender +} + +type UpstreamMessage struct { + Alias string + Accid primitive.ObjectID + Target string + Body any + Tag []string +} + +type DownstreamMessage struct { + Alias string `json:",omitempty"` + Body any `json:",omitempty"` + Tag []string `json:",omitempty"` +} + +type commandType string + const ( - connStateCachePrefix = "conn_state_" - connStateScript = ` - local hosts = redis.call('keys',KEYS[1]) - for index, key in ipairs(hosts) do - local ok = redis.call('hexists', key, KEYS[2]) - if ok == 1 then - return redis.call('hget', key, KEYS[2]) - end - end - return "" - ` + commandType_JoinRoom = commandType("join_room") + commandType_LeaveRoom = commandType("leave_room") + commandType_WriteControl = commandType("write_control") ) -var ConnStateCacheKey = func() string { - hn, _ := os.Hostname() - return connStateCachePrefix + hn -}() - -type Richconn struct { - *websocket.Conn - lock sync.Mutex - alias primitive.ObjectID - tags []string - onClose map[string]func() +type commandMessage struct { + Cmd commandType + Args []any } -func (rc *Richconn) AddTag(name, val string) { - rc.lock.Lock() - defer rc.lock.Unlock() +type WebSocketMessageType int - prefix := name + "=" - for i, tag := range rc.tags { - if strings.HasPrefix(tag, prefix) { - rc.tags[i] = prefix + val - return - } - } - rc.tags = append(rc.tags, prefix+val) +const ( + TextMessage = WebSocketMessageType(websocket.TextMessage) + BinaryMessage = WebSocketMessageType(websocket.BinaryMessage) + CloseMessage = WebSocketMessageType(websocket.CloseMessage) + PingMessage = WebSocketMessageType(websocket.PingMessage) + PongMessage = WebSocketMessageType(websocket.PongMessage) + Connected = WebSocketMessageType(100) + Disconnected = WebSocketMessageType(101) +) + +type Sender struct { + Region string + Accid primitive.ObjectID + Alias string } -func (rc *Richconn) GetTag(name string) string { - rc.lock.Lock() - defer rc.lock.Unlock() - - prefix := name + "=" - for _, tag := range rc.tags { - if strings.HasPrefix(tag, prefix) { - return tag[len(prefix):] - } - } - return "" -} - -func (rc *Richconn) RemoveTag(name string, val string) { - rc.lock.Lock() - defer rc.lock.Unlock() - - whole := fmt.Sprintf("%s=%s", name, val) - for i, tag := range rc.tags { - if tag == whole { - if i == 0 && len(rc.tags) == 1 { - rc.tags = nil - } else { - lastidx := len(rc.tags) - 1 - if i < lastidx { - rc.tags[i] = rc.tags[lastidx] - } - rc.tags = rc.tags[:lastidx] - } - return - } - } -} - -func (rc *Richconn) RegistOnCloseFunc(name string, f func()) { - rc.lock.Lock() - defer rc.lock.Unlock() - - if rc.onClose == nil { - f() - return - } - rc.onClose[name] = f -} - -func (rc *Richconn) HasOnCloseFunc(name string) bool { - rc.lock.Lock() - defer rc.lock.Unlock() - - if rc.onClose == nil { - return false - } - - _, ok := rc.onClose[name] - return ok -} - -func (rc *Richconn) UnregistOnCloseFunc(name string) (out func()) { - rc.lock.Lock() - defer rc.lock.Unlock() - - if rc.onClose == nil { - return - } - out = rc.onClose[name] - delete(rc.onClose, name) - return -} - -func (rc *Richconn) WriteBytes(data []byte) error { - rc.lock.Lock() - defer rc.lock.Unlock() - return rc.WriteMessage(websocket.TextMessage, data) -} - -type DeliveryMessage struct { - Alias primitive.ObjectID - Body []byte - Command string - Conn *Richconn -} - -func (dm *DeliveryMessage) Parse(out any) error { - return json.Unmarshal(dm.Body, out) -} - -func (dm *DeliveryMessage) MarshalBinary() (data []byte, err error) { - return append(dm.Alias[:], dm.Body...), nil -} - -func (dm *DeliveryMessage) UnmarshalBinary(data []byte) error { - copy(dm.Alias[:], data[:12]) - dm.Body = data[12:] - return nil -} - -type tagconn struct { - rc *Richconn - state string -} -type tagconnsmap = map[primitive.ObjectID]*tagconn -type tagconns struct { - sync.Mutex - tagconnsmap -} +type WebSocketMessageReceiver func(sender *Sender, messageType WebSocketMessageType, body io.Reader) type subhandler struct { - sync.Mutex - authCache *common.AuthCollection - conns map[primitive.ObjectID]*Richconn - aliases map[primitive.ObjectID]primitive.ObjectID - tags map[primitive.ObjectID]*tagconns - deliveryChan chan DeliveryMessage - url string - redisSync *redis.Client + authCache *gocommon.AuthCollection + redisMsgChanName string + redisCmdChanName string + redisSync *redis.Client + connInOutChan chan *wsconn + deliveryChan chan any + localDeliveryChan chan any + callReceiver WebSocketMessageReceiver + connWaitGroup sync.WaitGroup + region string } // WebsocketHandler : type WebsocketHandler struct { - authCaches map[string]*subhandler - RedisSync *redis.Client + authCaches map[string]*subhandler + RedisSync *redis.Client + receiverChain map[string][]WebSocketMessageReceiver } type wsConfig struct { SyncPipeline string `json:"ws_sync_pipeline"` } -func NewWebsocketHandler(authglobal common.AuthCollectionGlobal) (wsh *WebsocketHandler) { +func NewWebsocketHandler(authglobal gocommon.AuthCollectionGlobal) (wsh *WebsocketHandler) { + var config wsConfig + gocommon.LoadConfig(&config) + + redisSync, err := gocommon.NewRedisClient(config.SyncPipeline, 0) + if err != nil { + panic(err) + } + + // decoder := func(r io.Reader) *T { + // if r == nil { + // // 접속이 끊겼을 때. + // return nil + // } + // var m T + // dec := json.NewDecoder(r) + // if err := dec.Decode(&m); err != nil { + // logger.Println(err) + // } + + // // decoding 실패하더라도 빈 *T를 내보냄 + // return &m + // } + authCaches := make(map[string]*subhandler) for _, region := range authglobal.Regions() { sh := &subhandler{ - authCache: authglobal.Get(region), - conns: make(map[primitive.ObjectID]*Richconn), - aliases: make(map[primitive.ObjectID]primitive.ObjectID), - tags: make(map[primitive.ObjectID]*tagconns), - deliveryChan: make(chan DeliveryMessage, 1000), + authCache: authglobal.Get(region), + redisMsgChanName: fmt.Sprintf("_wsh_msg_%s", region), + redisCmdChanName: fmt.Sprintf("_wsh_cmd_%s", region), + redisSync: redisSync, + connInOutChan: make(chan *wsconn), + deliveryChan: make(chan any, 1000), + localDeliveryChan: make(chan any, 100), + region: region, } authCaches[region] = sh } - var config wsConfig - common.LoadConfig(&config) - - redisSync, err := common.NewRedisClient(config.SyncPipeline, 0) - if err != nil { - panic(err) - } - return &WebsocketHandler{ - authCaches: authCaches, - RedisSync: redisSync, + authCaches: authCaches, + RedisSync: redisSync, + receiverChain: make(map[string][]WebSocketMessageReceiver), } } -func (ws *WebsocketHandler) Destructor() { - if ws.RedisSync != nil { - ws.RedisSync.Del(context.Background(), ConnStateCacheKey) - } +func (ws *WebsocketHandler) RegisterReceiver(region string, receiver WebSocketMessageReceiver) { + ws.receiverChain[region] = append(ws.receiverChain[region], receiver) } -func (ws *WebsocketHandler) DeliveryChannel(region string) <-chan DeliveryMessage { - return ws.authCaches[region].deliveryChan -} - -func (ws *WebsocketHandler) Conn(region string, alias primitive.ObjectID) *Richconn { - if sh := ws.authCaches[region]; sh != nil { - return sh.conns[alias] - } - return nil -} - -func (ws *WebsocketHandler) JoinTag(region string, tag primitive.ObjectID, tid primitive.ObjectID, rc *Richconn, hint string) error { - if sh := ws.authCaches[region]; sh != nil { - sh.joinTag(tag, tid, rc, hint) - } - return nil -} - -func (ws *WebsocketHandler) LeaveTag(region string, tag primitive.ObjectID, tid primitive.ObjectID) error { - if sh := ws.authCaches[region]; sh != nil { - sh.leaveTag(tag, tid) - } - return nil -} - -func (ws *WebsocketHandler) SetStateInTag(region string, tag primitive.ObjectID, tid primitive.ObjectID, state string, hint string) error { - if sh := ws.authCaches[region]; sh != nil { - sh.setStateInTag(tag, tid, state, hint) - } - return nil -} - -func (ws *WebsocketHandler) BroadcastRaw(region string, tag primitive.ObjectID, raw []byte) { - if sh := ws.authCaches[region]; sh != nil { - if cs := sh.cloneTag(tag); len(cs) > 0 { - go func(raw []byte) { - for _, c := range cs { - if c != nil { - c.WriteBytes(raw) - } +func (ws *WebsocketHandler) Start(ctx context.Context) { + for region, sh := range ws.authCaches { + chain := ws.receiverChain[region] + if len(chain) == 0 { + sh.callReceiver = func(sender *Sender, messageType WebSocketMessageType, body io.Reader) {} + } else if len(chain) == 1 { + sh.callReceiver = chain[0] + } else { + sh.callReceiver = func(sender *Sender, messageType WebSocketMessageType, body io.Reader) { + for _, r := range chain { + r(sender, messageType, body) } - }(raw) + } } + + go sh.mainLoop(ctx) } } -func (ws *WebsocketHandler) Broadcast(region string, tag primitive.ObjectID, doc bson.M) { - raw, _ := json.Marshal(doc) - ws.BroadcastRaw(region, tag, raw) +func (ws *WebsocketHandler) Cleanup() { + for _, sh := range ws.authCaches { + sh.connWaitGroup.Wait() + } } -var onlineQueryScriptHash string - -func (ws *WebsocketHandler) RegisterHandlers(ctx context.Context, serveMux *http.ServeMux, prefix string) error { - ws.RedisSync.Del(context.Background(), ConnStateCacheKey) - - scriptHash, err := ws.RedisSync.ScriptLoad(context.Background(), connStateScript).Result() - if err != nil { - return err - } - onlineQueryScriptHash = scriptHash +func (ws *WebsocketHandler) RegisterHandlers(serveMux *http.ServeMux, prefix string) error { for region, sh := range ws.authCaches { if region == "default" { region = "" } - sh.url = common.MakeHttpHandlerPattern(prefix, region, "ws") - sh.redisSync = ws.RedisSync + url := gocommon.MakeHttpHandlerPattern(prefix, region, "ws") if *noSessionFlag { - serveMux.HandleFunc(sh.url, sh.upgrade_nosession) + serveMux.HandleFunc(url, sh.upgrade_nosession) } else { - serveMux.HandleFunc(sh.url, sh.upgrade) + serveMux.HandleFunc(url, sh.upgrade) } } return nil } -func (sh *subhandler) cloneTag(tag primitive.ObjectID) (out []*Richconn) { - sh.Lock() - cs := sh.tags[tag] - sh.Unlock() - - if cs == nil { - return nil +func (ws *WebsocketHandler) GetState(region string, accid primitive.ObjectID) string { + state, err := ws.RedisSync.Get(context.Background(), accid.Hex()).Result() + if err == redis.Nil { + return "" } - - cs.Lock() - defer cs.Unlock() - - out = make([]*Richconn, 0, len(cs.tagconnsmap)) - for _, c := range cs.tagconnsmap { - out = append(out, c.rc) - } - return + return state } -func (sh *subhandler) joinTag(tag primitive.ObjectID, tid primitive.ObjectID, rc *Richconn, hint string) { - sh.Lock() - cs := sh.tags[tag] - if cs == nil { - cs = &tagconns{ - tagconnsmap: make(map[primitive.ObjectID]*tagconn), - } - } - sh.Unlock() +func (ws *WebsocketHandler) SetState(region string, accid primitive.ObjectID, state string) { + ws.RedisSync.SetArgs(context.Background(), accid.Hex(), state, redis.SetArgs{Mode: "XX"}).Result() +} - cs.Lock() - states := make([]bson.M, 0, len(cs.tagconnsmap)) - for tid, conn := range cs.tagconnsmap { - states = append(states, bson.M{ - "_id": tid, - "_hint": hint, - "state": conn.state, - }) - } - - cs.tagconnsmap[tid] = &tagconn{rc: rc} - cs.Unlock() - - sh.Lock() - sh.tags[tag] = cs - sh.Unlock() - - if len(states) > 0 { - s, _ := json.Marshal(states) - rc.WriteBytes(s) +func (ws *WebsocketHandler) SendUpstreamMessage(region string, msg *UpstreamMessage) { + sh := ws.authCaches[region] + if sh != nil { + sh.localDeliveryChan <- msg } } -func (sh *subhandler) leaveTag(tag primitive.ObjectID, tid primitive.ObjectID) { - sh.Lock() - defer sh.Unlock() - - cs := sh.tags[tag] - if cs == nil { - return - } - - delete(cs.tagconnsmap, tid) - if len(cs.tagconnsmap) == 0 { - delete(sh.tags, tag) - } else { - sh.tags[tag] = cs - } -} - -func (sh *subhandler) setStateInTag(tag primitive.ObjectID, tid primitive.ObjectID, state string, hint string) { - sh.Lock() - cs := sh.tags[tag] - sh.Unlock() - - if cs == nil { - return - } - - cs.Lock() - defer cs.Unlock() - - if tagconn := cs.tagconnsmap[tid]; tagconn != nil { - tagconn.state = state - - var clone []*Richconn - for _, c := range cs.tagconnsmap { - clone = append(clone, c.rc) - } - raw, _ := json.Marshal(map[string]any{ - "_id": tid, - "_hint": hint, - "state": state, - }) - go func(raw []byte) { - for _, c := range clone { - c.WriteBytes(raw) - } - }(raw) - } -} - -func (wsh *WebsocketHandler) GetState(alias primitive.ObjectID) (string, error) { - state, err := wsh.RedisSync.EvalSha(context.Background(), onlineQueryScriptHash, []string{ - connStateCachePrefix + "*", alias.Hex(), - }).Result() - - if err != nil { - return "", err - } - - return state.(string), nil -} - -func (wsh *WebsocketHandler) IsOnline(alias primitive.ObjectID) (bool, error) { - state, err := wsh.GetState(alias) - if err != nil { - logger.Error("IsOnline failed. err :", err) - return false, err - } - return len(state) > 0, nil -} - -func (sh *subhandler) closeConn(accid primitive.ObjectID) { - sh.Lock() - defer sh.Unlock() - - if alias, ok := sh.aliases[accid]; ok { - if old := sh.conns[alias]; old != nil { - old.Close() +func (ws *WebsocketHandler) SendCloseMessage(region string, target string, text string) { + sh := ws.authCaches[region] + if sh != nil { + sh.localDeliveryChan <- &commandMessage{ + Cmd: commandType_WriteControl, + Args: []any{ + target, + int(websocket.CloseMessage), + websocket.FormatCloseMessage(websocket.CloseNormalClosure, text), + }, } } } -func (sh *subhandler) addConn(conn *Richconn, accid primitive.ObjectID) { - sh.Lock() - defer sh.Unlock() - - sh.conns[conn.alias] = conn - sh.aliases[accid] = conn.alias +func (ws *WebsocketHandler) EnterRoom(region string, room string, accid primitive.ObjectID) { + sh := ws.authCaches[region] + if sh != nil { + sh.localDeliveryChan <- &commandMessage{ + Cmd: commandType_JoinRoom, + Args: []any{room, accid}, + } + } } -func upgrade_core(sh *subhandler, conn *websocket.Conn, initState string, accid primitive.ObjectID, alias primitive.ObjectID) { - sh.closeConn(accid) +func (ws *WebsocketHandler) LeaveRoom(region string, room string, accid primitive.ObjectID) { + sh := ws.authCaches[region] + if sh != nil { + sh.localDeliveryChan <- &commandMessage{ + Cmd: commandType_LeaveRoom, + Args: []any{room, accid}, + } + } +} - newconn := sh.makeRichConn(alias, conn) - sh.addConn(newconn, accid) - sh.redisSync.HSet(context.Background(), ConnStateCacheKey, alias.Hex(), initState).Result() +func (sh *subhandler) mainLoop(ctx context.Context) { + defer func() { + s := recover() + if s != nil { + logger.Error(s) + } + }() - go func(c *Richconn, accid primitive.ObjectID, deliveryChan chan<- DeliveryMessage) { + // redis channel에서 유저가 보낸 메시지를 읽는 go rountine + go func() { + var pubsub *redis.PubSub for { - mt, p, err := c.ReadMessage() + if pubsub == nil { + pubsub = sh.redisSync.Subscribe(ctx, sh.redisMsgChanName, sh.redisCmdChanName) + } + raw, err := pubsub.ReceiveMessage(ctx) + if err == nil { + if raw.Channel == sh.redisMsgChanName { + var msg UpstreamMessage + if err := json.Unmarshal([]byte(raw.Payload), &msg); err == nil { + sh.deliveryChan <- &msg + } else { + logger.Println("decode UpstreamMessage failed :", err) + } + } else if raw.Channel == sh.redisCmdChanName { + var cmd commandMessage + if err := json.Unmarshal([]byte(raw.Payload), &cmd); err == nil { + sh.deliveryChan <- &cmd + } else { + logger.Println("decode UpstreamMessage failed :", err) + } + } + } else { + logger.Println("pubsub.ReceiveMessage failed :", err) + pubsub.Close() + pubsub = nil + + if ctx.Err() != nil { + break + } + } + } + }() + + entireConns := make(map[string]*wsconn) + rooms := make(map[string]*room) + findRoom := func(name string, create bool) *room { + room := rooms[name] + if room == nil && create { + room = makeRoom(name) + rooms[name] = room + room.start(ctx) + } + return room + } + + // 유저에게서 온 메세지, 소켓 연결/해체 처리 + for { + select { + case usermsg := <-sh.localDeliveryChan: + // 로컬에 connection이 있는지 먼저 확인해 보기 위한 채널 + // 없으면 publish한다. + switch usermsg := usermsg.(type) { + case *UpstreamMessage: + target := usermsg.Target + if target[0] == '@' { + accid := target[1:] + conn := entireConns[accid] + if conn != nil { + // 이 경우 아니면 publish 해야 함 + ds, _ := json.Marshal(DownstreamMessage{ + Alias: usermsg.Alias, + Body: usermsg.Body, + Tag: usermsg.Tag, + }) + + conn.WriteMessage(websocket.TextMessage, ds) + break + } + } + if bt, err := json.Marshal(usermsg); err == nil { + sh.redisSync.Publish(context.Background(), sh.redisMsgChanName, bt).Result() + } + + case *commandMessage: + if usermsg.Cmd == commandType_JoinRoom && len(usermsg.Args) == 2 { + roomName := usermsg.Args[0].(string) + accid := usermsg.Args[1].(primitive.ObjectID) + conn := entireConns[accid.Hex()] + if conn != nil { + findRoom(roomName, true).in(conn) + break + } + } else if usermsg.Cmd == commandType_LeaveRoom && len(usermsg.Args) == 2 { + roomName := usermsg.Args[0].(string) + accid := usermsg.Args[1].(primitive.ObjectID) + conn := entireConns[accid.Hex()] + if conn != nil { + if room := findRoom(roomName, false); room != nil { + room.out(conn) + break + } + } + } else if usermsg.Cmd == commandType_WriteControl && len(usermsg.Args) == 2 { + accid := usermsg.Args[0].(string) + conn := entireConns[accid] + if conn != nil { + conn.WriteControl(usermsg.Args[1].(int), usermsg.Args[2].([]byte), time.Time{}) + break + } + } + + // 위에서 break 안걸리면 나한테 없으므로 publish를 해야 함. 그러면 다른 호스트가 deliveryChan으로 받는다 + if bt, err := json.Marshal(usermsg); err == nil { + sh.redisSync.Publish(context.Background(), sh.redisCmdChanName, bt).Result() + } + } + + case usermsg := <-sh.deliveryChan: + switch usermsg := usermsg.(type) { + case *UpstreamMessage: + target := usermsg.Target + if target[0] == '#' { + // 룸에 브로드캐스팅 + roomName := target[1:] + if room := findRoom(roomName, false); room != nil { + room.broadcast(usermsg) + } + } else if target[0] == '@' { + accid := target[1:] + conn := entireConns[accid] + if conn != nil { + ds, _ := json.Marshal(DownstreamMessage{ + Alias: usermsg.Alias, + Body: usermsg.Body, + Tag: usermsg.Tag, + }) + conn.WriteMessage(websocket.TextMessage, ds) + } + } + + case *commandMessage: + if usermsg.Cmd == commandType_JoinRoom && len(usermsg.Args) == 2 { + roomName := usermsg.Args[0].(string) + accid := usermsg.Args[1].(primitive.ObjectID) + conn := entireConns[accid.Hex()] + if conn != nil { + findRoom(roomName, true).in(conn) + } + } else if usermsg.Cmd == commandType_LeaveRoom && len(usermsg.Args) == 2 { + roomName := usermsg.Args[0].(string) + accid := usermsg.Args[1].(primitive.ObjectID) + conn := entireConns[accid.Hex()] + if conn != nil { + if room := findRoom(roomName, false); room != nil { + room.out(conn) + } + } + } + + default: + logger.Println("usermsg is unknown type") + } + + case c := <-sh.connInOutChan: + if c.Conn == nil { + delete(entireConns, c.sender.Accid.Hex()) + for _, room := range rooms { + room.out(c) + } + sh.callReceiver(c.sender, Disconnected, nil) + } else { + entireConns[c.sender.Accid.Hex()] = c + sh.callReceiver(c.sender, Connected, nil) + } + } + } +} + +func upgrade_core(sh *subhandler, conn *websocket.Conn, accid primitive.ObjectID, alias string) { + newconn := &wsconn{ + Conn: conn, + sender: &Sender{ + Region: sh.region, + Alias: alias, + Accid: accid, + }, + } + sh.connInOutChan <- newconn + + sh.connWaitGroup.Add(1) + go func(c *wsconn, accid primitive.ObjectID, deliveryChan chan<- any) { + sh.redisSync.Set(context.Background(), accid.Hex(), "online", 0) + for { + messageType, r, err := c.NextReader() if err != nil { c.Close() break } - switch mt { - case websocket.BinaryMessage: - msg := DeliveryMessage{ - Alias: c.alias, - Body: p, - Conn: c, - } - deliveryChan <- msg - - case websocket.TextMessage: - msg := string(p) - opcodes := strings.Split(msg, ";") - for _, opcode := range opcodes { - if strings.HasPrefix(opcode, "ps:") { - sh.redisSync.HSet(context.Background(), ConnStateCacheKey, alias.Hex(), opcode[3:]).Result() - } else if strings.HasPrefix(opcode, "cmd:") { - cmd := opcode[4:] - msg := DeliveryMessage{ - Alias: c.alias, - Command: cmd, - Conn: c, - } - deliveryChan <- msg - } - } + if messageType == websocket.CloseMessage { + sh.callReceiver(c.sender, CloseMessage, r) + break + } + if messageType == websocket.TextMessage { + // 유저가 직접 보낸 메시지 + sh.callReceiver(c.sender, TextMessage, r) + } else if messageType == websocket.BinaryMessage { + sh.callReceiver(c.sender, BinaryMessage, r) } } + sh.redisSync.Del(context.Background(), accid.Hex()) + sh.connWaitGroup.Done() - sh.redisSync.HDel(context.Background(), ConnStateCacheKey, c.alias.Hex()).Result() - - sh.Lock() - delete(sh.conns, c.alias) - delete(sh.aliases, accid) - sh.Unlock() - - var funcs []func() - c.lock.Lock() - for _, f := range c.onClose { - funcs = append(funcs, f) - } - c.onClose = nil - c.lock.Unlock() - - for _, f := range funcs { - f() - } + c.Conn = nil + sh.connInOutChan <- c }(newconn, accid, sh.deliveryChan) } @@ -551,17 +506,14 @@ func (sh *subhandler) upgrade_nosession(w http.ResponseWriter, r *http.Request) return } - alias := accid + var alias string if v := r.Header.Get("AS-X-ALIAS"); len(v) > 0 { - alias = common.ParseObjectID(v) + alias = v + } else { + alias = accid.Hex() } - initState := r.Header.Get("As-X-Tavern-InitialState") - if len(initState) == 0 { - initState = "online" - } - - upgrade_core(sh, conn, initState, accid, alias) + upgrade_core(sh, conn, accid, alias) } func (sh *subhandler) upgrade(w http.ResponseWriter, r *http.Request) { @@ -597,24 +549,12 @@ func (sh *subhandler) upgrade(w http.ResponseWriter, r *http.Request) { return } - alias := accid + var alias string if v := r.Header.Get("AS-X-ALIAS"); len(v) > 0 { - alias = common.ParseObjectID(v) + alias = v + } else { + alias = accid.Hex() } - initState := r.Header.Get("As-X-Tavern-InitialState") - if len(initState) == 0 { - initState = "online" - } - - upgrade_core(sh, conn, initState, accid, alias) -} - -func (sh *subhandler) makeRichConn(alias primitive.ObjectID, conn *websocket.Conn) *Richconn { - rc := Richconn{ - Conn: conn, - alias: alias, - onClose: make(map[string]func()), - } - return &rc + upgrade_core(sh, conn, accid, alias) }