diff --git a/net/server/internal/api_removeAgentToken.go b/net/server/internal/api_removeAgentToken.go index cf4c7898..aecf1ea2 100644 --- a/net/server/internal/api_removeAgentToken.go +++ b/net/server/internal/api_removeAgentToken.go @@ -13,16 +13,18 @@ func RemoveAgentToken(w http.ResponseWriter, r *http.Request) { // logout of all devices logoutMode := r.FormValue("all") == "true" - // parse authentication token - target, access, err := ParseToken(r.FormValue("agent")) - if err != nil { - ErrResponse(w, http.StatusBadRequest, err); - return - } - if logoutMode { + + // find account + account, code, err := ParamAgentToken(r, true) + if err != nil { + PrintMsg(r) + ErrResponse(w, code, err) + return + } + var sessions []store.Session - if err = store.DB.Where("account_id = ?", target, access).Find(&sessions).Error; err != nil { + if err = store.DB.Where("account_id = ?", account.GUID).Find(&sessions).Error; err != nil { ErrResponse(w, http.StatusInternalServerError, err); return; } @@ -43,7 +45,17 @@ func RemoveAgentToken(w http.ResponseWriter, r *http.Request) { ErrResponse(w, http.StatusInternalServerError, err) return } + + ClearStatus(account); } else { + + // parse authentication token + target, access, err := ParseToken(r.FormValue("agent")) + if err != nil { + ErrResponse(w, http.StatusBadRequest, err); + return + } + var session store.Session if err = store.DB.Where("account_id = ? AND token = ?", target, access).Find(&session).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { diff --git a/net/server/internal/api_status.go b/net/server/internal/api_status.go index 5dcf0301..0a2d1ee9 100644 --- a/net/server/internal/api_status.go +++ b/net/server/internal/api_status.go @@ -14,6 +14,7 @@ var wsSync sync.Mutex var wsExit = make(chan bool, 1) var statusListener = make(map[uint][]chan<- []byte) var revisionListener = make(map[uint][]chan<- []byte) +var disconnectListener = make(map[uint][]chan<- bool) var upgrader = websocket.Upgrader{} //Status handler for websocket connection @@ -92,6 +93,14 @@ func Status(w http.ResponseWriter, r *http.Request) { defer removeRevisionListener(session.Account.ID, c) } + // open channel for disconnection + d := make(chan bool) + defer close(d) + + // register channel for updates + addDisconnectListener(session.Account.ID, d) + defer removeDisconnectListener(session.Account.ID, d) + // start ping pong ticker ticker := time.NewTicker(60 * time.Second) defer ticker.Stop() @@ -109,6 +118,9 @@ func Status(w http.ResponseWriter, r *http.Request) { ErrMsg(err) return } + case <-d: + LogMsg("user discconection") + return case <-wsExit: LogMsg("exiting server") wsExit <- true @@ -205,6 +217,22 @@ func SetStatus(account *store.Account) { } } +//ClearStatus disconnects websockets from account +func ClearStatus(account *store.Account) { + + // lock access to statusListener + wsSync.Lock() + defer wsSync.Unlock(); + + // notify all disconnect listeners + chs, ok := disconnectListener[account.ID] + if ok { + for _, ch := range chs { + ch <- true + } + } +} + func addStatusListener(act uint, ch chan<- []byte) { // lock access to statusListener @@ -278,3 +306,41 @@ func removeRevisionListener(act uint, ch chan<- []byte) { } } } + +func addDisconnectListener(act uint, ch chan<- bool) { + + // lock access to disconnectListener + wsSync.Lock() + defer wsSync.Unlock() + + // add new listener to map + chs, ok := disconnectListener[act] + if ok { + disconnectListener[act] = append(chs, ch) + } else { + disconnectListener[act] = []chan<- bool{ch} + } +} + +func removeDisconnectListener(act uint, ch chan<- bool) { + + // lock access to revisionListener + wsSync.Lock() + defer wsSync.Unlock() + + // remove channel from map + chs, ok := disconnectListener[act] + if ok { + for i, c := range chs { + if ch == c { + if len(chs) == 1 { + delete(disconnectListener, act) + } else { + chs[i] = chs[len(chs)-1] + disconnectListener[act] = chs[:len(chs)-1] + } + } + } + } +} +