force disconnection when logging out of all devices

This commit is contained in:
Pierre Balzack 2023-06-27 22:54:19 -07:00
parent 766bc52965
commit 4b52c477f8
2 changed files with 86 additions and 8 deletions

View File

@ -13,16 +13,18 @@ func RemoveAgentToken(w http.ResponseWriter, r *http.Request) {
// logout of all devices // logout of all devices
logoutMode := r.FormValue("all") == "true" logoutMode := r.FormValue("all") == "true"
// parse authentication token
target, access, err := ParseToken(r.FormValue("agent"))
if err != nil {
ErrResponse(w, http.StatusBadRequest, err);
return
}
if logoutMode { if logoutMode {
// find account
account, code, err := ParamAgentToken(r, true)
if err != nil {
PrintMsg(r)
ErrResponse(w, code, err)
return
}
var sessions []store.Session var sessions []store.Session
if err = store.DB.Where("account_id = ?", target, access).Find(&sessions).Error; err != nil { if err = store.DB.Where("account_id = ?", account.GUID).Find(&sessions).Error; err != nil {
ErrResponse(w, http.StatusInternalServerError, err); ErrResponse(w, http.StatusInternalServerError, err);
return; return;
} }
@ -43,7 +45,17 @@ func RemoveAgentToken(w http.ResponseWriter, r *http.Request) {
ErrResponse(w, http.StatusInternalServerError, err) ErrResponse(w, http.StatusInternalServerError, err)
return return
} }
ClearStatus(account);
} else { } else {
// parse authentication token
target, access, err := ParseToken(r.FormValue("agent"))
if err != nil {
ErrResponse(w, http.StatusBadRequest, err);
return
}
var session store.Session var session store.Session
if err = store.DB.Where("account_id = ? AND token = ?", target, access).Find(&session).Error; err != nil { if err = store.DB.Where("account_id = ? AND token = ?", target, access).Find(&session).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {

View File

@ -14,6 +14,7 @@ var wsSync sync.Mutex
var wsExit = make(chan bool, 1) var wsExit = make(chan bool, 1)
var statusListener = make(map[uint][]chan<- []byte) var statusListener = make(map[uint][]chan<- []byte)
var revisionListener = make(map[uint][]chan<- []byte) var revisionListener = make(map[uint][]chan<- []byte)
var disconnectListener = make(map[uint][]chan<- bool)
var upgrader = websocket.Upgrader{} var upgrader = websocket.Upgrader{}
//Status handler for websocket connection //Status handler for websocket connection
@ -92,6 +93,14 @@ func Status(w http.ResponseWriter, r *http.Request) {
defer removeRevisionListener(session.Account.ID, c) defer removeRevisionListener(session.Account.ID, c)
} }
// open channel for disconnection
d := make(chan bool)
defer close(d)
// register channel for updates
addDisconnectListener(session.Account.ID, d)
defer removeDisconnectListener(session.Account.ID, d)
// start ping pong ticker // start ping pong ticker
ticker := time.NewTicker(60 * time.Second) ticker := time.NewTicker(60 * time.Second)
defer ticker.Stop() defer ticker.Stop()
@ -109,6 +118,9 @@ func Status(w http.ResponseWriter, r *http.Request) {
ErrMsg(err) ErrMsg(err)
return return
} }
case <-d:
LogMsg("user discconection")
return
case <-wsExit: case <-wsExit:
LogMsg("exiting server") LogMsg("exiting server")
wsExit <- true wsExit <- true
@ -205,6 +217,22 @@ func SetStatus(account *store.Account) {
} }
} }
//ClearStatus disconnects websockets from account
func ClearStatus(account *store.Account) {
// lock access to statusListener
wsSync.Lock()
defer wsSync.Unlock();
// notify all disconnect listeners
chs, ok := disconnectListener[account.ID]
if ok {
for _, ch := range chs {
ch <- true
}
}
}
func addStatusListener(act uint, ch chan<- []byte) { func addStatusListener(act uint, ch chan<- []byte) {
// lock access to statusListener // lock access to statusListener
@ -278,3 +306,41 @@ func removeRevisionListener(act uint, ch chan<- []byte) {
} }
} }
} }
func addDisconnectListener(act uint, ch chan<- bool) {
// lock access to disconnectListener
wsSync.Lock()
defer wsSync.Unlock()
// add new listener to map
chs, ok := disconnectListener[act]
if ok {
disconnectListener[act] = append(chs, ch)
} else {
disconnectListener[act] = []chan<- bool{ch}
}
}
func removeDisconnectListener(act uint, ch chan<- bool) {
// lock access to revisionListener
wsSync.Lock()
defer wsSync.Unlock()
// remove channel from map
chs, ok := disconnectListener[act]
if ok {
for i, c := range chs {
if ch == c {
if len(chs) == 1 {
delete(disconnectListener, act)
} else {
chs[i] = chs[len(chs)-1]
disconnectListener[act] = chs[:len(chs)-1]
}
}
}
}
}