package session import ( "context" "fmt" "time" "github.com/go-redis/redis/v8" "go.mongodb.org/mongo-driver/bson/primitive" "repositories.action2quare.com/ayo/gocommon" "repositories.action2quare.com/ayo/gocommon/logger" ) const ( communication_channel_name_prefix = "_sess_comm_chan_name" ) type sessionRedis struct { *Authorization expireAt time.Time } type provider_redis struct { redisClient *redis.Client deleteChannel string ttl time.Duration ctx context.Context } func newProviderWithRedis(ctx context.Context, redisUrl string, ttl time.Duration) (Provider, error) { redisClient, err := gocommon.NewRedisClient(redisUrl) if err != nil { return nil, err } return &provider_redis{ redisClient: redisClient, deleteChannel: fmt.Sprintf("%s_%d_d", communication_channel_name_prefix, redisClient.Options().DB), ttl: ttl, ctx: ctx, }, nil } func (p *provider_redis) New(input *Authorization) (string, error) { newsk := make_storagekey(input.Account) prefix := input.Account.Hex() sks, err := p.redisClient.Keys(p.ctx, prefix+"*").Result() if err != nil { logger.Println("session provider delete :", sks, err) return "", err } for _, sk := range sks { p.redisClient.HSet(p.ctx, sk, "inv", "true") p.redisClient.Publish(p.ctx, p.deleteChannel, sk).Result() } for { duplicated := false for _, sk := range sks { if sk == string(newsk) { duplicated = true break } } if !duplicated { break } newsk = make_storagekey(input.Account) } _, err = p.redisClient.HSet(p.ctx, string(newsk), input.ToStrings()).Result() if err != nil { return "", err } _, err = p.redisClient.Expire(p.ctx, string(newsk), p.ttl).Result() if err != nil { return "", err } pk := storagekey_to_publickey(newsk) return string(pk), err } func (p *provider_redis) RevokeAll(account primitive.ObjectID) error { prefix := account.Hex() sks, err := p.redisClient.Keys(p.ctx, prefix+"*").Result() if err != nil { logger.Println("session provider delete :", sks, err) return err } for _, sk := range sks { p.redisClient.HSet(p.ctx, sk, "inv", "true") p.redisClient.Publish(p.ctx, p.deleteChannel, sk).Result() } return nil } func (p *provider_redis) Query(pk string) (Authorization, error) { sk := publickey_to_storagekey(publickey(pk)) src, err := p.redisClient.HGetAll(p.ctx, string(sk)).Result() if err == redis.Nil { logger.Println("session provider query :", pk, err) return Authorization{}, nil } else if err != nil { logger.Println("session provider query :", pk, err) return Authorization{}, err } auth := MakeAuthrizationFromStringMap(src) return auth, nil } func (p *provider_redis) Touch(pk string) (bool, error) { sk := publickey_to_storagekey(publickey(pk)) ok, err := p.redisClient.Expire(p.ctx, string(sk), p.ttl).Result() if err == redis.Nil { // 이미 만료됨 logger.Println("session provider touch :", pk, err) return false, nil } else if err != nil { logger.Println("session provider touch :", pk, err) return false, err } return ok, nil } type consumer_redis struct { consumer_common[*sessionRedis] redisClient *redis.Client deleteChannel string } func newConsumerWithRedis(ctx context.Context, redisUrl string, ttl time.Duration) (Consumer, error) { redisClient, err := gocommon.NewRedisClient(redisUrl) if err != nil { return nil, err } deleteChannel := fmt.Sprintf("%s_%d_d", communication_channel_name_prefix, redisClient.Options().DB) sub := redisClient.Subscribe(ctx, deleteChannel) consumer := &consumer_redis{ consumer_common: consumer_common[*sessionRedis]{ ttl: ttl, ctx: ctx, stages: [2]*cache_stage[*sessionRedis]{make_cache_stage[*sessionRedis](), make_cache_stage[*sessionRedis]()}, startTime: time.Now(), }, redisClient: redisClient, deleteChannel: deleteChannel, } go func() { stageswitch := time.Now().Add(ttl) tickTimer := time.After(ttl) for { select { case <-ctx.Done(): return case <-tickTimer: consumer.changeStage() stageswitch = stageswitch.Add(ttl) tempttl := time.Until(stageswitch) tickTimer = time.After(tempttl) case msg := <-sub.Channel(): if msg == nil { return } if len(msg.Payload) == 0 { continue } switch msg.Channel { case deleteChannel: sk := storagekey(msg.Payload) old := consumer.delete(sk) if old != nil { for _, f := range consumer.onSessionInvalidated { f(old.Account) } } } } } }() return consumer, nil } func (c *consumer_redis) query_internal(sk storagekey) (*sessionRedis, error) { if old, deleted := c.stages[0].deleted[sk]; deleted { return old, nil } if old, deleted := c.stages[1].deleted[sk]; deleted { return old, nil } found, ok := c.stages[0].cache[sk] if !ok { found, ok = c.stages[1].cache[sk] } if ok { if time.Now().Before(found.expireAt) { // 만료전 세션 return found, nil } // 다른 Consumer가 Touch했을 수도 있으므로 redis에서 읽어본다. } payload, err := c.redisClient.HGetAll(c.ctx, string(sk)).Result() if err != nil && err != redis.Nil { logger.Println("consumer Query :", err) return nil, err } if len(payload) == 0 { return nil, nil } ttl, err := c.redisClient.TTL(c.ctx, string(sk)).Result() if err != nil { logger.Println("consumer Query :", err) return nil, err } if ttl < 0 { ttl = time.Duration(time.Hour * 24) } auth := MakeAuthrizationFromStringMap(payload) si := &sessionRedis{ Authorization: &auth, expireAt: time.Now().Add(ttl), } if auth.Invalidated() { c.stages[0].deleted[sk] = si } else { c.add_internal(sk, si) } return si, nil } func (c *consumer_redis) Query(pk string) (Authorization, error) { c.lock.Lock() defer c.lock.Unlock() sk := publickey_to_storagekey(publickey(pk)) si, err := c.query_internal(sk) if err != nil { logger.Println("session consumer query :", pk, err) return Authorization{}, err } if si == nil { logger.Println("session consumer query(si nil) :", pk, nil) return Authorization{}, nil } if time.Now().After(si.expireAt) { logger.Println("session consumer query(expired):", pk, nil) return Authorization{}, nil } return *si.Authorization, nil } func (c *consumer_redis) Touch(pk string) (Authorization, error) { c.lock.Lock() defer c.lock.Unlock() sk := publickey_to_storagekey(publickey(pk)) if _, deleted := c.stages[0].deleted[sk]; deleted { return Authorization{}, nil } if _, deleted := c.stages[1].deleted[sk]; deleted { return Authorization{}, nil } ok, err := c.redisClient.Expire(c.ctx, string(sk), c.ttl).Result() if err == redis.Nil { logger.Println("session consumer touch :", pk, err) return Authorization{}, nil } else if err != nil { logger.Println("session consumer touch :", pk, err) return Authorization{}, err } if ok { // redis에 살아있다. si, err := c.query_internal(sk) if err != nil { logger.Println("session consumer touch(ok) :", pk, err) return Authorization{}, err } if si == nil { logger.Println("session consumer touch(ok, si nil) :", pk) return Authorization{}, nil } return *si.Authorization, nil } return Authorization{}, nil } func (c *consumer_redis) Revoke(pk string) { sk := publickey_to_storagekey(publickey(pk)) c.redisClient.Del(c.ctx, string(sk)) c.redisClient.Publish(c.ctx, c.deleteChannel, string(sk)).Result() } func (c *consumer_redis) IsRevoked(accid primitive.ObjectID) bool { sk := make_storagekey(accid) c.lock.Lock() defer c.lock.Unlock() if _, deleted := c.stages[0].deleted[sk]; deleted { return true } if _, deleted := c.stages[1].deleted[sk]; deleted { return true } return false } func (c *consumer_redis) RegisterOnSessionInvalidated(cb func(primitive.ObjectID)) { c.onSessionInvalidated = append(c.onSessionInvalidated, cb) }