package wshandler import ( "context" "encoding/hex" "encoding/json" "fmt" "io" "net/http" "os" "strings" "sync" common "repositories.action2quare.com/ayo/gocommon" "repositories.action2quare.com/ayo/gocommon/flagx" "repositories.action2quare.com/ayo/gocommon/logger" "github.com/go-redis/redis/v8" "github.com/gorilla/websocket" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/bson/primitive" ) var noSessionFlag = flagx.Bool("nosession", false, "nosession=[true|false]") const ( connStateCachePrefix = "conn_state_" connStateScript = ` local hosts = redis.call('keys',KEYS[1]) for index, key in ipairs(hosts) do local ok = redis.call('hexists', key, KEYS[2]) if ok == 1 then return redis.call('hget', key, KEYS[2]) end end return "" ` ) var ConnStateCacheKey = func() string { hn, _ := os.Hostname() return connStateCachePrefix + hn }() type Richconn struct { *websocket.Conn lock sync.Mutex alias primitive.ObjectID tags []string onClose map[string]func() } func (rc *Richconn) AddTag(name, val string) { rc.lock.Lock() defer rc.lock.Unlock() prefix := name + "=" for i, tag := range rc.tags { if strings.HasPrefix(tag, prefix) { rc.tags[i] = prefix + val return } } rc.tags = append(rc.tags, prefix+val) } func (rc *Richconn) GetTag(name string) string { rc.lock.Lock() defer rc.lock.Unlock() prefix := name + "=" for _, tag := range rc.tags { if strings.HasPrefix(tag, prefix) { return tag[len(prefix):] } } return "" } func (rc *Richconn) RemoveTag(name string, val string) { rc.lock.Lock() defer rc.lock.Unlock() whole := fmt.Sprintf("%s=%s", name, val) for i, tag := range rc.tags { if tag == whole { if i == 0 && len(rc.tags) == 1 { rc.tags = nil } else { lastidx := len(rc.tags) - 1 if i < lastidx { rc.tags[i] = rc.tags[lastidx] } rc.tags = rc.tags[:lastidx] } return } } } func (rc *Richconn) RegistOnCloseFunc(name string, f func()) { rc.lock.Lock() defer rc.lock.Unlock() if rc.onClose == nil { f() return } rc.onClose[name] = f } func (rc *Richconn) HasOnCloseFunc(name string) bool { rc.lock.Lock() defer rc.lock.Unlock() if rc.onClose == nil { return false } _, ok := rc.onClose[name] return ok } func (rc *Richconn) UnregistOnCloseFunc(name string) (out func()) { rc.lock.Lock() defer rc.lock.Unlock() if rc.onClose == nil { return } out = rc.onClose[name] delete(rc.onClose, name) return } func (rc *Richconn) WriteBytes(data []byte) error { rc.lock.Lock() defer rc.lock.Unlock() return rc.WriteMessage(websocket.TextMessage, data) } type DeliveryMessage struct { Alias primitive.ObjectID Body []byte Command string Conn *Richconn } func (dm *DeliveryMessage) Parse(out any) error { return json.Unmarshal(dm.Body, out) } func (dm *DeliveryMessage) MarshalBinary() (data []byte, err error) { return append(dm.Alias[:], dm.Body...), nil } func (dm *DeliveryMessage) UnmarshalBinary(data []byte) error { copy(dm.Alias[:], data[:12]) dm.Body = data[12:] return nil } type tagconn struct { rc *Richconn state string } type tagconnsmap = map[primitive.ObjectID]*tagconn type tagconns struct { sync.Mutex tagconnsmap } type subhandler struct { sync.Mutex authCache *common.AuthCollection conns map[primitive.ObjectID]*Richconn aliases map[primitive.ObjectID]primitive.ObjectID tags map[primitive.ObjectID]*tagconns deliveryChan chan DeliveryMessage url string redisSync *redis.Client } // WebsocketHandler : type WebsocketHandler struct { authCaches map[string]*subhandler RedisSync *redis.Client } type wsConfig struct { SyncPipeline string `json:"ws_sync_pipeline"` } func NewWebsocketHandler(authglobal common.AuthCollectionGlobal) (wsh *WebsocketHandler) { authCaches := make(map[string]*subhandler) for _, region := range authglobal.Regions() { sh := &subhandler{ authCache: authglobal.Get(region), conns: make(map[primitive.ObjectID]*Richconn), aliases: make(map[primitive.ObjectID]primitive.ObjectID), tags: make(map[primitive.ObjectID]*tagconns), deliveryChan: make(chan DeliveryMessage, 1000), } authCaches[region] = sh } var config wsConfig common.LoadConfig(&config) redisSync, err := common.NewRedisClient(config.SyncPipeline, 0) if err != nil { panic(err) } return &WebsocketHandler{ authCaches: authCaches, RedisSync: redisSync, } } func (ws *WebsocketHandler) Destructor() { if ws.RedisSync != nil { ws.RedisSync.Del(context.Background(), ConnStateCacheKey) } } func (ws *WebsocketHandler) DeliveryChannel(region string) <-chan DeliveryMessage { return ws.authCaches[region].deliveryChan } func (ws *WebsocketHandler) Conn(region string, alias primitive.ObjectID) *Richconn { if sh := ws.authCaches[region]; sh != nil { return sh.conns[alias] } return nil } func (ws *WebsocketHandler) JoinTag(region string, tag primitive.ObjectID, tid primitive.ObjectID, rc *Richconn, hint string) error { if sh := ws.authCaches[region]; sh != nil { sh.joinTag(tag, tid, rc, hint) } return nil } func (ws *WebsocketHandler) LeaveTag(region string, tag primitive.ObjectID, tid primitive.ObjectID) error { if sh := ws.authCaches[region]; sh != nil { sh.leaveTag(tag, tid) } return nil } func (ws *WebsocketHandler) SetStateInTag(region string, tag primitive.ObjectID, tid primitive.ObjectID, state string, hint string) error { if sh := ws.authCaches[region]; sh != nil { sh.setStateInTag(tag, tid, state, hint) } return nil } func (ws *WebsocketHandler) BroadcastRaw(region string, tag primitive.ObjectID, raw []byte) { if sh := ws.authCaches[region]; sh != nil { if cs := sh.cloneTag(tag); len(cs) > 0 { go func(raw []byte) { for _, c := range cs { if c != nil { c.WriteBytes(raw) } } }(raw) } } } func (ws *WebsocketHandler) Broadcast(region string, tag primitive.ObjectID, doc bson.M) { raw, _ := json.Marshal(doc) ws.BroadcastRaw(region, tag, raw) } var onlineQueryScriptHash string func (ws *WebsocketHandler) RegisterHandlers(ctx context.Context, serveMux *http.ServeMux, prefix string) error { ws.RedisSync.Del(context.Background(), ConnStateCacheKey) scriptHash, err := ws.RedisSync.ScriptLoad(context.Background(), connStateScript).Result() if err != nil { return err } onlineQueryScriptHash = scriptHash for region, sh := range ws.authCaches { if region == "default" { region = "" } sh.url = common.MakeHttpHandlerPattern(prefix, region, "ws") sh.redisSync = ws.RedisSync if *noSessionFlag { serveMux.HandleFunc(sh.url, sh.upgrade_nosession) } else { serveMux.HandleFunc(sh.url, sh.upgrade) } } return nil } func (sh *subhandler) cloneTag(tag primitive.ObjectID) (out []*Richconn) { sh.Lock() cs := sh.tags[tag] sh.Unlock() if cs == nil { return nil } cs.Lock() defer cs.Unlock() out = make([]*Richconn, 0, len(cs.tagconnsmap)) for _, c := range cs.tagconnsmap { out = append(out, c.rc) } return } func (sh *subhandler) joinTag(tag primitive.ObjectID, tid primitive.ObjectID, rc *Richconn, hint string) { sh.Lock() cs := sh.tags[tag] if cs == nil { cs = &tagconns{ tagconnsmap: make(map[primitive.ObjectID]*tagconn), } } sh.Unlock() cs.Lock() states := make([]bson.M, 0, len(cs.tagconnsmap)) for tid, conn := range cs.tagconnsmap { states = append(states, bson.M{ "_id": tid, "_hint": hint, "state": conn.state, }) } cs.tagconnsmap[tid] = &tagconn{rc: rc} cs.Unlock() sh.Lock() sh.tags[tag] = cs sh.Unlock() if len(states) > 0 { s, _ := json.Marshal(states) rc.WriteBytes(s) } } func (sh *subhandler) leaveTag(tag primitive.ObjectID, tid primitive.ObjectID) { sh.Lock() defer sh.Unlock() cs := sh.tags[tag] if cs == nil { return } delete(cs.tagconnsmap, tid) if len(cs.tagconnsmap) == 0 { delete(sh.tags, tag) } else { sh.tags[tag] = cs } } func (sh *subhandler) setStateInTag(tag primitive.ObjectID, tid primitive.ObjectID, state string, hint string) { sh.Lock() cs := sh.tags[tag] sh.Unlock() if cs == nil { return } cs.Lock() defer cs.Unlock() if tagconn := cs.tagconnsmap[tid]; tagconn != nil { tagconn.state = state var clone []*Richconn for _, c := range cs.tagconnsmap { clone = append(clone, c.rc) } raw, _ := json.Marshal(map[string]any{ "_id": tid, "_hint": hint, "state": state, }) go func(raw []byte) { for _, c := range clone { c.WriteBytes(raw) } }(raw) } } func (wsh *WebsocketHandler) GetState(alias primitive.ObjectID) (string, error) { state, err := wsh.RedisSync.EvalSha(context.Background(), onlineQueryScriptHash, []string{ connStateCachePrefix + "*", alias.Hex(), }).Result() if err != nil { return "", err } return state.(string), nil } func (wsh *WebsocketHandler) IsOnline(alias primitive.ObjectID) (bool, error) { state, err := wsh.GetState(alias) if err != nil { logger.Error("IsOnline failed. err :", err) return false, err } return len(state) > 0, nil } func (sh *subhandler) closeConn(accid primitive.ObjectID) { sh.Lock() defer sh.Unlock() if alias, ok := sh.aliases[accid]; ok { if old := sh.conns[alias]; old != nil { old.Close() } } } func (sh *subhandler) addConn(conn *Richconn, accid primitive.ObjectID) { sh.Lock() defer sh.Unlock() sh.conns[conn.alias] = conn sh.aliases[accid] = conn.alias } func upgrade_core(sh *subhandler, conn *websocket.Conn, initState string, accid primitive.ObjectID, alias primitive.ObjectID) { sh.closeConn(accid) newconn := sh.makeRichConn(alias, conn) sh.addConn(newconn, accid) sh.redisSync.HSet(context.Background(), ConnStateCacheKey, alias.Hex(), initState).Result() go func(c *Richconn, accid primitive.ObjectID, deliveryChan chan<- DeliveryMessage) { for { mt, p, err := c.ReadMessage() if err != nil { c.Close() break } switch mt { case websocket.BinaryMessage: msg := DeliveryMessage{ Alias: c.alias, Body: p, Conn: c, } deliveryChan <- msg case websocket.TextMessage: msg := string(p) opcodes := strings.Split(msg, ";") for _, opcode := range opcodes { if strings.HasPrefix(opcode, "ps:") { sh.redisSync.HSet(context.Background(), ConnStateCacheKey, alias.Hex(), opcode[3:]).Result() } else if strings.HasPrefix(opcode, "cmd:") { cmd := opcode[4:] msg := DeliveryMessage{ Alias: c.alias, Command: cmd, Conn: c, } deliveryChan <- msg } } } } sh.redisSync.HDel(context.Background(), ConnStateCacheKey, c.alias.Hex()).Result() sh.Lock() delete(sh.conns, c.alias) delete(sh.aliases, accid) sh.Unlock() var funcs []func() c.lock.Lock() for _, f := range c.onClose { funcs = append(funcs, f) } c.onClose = nil c.lock.Unlock() for _, f := range funcs { f() } }(newconn, accid, sh.deliveryChan) } func (sh *subhandler) upgrade_nosession(w http.ResponseWriter, r *http.Request) { // 클라이언트 접속 defer func() { s := recover() if s != nil { logger.Error(s) } io.Copy(io.Discard, r.Body) r.Body.Close() }() auth := strings.Split(r.Header.Get("Authorization"), " ") if len(auth) != 2 { w.WriteHeader(http.StatusBadRequest) return } if auth[0] != "Editor" { w.WriteHeader(http.StatusBadRequest) return } temp, err := hex.DecodeString(auth[1]) if err != nil { w.WriteHeader(http.StatusBadRequest) return } if len(temp) != len(primitive.NilObjectID) { w.WriteHeader(http.StatusBadRequest) return } raw := (*[12]byte)(temp) accid := primitive.ObjectID(*raw) var upgrader = websocket.Upgrader{} // use default options conn, err := upgrader.Upgrade(w, r, nil) if err != nil { w.WriteHeader(http.StatusInternalServerError) return } alias := accid if v := r.Header.Get("AS-X-ALIAS"); len(v) > 0 { alias = common.ParseObjectID(v) } initState := r.Header.Get("As-X-Tavern-InitialState") if len(initState) == 0 { initState = "online" } upgrade_core(sh, conn, initState, accid, alias) } func (sh *subhandler) upgrade(w http.ResponseWriter, r *http.Request) { // 클라이언트 접속 defer func() { s := recover() if s != nil { logger.Error(s) } io.Copy(io.Discard, r.Body) r.Body.Close() }() sk := r.Header.Get("AS-X-SESSION") auth := strings.Split(r.Header.Get("Authorization"), " ") if len(auth) != 2 { //TODO : 클라이언트는 BadRequest를 받으면 로그인 화면으로 돌아가야 한다. w.WriteHeader(http.StatusBadRequest) return } authtoken := auth[1] accid, success := sh.authCache.IsValid(sk, authtoken) if !success { w.WriteHeader(http.StatusUnauthorized) return } var upgrader = websocket.Upgrader{} // use default options conn, err := upgrader.Upgrade(w, r, nil) if err != nil { w.WriteHeader(http.StatusInternalServerError) return } alias := accid if v := r.Header.Get("AS-X-ALIAS"); len(v) > 0 { alias = common.ParseObjectID(v) } initState := r.Header.Get("As-X-Tavern-InitialState") if len(initState) == 0 { initState = "online" } upgrade_core(sh, conn, initState, accid, alias) } func (sh *subhandler) makeRichConn(alias primitive.ObjectID, conn *websocket.Conn) *Richconn { rc := Richconn{ Conn: conn, alias: alias, onClose: make(map[string]func()), } return &rc }