diff --git a/logger/logger.go b/logger/logger.go index 1e2b959..7aa2b31 100644 --- a/logger/logger.go +++ b/logger/logger.go @@ -7,6 +7,7 @@ import ( "log" "os" "path" + "runtime" "runtime/debug" "strings" ) @@ -101,3 +102,44 @@ func Panicln(v ...interface{}) { errlogger.Output(2, string(debug.Stack())) panic(s) } + +type errWithCallstack struct { + inner error + frames []*runtime.Frame +} + +func (ecs *errWithCallstack) Error() string { + if ecs.frames == nil { + return ecs.inner.Error() + } + + out := make([]string, 0, len(ecs.frames)+1) + out = append(out, ecs.inner.Error()) + for i := len(ecs.frames) - 1; i >= 0; i-- { + frame := ecs.frames[i] + out = append(out, fmt.Sprintf("%s\n\t%s:%d", frame.Function, frame.File, frame.Line)) + } + + return strings.Join(out, "\n") +} + +func ErrorWithCallStack(err error) error { + var frames []*runtime.Frame + + if recur, ok := err.(*errWithCallstack); ok { + err = recur.inner + frames = recur.frames + } + + pc, _, _, ok := runtime.Caller(1) + if ok { + curframes := runtime.CallersFrames([]uintptr{pc}) + f, _ := curframes.Next() + frames = append(frames, &f) + } + + return &errWithCallstack{ + inner: err, + frames: frames, + } +} diff --git a/server.go b/server.go index 3762acb..2bfdcb4 100644 --- a/server.go +++ b/server.go @@ -1,6 +1,7 @@ package gocommon import ( + "bytes" "context" "encoding/gob" "encoding/json" @@ -699,22 +700,68 @@ func MakeHttpApiHandler[T any](receiver *T, receiverName string) HttpApiHandler } } -type HttpApiHandlerContainer struct { - methods map[string]apiFuncType +type HttpApiBroker struct { + methods map[string]apiFuncType + methods_dup map[string][]apiFuncType } -func (hc *HttpApiHandlerContainer) RegisterApiHandler(receiver HttpApiHandler) { +type bufferReadCloser struct { + *bytes.Reader +} + +func (buff *bufferReadCloser) Close() error { return nil } + +type readOnlyResponseWriter struct { + inner http.ResponseWriter + statusCode int +} + +func (w *readOnlyResponseWriter) Header() http.Header { + return w.inner.Header() +} + +func (w *readOnlyResponseWriter) Write(in []byte) (int, error) { + logger.Println("readOnlyResponseWriter cannot write") + return len(in), nil +} + +func (w *readOnlyResponseWriter) WriteHeader(statusCode int) { + w.statusCode = statusCode +} + +func (hc *HttpApiBroker) AddHandler(receiver HttpApiHandler) { if hc.methods == nil { hc.methods = make(map[string]apiFuncType) + hc.methods_dup = make(map[string][]apiFuncType) } for k, v := range receiver.methods { logger.Println("http api registered :", k) - hc.methods[k] = v + + hc.methods_dup[k] = append(hc.methods_dup[k], v) + if len(hc.methods_dup[k]) > 1 { + chain := hc.methods_dup[k] + hc.methods[k] = func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + defer r.Body.Close() + + wrap := &readOnlyResponseWriter{inner: w, statusCode: 200} + for _, f := range chain { + r.Body = &bufferReadCloser{bytes.NewReader(body)} + f(wrap, r) + } + + if wrap.statusCode != 200 { + w.WriteHeader(wrap.statusCode) + } + } + } else { + hc.methods[k] = v + } } } -func (hc *HttpApiHandlerContainer) Call(funcname string, w http.ResponseWriter, r *http.Request) { +func (hc *HttpApiBroker) Call(funcname string, w http.ResponseWriter, r *http.Request) { if found := hc.methods[funcname]; found != nil { found(w, r) } else { diff --git a/session/impl_redis.go b/session/impl_redis.go index fa5a775..1097bdf 100644 --- a/session/impl_redis.go +++ b/session/impl_redis.go @@ -53,18 +53,23 @@ func (p *provider_redis) New(input *Authorization) (string, error) { if err != nil { return "", err } + pk := storagekey_to_publickey(sk) - return string(storagekey_to_publickey(sk)), err + logger.Println("session provider new :", sk, pk) + + return string(pk), err } func (p *provider_redis) Delete(account primitive.ObjectID) error { prefix := account.Hex() sks, err := p.redisClient.Keys(p.ctx, prefix+"*").Result() if err != nil { + logger.Println("session provider delete :", sks, err) return err } for _, sk := range sks { + logger.Println("session provider delete :", sk) p.redisClient.Del(p.ctx, sk).Result() p.redisClient.Publish(p.ctx, p.deleteChannel, sk).Result() } @@ -76,29 +81,38 @@ func (p *provider_redis) Query(pk string) (Authorization, error) { sk := publickey_to_storagekey(publickey(pk)) payload, err := p.redisClient.Get(p.ctx, string(sk)).Result() if err == redis.Nil { + logger.Println("session provider query :", pk, err) return Authorization{}, nil } else if err != nil { + logger.Println("session provider query :", pk, err) return Authorization{}, err } var auth Authorization if err := json.Unmarshal([]byte(payload), &auth); err != nil { + logger.Println("session provider query :", pk, err) return Authorization{}, err } + logger.Println("session provider query :", pk, auth) + return auth, nil } func (p *provider_redis) Touch(pk string) (bool, error) { sk := publickey_to_storagekey(publickey(pk)) ok, err := p.redisClient.Expire(p.ctx, string(sk), p.ttl).Result() + logger.Println("session provider touch :", pk) + if err == redis.Nil { // 이미 만료됨 + logger.Println("session consumer touch :", pk, err) return false, nil } else if err != nil { - logger.Println("consumer Touch :", err) + logger.Println("session consumer touch :", pk, err) return false, err } + logger.Println("session consumer touch :", pk) return ok, nil } @@ -223,16 +237,20 @@ func (c *consumer_redis) Query(pk string) (Authorization, error) { sk := publickey_to_storagekey(publickey(pk)) si, _, err := c.query_internal(sk) if err != nil { + logger.Println("session consumer query :", pk, err) return Authorization{}, err } if si == nil { + logger.Println("session consumer query(si nil) :", pk, nil) return Authorization{}, nil } if time.Now().After(si.expireAt) { + logger.Println("session consumer query(expired):", pk, nil) return Authorization{}, nil } + logger.Println("session consumer query :", pk) return *si.Authorization, nil } @@ -244,9 +262,11 @@ func (c *consumer_redis) Touch(pk string) (Authorization, error) { sk := publickey_to_storagekey(publickey(pk)) ok, err := c.redisClient.Expire(c.ctx, string(sk), c.ttl).Result() if err == redis.Nil { + logger.Println("session consumer touch :", pk, err) + return Authorization{}, nil } else if err != nil { - logger.Println("consumer Touch :", err) + logger.Println("session consumer touch :", pk, err) return Authorization{}, err } @@ -254,10 +274,12 @@ func (c *consumer_redis) Touch(pk string) (Authorization, error) { // redis에 살아있다. si, added, err := c.query_internal(sk) if err != nil { + logger.Println("session consumer touch(ok) :", pk, err) return Authorization{}, err } if si == nil { + logger.Println("session consumer touch(ok, si nil) :", pk) return Authorization{}, nil } @@ -267,8 +289,10 @@ func (c *consumer_redis) Touch(pk string) (Authorization, error) { c.add_internal(sk, si) } + logger.Println("session consumer touch(ok) :", pk) return *si.Authorization, nil } + logger.Println("session consumer touch(!ok) :", pk) return Authorization{}, nil } diff --git a/wshandler/api_handler.go b/wshandler/api_handler.go index f977f4b..adeba0a 100644 --- a/wshandler/api_handler.go +++ b/wshandler/api_handler.go @@ -4,8 +4,10 @@ import ( "encoding/json" "io" "reflect" + "strings" "unsafe" + "github.com/gorilla/websocket" "repositories.action2quare.com/ayo/gocommon/logger" ) @@ -15,11 +17,13 @@ const ( ) type apiFuncType func(ApiCallContext) +type connFuncType func(*websocket.Conn, *Sender) type WebsocketApiHandler struct { - methods map[string]apiFuncType - connfunc apiFuncType - disconnfunc apiFuncType + methods map[string]apiFuncType + connfunc connFuncType + disconnfunc connFuncType + originalReceiverName string } type ApiCallContext struct { @@ -35,92 +39,133 @@ func MakeWebsocketApiHandler[T any](receiver *T, receiverName string) WebsocketA receiverName = tp.Elem().Name() } - var connfunc apiFuncType - var disconnfunc apiFuncType + var connfunc connFuncType + var disconnfunc connFuncType for i := 0; i < tp.NumMethod(); i++ { method := tp.Method(i) - if method.Type.NumIn() != 2 { - continue - } - if method.Type.In(0) != tp { continue } - if method.Type.In(1) != reflect.TypeOf((*ApiCallContext)(nil)).Elem() { - continue - } - - funcptr := method.Func.Pointer() - p1 := unsafe.Pointer(&funcptr) - p2 := unsafe.Pointer(&p1) - testfunc := (*func(*T, ApiCallContext))(p2) - if method.Name == ClientConnected { - connfunc = func(ctx ApiCallContext) { - (*testfunc)(receiver, ctx) + if method.Type.NumIn() != 3 { + continue + } + if method.Type.In(1) != reflect.TypeOf((*websocket.Conn)(nil)) { + continue + } + if method.Type.In(2) != reflect.TypeOf((*Sender)(nil)) { + continue + } + funcptr := method.Func.Pointer() + p1 := unsafe.Pointer(&funcptr) + p2 := unsafe.Pointer(&p1) + connfuncptr := (*func(*T, *websocket.Conn, *Sender))(p2) + + connfunc = func(c *websocket.Conn, s *Sender) { + (*connfuncptr)(receiver, c, s) } } else if method.Name == ClientDisconnected { - disconnfunc = func(ctx ApiCallContext) { - (*testfunc)(receiver, ctx) + if method.Type.NumIn() != 3 { + continue + } + if method.Type.In(1) != reflect.TypeOf((*websocket.Conn)(nil)) { + continue + } + if method.Type.In(2) != reflect.TypeOf((*Sender)(nil)) { + continue + } + funcptr := method.Func.Pointer() + p1 := unsafe.Pointer(&funcptr) + p2 := unsafe.Pointer(&p1) + disconnfuncptr := (*func(*T, *websocket.Conn, *Sender))(p2) + + disconnfunc = func(c *websocket.Conn, s *Sender) { + (*disconnfuncptr)(receiver, c, s) } } else { + if method.Type.NumIn() != 2 { + continue + } + if method.Type.In(1) != reflect.TypeOf((*ApiCallContext)(nil)).Elem() { + continue + } + + funcptr := method.Func.Pointer() + p1 := unsafe.Pointer(&funcptr) + p2 := unsafe.Pointer(&p1) + apifuncptr := (*func(*T, ApiCallContext))(p2) methods[receiverName+"."+method.Name] = func(ctx ApiCallContext) { - (*testfunc)(receiver, ctx) + (*apifuncptr)(receiver, ctx) } } } return WebsocketApiHandler{ - methods: methods, - connfunc: connfunc, - disconnfunc: disconnfunc, + methods: methods, + connfunc: connfunc, + disconnfunc: disconnfunc, + originalReceiverName: tp.Elem().Name(), } } type WebsocketApiBroker struct { methods map[string]apiFuncType - connFuncs []apiFuncType - disconnFuncs []apiFuncType + methods_dup map[string][]apiFuncType + connFuncs []connFuncType + disconnFuncs []connFuncType } func (hc *WebsocketApiBroker) AddHandler(receiver WebsocketApiHandler) { if hc.methods == nil { hc.methods = make(map[string]apiFuncType) + hc.methods_dup = make(map[string][]apiFuncType) } for k, v := range receiver.methods { - logger.Println("http api registered :", k) - hc.methods[k] = v + ab := strings.Split(k, ".") + logger.Printf("websocket api registered : %s.%s -> %s\n", receiver.originalReceiverName, ab[1], k) + + hc.methods_dup[k] = append(hc.methods_dup[k], v) + if len(hc.methods_dup[k]) > 1 { + chain := hc.methods_dup[k] + hc.methods[k] = func(ctx ApiCallContext) { + for _, f := range chain { + f(ctx) + } + } + } else { + hc.methods[k] = v + } } if receiver.connfunc != nil { + logger.Printf("websocket api registered : %s.ClientConnected\n", receiver.originalReceiverName) hc.connFuncs = append(hc.connFuncs, receiver.connfunc) } if receiver.disconnfunc != nil { // disconnfunc은 역순 - hc.disconnFuncs = append([]apiFuncType{receiver.disconnfunc}, hc.disconnFuncs...) + logger.Printf("websocket api registered : %s.ClientDisconnected\n", receiver.originalReceiverName) + hc.disconnFuncs = append([]connFuncType{receiver.disconnfunc}, hc.disconnFuncs...) + } +} + +func (hc *WebsocketApiBroker) ClientConnected(c *wsconn) { + for _, v := range hc.connFuncs { + v(c.Conn, c.sender) + } +} + +func (hc *WebsocketApiBroker) ClientDisconnected(c *wsconn) { + for _, v := range hc.disconnFuncs { + v(c.Conn, c.sender) } } func (hc *WebsocketApiBroker) Call(callby *Sender, funcname string, r io.Reader) { - if funcname == ClientConnected { - for _, v := range hc.connFuncs { - v(ApiCallContext{ - CallBy: callby, - Arguments: nil, - }) - } - } else if funcname == ClientDisconnected { - for _, v := range hc.disconnFuncs { - v(ApiCallContext{ - CallBy: callby, - Arguments: nil, - }) - } - } else if found := hc.methods[funcname]; found != nil { + if found := hc.methods[funcname]; found != nil { var args []any if r != nil { dec := json.NewDecoder(r) diff --git a/wshandler/wshandler.go b/wshandler/wshandler.go index f4e4755..4eeafce 100644 --- a/wshandler/wshandler.go +++ b/wshandler/wshandler.go @@ -85,6 +85,8 @@ type send_msg_queue_elem struct { } type WebsocketHandler struct { + WebsocketApiBroker + redisMsgChanName string redisCmdChanName string redisSync *redis.Client @@ -93,7 +95,6 @@ type WebsocketHandler struct { localDeliveryChan chan any sendMsgChan chan send_msg_queue_elem - wsApiBroker WebsocketApiBroker connWaitGroup sync.WaitGroup sessionConsumer session.Consumer } @@ -121,7 +122,7 @@ func NewWebsocketHandler(consumer session.Consumer, redisUrl string) (*Websocket redisSync, err := gocommon.NewRedisClient(redisUrl) if err != nil { - return nil, err + return nil, logger.ErrorWithCallStack(err) } sendchan := make(chan send_msg_queue_elem, 1000) @@ -153,10 +154,6 @@ func NewWebsocketHandler(consumer session.Consumer, redisUrl string) (*Websocket }, nil } -func (ws *WebsocketHandler) RegisterApiHandler(handler WebsocketApiHandler) { - ws.wsApiBroker.AddHandler(handler) -} - func (ws *WebsocketHandler) Start(ctx context.Context) { ws.connWaitGroup.Add(1) go ws.mainLoop(ctx) @@ -260,7 +257,7 @@ func (ws *WebsocketHandler) mainLoop(ctx context.Context) { defer func() { for _, conn := range entireConns { - ws.wsApiBroker.Call(conn.sender, ClientDisconnected, nil) + ws.Call(conn.sender, ClientDisconnected, nil) conn.Close() } }() @@ -414,10 +411,12 @@ func (ws *WebsocketHandler) mainLoop(ctx context.Context) { case c := <-ws.connInOutChan: if c.Conn == nil { delete(entireConns, c.sender.Accid.Hex()) - go ws.wsApiBroker.Call(c.sender, ClientDisconnected, nil) + logger.Println("ClientDisconnected :", c.sender.Alias) + go ws.ClientDisconnected(c) } else { entireConns[c.sender.Accid.Hex()] = c - go ws.wsApiBroker.Call(c.sender, ClientConnected, nil) + logger.Println("ClientConnected :", c.sender.Alias) + go ws.ClientConnected(c) } } } @@ -451,7 +450,7 @@ func upgrade_core(ws *WebsocketHandler, conn *websocket.Conn, accid primitive.Ob r.Read(size[:]) cmd := make([]byte, size[0]) r.Read(cmd) - ws.wsApiBroker.Call(newconn.sender, string(cmd), r) + ws.Call(newconn.sender, string(cmd), r) } } ws.connWaitGroup.Done()