diff --git a/net/server/internal/sturn/message.go b/net/server/internal/sturn/message.go index ec5e9612..e00bc663 100644 --- a/net/server/internal/sturn/message.go +++ b/net/server/internal/sturn/message.go @@ -98,21 +98,15 @@ func (s *Sturn) handleMessage(buf []byte, addr net.Addr) { } if msg.class == CLSRequest && msg.method == MEHBinding { - err := s.handleBindingRequest(msg, addr); - if err != nil { - fmt.Println(err); - } + s.handleBindingRequest(msg, addr); } else if msg.class == CLSRequest && msg.method == MEHAllocate { - err := s.handleAllocateRequest(msg, addr); - if err != nil { - fmt.Println(err); - } + s.handleAllocateRequest(msg, addr); } else { fmt.Println("unsupported message", buf); } } -func (s *Sturn) handleBindingRequest(msg *SturnMessage, addr net.Addr) (error) { +func (s *Sturn) handleBindingRequest(msg *SturnMessage, addr net.Addr) { address := strings.Split(addr.String(), ":") ip := address[0]; @@ -136,11 +130,61 @@ func (s *Sturn) handleBindingRequest(msg *SturnMessage, addr net.Addr) (error) { } else { (*s.conn).WriteTo(s.buf[:n], addr); } - return nil + return } -func (s *Sturn) handleAllocateRequest(msg *SturnMessage, addr net.Addr) (error) { - fmt.Println("ALLOCATE REQUEST"); - return nil +func (s *Sturn) sendAllocateError(msg *SturnMessage, addr net.Addr) { + var attributes []SturnAttribute + attributes = append(attributes, SturnAttribute{ + atrType: ATRErrorCode, + intValue: 400, + }) + attributes = append(attributes, SturnAttribute{ + atrType: ATRNonce, + strValue: "", + }) + attributes = append(attributes, SturnAttribute{ + atrType: ATRRealm, + strValue: "databag", + }) + response := &SturnMessage{ + class: CLSError, + method: MEHAllocate, + transaction: msg.transaction, + attributes: attributes, + }; + err, n := writeMessage(response, s.buf); + if err != nil { + fmt.Printf("failed to write stun response"); + } else { + (*s.conn).WriteTo(s.buf[:n], addr); + } } +func (s *Sturn) handleAllocateRequest(msg *SturnMessage, addr net.Addr) { + + username := getAttribute(msg, ATRUsername); + if username == nil { + s.sendAllocateError(msg, addr); + return; + } + + port, err := s.getRelayPort(); + if err != nil { + fmt.Println(err); + s.sendAllocateError(msg, addr) + return + } + + fmt.Println("ALLOCATE REQUEST", msg, port); + return +} + +func getAttribute(msg *SturnMessage, atrType int) (attr *SturnAttribute) { + for _, attribute := range msg.attributes { + if attribute.atrType == ATRUsername { + attr = &attribute; + } + } + return +} diff --git a/net/server/internal/sturn/sturn.go b/net/server/internal/sturn/sturn.go index ce0a590b..349a828c 100644 --- a/net/server/internal/sturn/sturn.go +++ b/net/server/internal/sturn/sturn.go @@ -29,16 +29,18 @@ type Sturn struct { sessionId int sessions map[string]*SturnSession closing bool - port uint - relayStart uint - relayEnd uint + port int conn *net.PacketConn closed chan bool buf []byte publicIp string + relayStart int + relayCount int + relayPorts map[int]bool + relayIndex int } -func Listen(port uint, relayStart uint, relayEnd uint) (error) { +func Listen(port int, relayStart int, relayCount int) (error) { if (sturn != nil) { (*sturn.conn).Close() @@ -52,12 +54,18 @@ func Listen(port uint, relayStart uint, relayEnd uint) (error) { return err } + relayPorts := make(map[int]bool) + for i := 0; i < relayCount; i++ { + relayPorts[i] = true + } + sturn := &Sturn{ sessionId: 0, closing: false, port: port, relayStart: relayStart, - relayEnd: relayEnd, + relayCount: relayCount, + relayPorts: relayPorts, conn: &conn, buf: make([]byte, SturnMaxSize), } @@ -137,3 +145,24 @@ func (s *Sturn) addSession() (*SturnSession, error) { return session, nil } +func (s *Sturn) getRelayPort() (int, error) { + s.sync.Lock(); + defer s.sync.Unlock(); + s.relayIndex += 1; + for i := 0; i < s.relayCount; i++ { + key := (i + s.relayIndex) % s.relayCount; + if s.relayPorts[key] { + s.relayPorts[key] = false + return s.relayStart + key, nil + } + } + return 0, errors.New("no available relay port") +} + +func (s *Sturn) setRelayPort(port int) { + s.sync.Lock() + defer s.sync.Unlock(); + key := port - s.relayStart + s.relayPorts[key] = true +} + diff --git a/net/server/main.go b/net/server/main.go index 71cf71b4..bc751cf4 100644 --- a/net/server/main.go +++ b/net/server/main.go @@ -19,7 +19,7 @@ func main() { origins := handlers.AllowedOrigins([]string{"*"}) methods := handlers.AllowedMethods([]string{"GET", "HEAD", "POST", "PUT", "DELETE", "OPTIONS"}) - sturn.Listen(5001, 5002, 5101) + sturn.Listen(5001, 5002, 99) sturn.TestSession() args := os.Args