diff --git a/net/server/internal/sturn/attribute.go b/net/server/internal/sturn/attribute.go index 11741c0e..8b7477b3 100644 --- a/net/server/internal/sturn/attribute.go +++ b/net/server/internal/sturn/attribute.go @@ -22,21 +22,20 @@ func readAttribute(buf []byte, pos int) (error, *SturnAttribute, int) { return errors.New("invalid attribute buffer"), nil, 0 } - var intValue int32 - var strValue string + attr := &SturnAttribute{ atrType: atrType } if atrType == ATRRequestedTransport { if buf[pos + 5] != 0x00 || buf[pos + 6] != 0x00 || buf[pos + 7] != 0x00 { return errors.New("invalid attribute"), nil, 0 } - intValue = int32(buf[pos + 4]) + attr.intValue = int32(buf[pos + 4]) } else if atrType == ATRLifetime { - intValue = 256 * (256 * (256 * int32(buf[pos + 4]) + int32(buf[pos + 5])) + int32(buf[pos + 6])) + int32(buf[pos + 7]); + attr.intValue = 256 * (256 * (256 * int32(buf[pos + 4]) + int32(buf[pos + 5])) + int32(buf[pos + 6])) + int32(buf[pos + 7]); } else if atrType == ATRNonce { - strValue = string(buf[pos + 4:pos + 4+atrLength]); + attr.strValue = string(buf[pos + 4:pos + 4+atrLength]); } else if atrType == ATRUsername { - strValue = string(buf[pos + 4:pos + 4+atrLength]); + attr.strValue = string(buf[pos + 4:pos + 4+atrLength]); } else if atrType == ATRRealm { - strValue = string(buf[pos + 4:pos + 4+atrLength]); + attr.strValue = string(buf[pos + 4:pos + 4+atrLength]); } else if atrType == ATRMessageIntegrity { //fmt.Println("HANDLE: ATRMessageIntegrity"); } else if atrType == ATRMessageIntegritySha256 { @@ -50,29 +49,26 @@ func readAttribute(buf []byte, pos int) (error, *SturnAttribute, int) { if buf[pos + 4] != 0 || buf[pos + 5] != FAMIPv4 { return errors.New("unsupported protocol family"), nil, 0 } - strValue = "" - strValue += strconv.Itoa(int(buf[pos + 8] ^ 0x21)) - strValue += "." - strValue += strconv.Itoa(int(buf[pos + 9] ^ 0x12)) - strValue += "." - strValue += strconv.Itoa(int(buf[pos + 10] ^ 0xA4)) - strValue += "." - strValue += strconv.Itoa(int(buf[pos + 11] ^ 0x42)) - intValue = int32(buf[pos + 6] ^ 0x21) - intValue *= 256 - intValue += int32(buf[pos + 7] ^ 0x12) + attr.strValue = "" + attr.strValue += strconv.Itoa(int(buf[pos + 8] ^ 0x21)) + attr.strValue += "." + attr.strValue += strconv.Itoa(int(buf[pos + 9] ^ 0x12)) + attr.strValue += "." + attr.strValue += strconv.Itoa(int(buf[pos + 10] ^ 0xA4)) + attr.strValue += "." + attr.strValue += strconv.Itoa(int(buf[pos + 11] ^ 0x42)) + attr.intValue = int32(buf[pos + 6] ^ 0x21) + attr.intValue *= 256 + attr.intValue += int32(buf[pos + 7] ^ 0x12) } else if atrType == ATRData { + for i := 0; i < atrLength; i++ { + attr.binValue = append(attr.binValue, buf[pos + 4 + i]) + } } else { fmt.Println("UNKNOWN ATTRIBUTE", atrType); } - return nil, &SturnAttribute{ - atrType: atrType, - intValue: intValue, - strValue: strValue, - }, 4 + padLength; - - return nil, nil, 0 + return nil, attr, 4 + padLength } func writeAttribute(attribute *SturnAttribute, buf []byte, pos int) (error, int) { @@ -125,6 +121,43 @@ func writeAttribute(attribute *SturnAttribute, buf []byte, pos int) (error, int) buf[pos + 10] = byte((ip >> 8) % 256) ^ 0xA4 buf[pos + 11] = byte(ip % 256) ^ 0x42 return nil, 12 + } else if attribute.atrType == ATRXorPeerAddress { + if len(buf) - pos < 12 { + return errors.New("invalid buffer size"), 0 + } + ip := 0 + parts := strings.Split(attribute.strValue, "."); + for i := 0; i < 4; i++ { + val, _ := strconv.Atoi(parts[i]); + ip = (ip * 256) + val; + } + buf[pos + 1], buf[pos + 0] = setAttributeType(ATRXorPeerAddress); + buf[pos + 2] = 0x00 + buf[pos + 3] = 0x08 + buf[pos + 4] = 0x00 + buf[pos + 5] = FAMIPv4 + buf[pos + 6] = byte((attribute.intValue >> 8) % 256) ^ 0x21 + buf[pos + 7] = byte(attribute.intValue % 256) ^ 0x12 + buf[pos + 8] = byte((ip >> 24) % 256) ^ 0x21 + buf[pos + 9] = byte((ip >> 16) % 256) ^ 0x12 + buf[pos + 10] = byte((ip >> 8) % 256) ^ 0xA4 + buf[pos + 11] = byte(ip % 256) ^ 0x42 + return nil, 12 + } else if attribute.atrType == ATRData { + paddedLen := ((len(attribute.binValue) + 3) >> 2) << 2 + if len(buf) - pos < 4 + paddedLen { + return errors.New("invalid buffer size"), 0 + } + buf[pos + 1], buf[pos + 0] = setAttributeType(ATRData) + buf[pos + 2] = byte((len(attribute.binValue) >> 8) % 256) + buf[pos + 3] = byte(len(attribute.binValue) % 256) + for i := 0; i < len(attribute.binValue); i++ { + buf[pos + 4 + i] = attribute.binValue[i] + } + for i := len(attribute.binValue); i < paddedLen; i++ { + buf[pos + 4 + i] = 0x00 + } + return nil, 4 + paddedLen } else if attribute.atrType == ATRLifetime { if len(buf) - pos < 8 { return errors.New("invalid buffer size"), 0 diff --git a/net/server/internal/sturn/message.go b/net/server/internal/sturn/message.go index d96ce5fe..ba4405c1 100644 --- a/net/server/internal/sturn/message.go +++ b/net/server/internal/sturn/message.go @@ -91,8 +91,7 @@ func (s *Sturn) handleMessage(buf []byte, addr net.Addr) { err, msg := readMessage(buf); if err != nil { - fmt.Println(addr.String(), buf); - fmt.Println(err); + fmt.Println(err, addr.String(), buf); return } if msg == nil { @@ -113,7 +112,7 @@ func (s *Sturn) handleMessage(buf []byte, addr net.Addr) { s.handleCreatePermissionRequest(msg, addr); } else if msg.class == CLSIndication && msg.method == MEHSend { fmt.Println("stun/turn send"); - s.handleSendIndication(msg, addr); + s.handleSendIndication(msg, addr, buf); } else { fmt.Println("unsupported message", buf); } @@ -164,7 +163,7 @@ func (s *Sturn) handleCreatePermissionRequest(msg *SturnMessage, addr net.Addr) s.sync.Lock(); defer s.sync.Unlock(); - session, set := sturn.sessions[username.strValue] + _, set := sturn.sessions[username.strValue] if !set { fmt.Println("no session", addr.String()); s.sendRequestError(msg, addr, 401) @@ -172,7 +171,7 @@ func (s *Sturn) handleCreatePermissionRequest(msg *SturnMessage, addr net.Addr) } source := addr.String() - allocation, found := session.allocations[source] + allocation, found := s.allocations[source] if !found { fmt.Println("no allocation"); s.sendRequestError(msg, addr, 400) @@ -180,7 +179,6 @@ func (s *Sturn) handleCreatePermissionRequest(msg *SturnMessage, addr net.Addr) } allocation.permissions = append(allocation.permissions, permission.strValue); - fmt.Println("---> ", allocation.port, allocation.permissions); var attributes []SturnAttribute attributes = append(attributes, SturnAttribute{ @@ -202,8 +200,44 @@ func (s *Sturn) handleCreatePermissionRequest(msg *SturnMessage, addr net.Addr) return } -func (s *Sturn) handleSendIndication(msg *SturnMessage, addr net.Addr) { -// fmt.Println(addr.String(), msg); +func (s *Sturn) handleSendIndication(msg *SturnMessage, addr net.Addr, buf []byte) { + + peer := getAttribute(msg, ATRXorPeerAddress) + if peer == nil { + fmt.Println("no peer"); + return + } + + data := getAttribute(msg, ATRData) + if data == nil { + fmt.Println("no data"); + return + } + + s.sync.Lock(); + defer s.sync.Unlock(); + source := addr.String() + allocation, found := s.allocations[source] + if !found { + fmt.Println("no allocation"); + return + } + + for _, permission := range allocation.permissions { + if permission == peer.strValue { + address := fmt.Sprintf("%s:%d", peer.strValue, peer.intValue) + dst, err := net.ResolveUDPAddr("udp", address) + if err != nil { + fmt.Println("no resolve"); + return + } + + _, err = allocation.conn.WriteTo(data.binValue, dst) + if err != nil { + fmt.Println("write error"); + } + } + } } func (s *Sturn) handleBindingRequest(msg *SturnMessage, addr net.Addr) { @@ -250,7 +284,7 @@ func (s *Sturn) handleRefreshRequest(msg *SturnMessage, addr net.Addr) { return } -func setAllocation(source string, transaction []byte, response []byte, port int, conn net.PacketConn, session *SturnSession) (*SturnAllocation) { +func (s *Sturn) setAllocation(source string, transaction []byte, response []byte, port int, conn net.PacketConn, session *SturnSession) (*SturnAllocation) { allocation := &SturnAllocation{} allocation.port = port allocation.conn = conn @@ -259,12 +293,12 @@ func setAllocation(source string, transaction []byte, response []byte, port int, copy(allocation.transaction, transaction) allocation.response = make([]byte, len(response)) copy(allocation.response, response) - session.allocations[source] = allocation + s.allocations[source] = allocation return allocation } -func getAllocation(source string, transaction []byte, session *SturnSession) (*SturnAllocation, error) { - for _, allocation := range session.allocations { +func (s *Sturn) getAllocation(source string, transaction []byte, session *SturnSession) (*SturnAllocation, error) { + for _, allocation := range s.allocations { if allocation.source == source { if len(allocation.transaction) == len(transaction) { match := true @@ -301,7 +335,7 @@ func (s *Sturn) handleAllocateRequest(msg *SturnMessage, addr net.Addr) { return } - allocation, collision := getAllocation(addr.String(), msg.transaction, session) + allocation, collision := s.getAllocation(addr.String(), msg.transaction, session) if collision != nil { fmt.Println("5tuple collision", addr.String()) s.sendRequestError(msg, addr, 403) @@ -327,7 +361,6 @@ func (s *Sturn) handleAllocateRequest(msg *SturnMessage, addr net.Addr) { return } - fmt.Println("> ", relayPort, "< ", addr.String(), msg); address := strings.Split(addr.String(), ":") ip := address[0]; port, _ := strconv.Atoi(address[1]); @@ -363,7 +396,7 @@ func (s *Sturn) handleAllocateRequest(msg *SturnMessage, addr net.Addr) { if err != nil { fmt.Printf("failed to write stun response") } else { - allocation := setAllocation(addr.String(), msg.transaction, s.buf[:n], relayPort, conn, session) + allocation := s.setAllocation(addr.String(), msg.transaction, s.buf[:n], relayPort, conn, session) (*s.conn).WriteTo(s.buf[:n], addr) go s.relay(allocation); } @@ -380,25 +413,56 @@ func getAttribute(msg *SturnMessage, atrType int) (attr *SturnAttribute) { } func (s *Sturn) relay(allocation *SturnAllocation) { -fmt.Println("STARTED RELAY"); + data := make([]byte, SturnMaxSize) + buf := make([]byte, SturnMaxSize) for { - buf := make([]byte, SturnMaxSize) - n, addr, err := allocation.conn.ReadFrom(buf) + n, addr, err := allocation.conn.ReadFrom(data) if err != nil { fmt.Println(err) + // CLEANUP ALLOCATION return } - fmt.Println("GET REPLAY PACKET:", allocation.port, allocation.permissions, addr.String()); + fmt.Println("stun/turn relay"); s.sync.Lock(); - defer s.sync.Unlock(); - ip := strings.Split(addr.String(), ":") + split := strings.Split(addr.String(), ":") + ip := split[0] + port, _ := strconv.Atoi(split[1]); for _, permission := range allocation.permissions { - if permission == ip[0] { - fmt.Println("HANDLE PACKET", allocation.port, n, ip[0]); + if permission == ip { + + var attributes []SturnAttribute + attributes = append(attributes, SturnAttribute{ + atrType: ATRXorPeerAddress, + strValue: ip, + intValue: int32(port), + }) + attributes = append(attributes, SturnAttribute{ + atrType: ATRData, + binValue: data[:n], + }) + + relay := &SturnMessage{ + class: CLSIndication, + method: MEHData, + transaction: []byte{ 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22 }, + attributes: attributes, + }; + + err, l := writeMessage(relay, buf) + dst, err := net.ResolveUDPAddr("udp", allocation.source) + if err != nil { + fmt.Println("no resolve"); + } else { + _, err := allocation.conn.WriteTo(buf[:l], dst); + if err != nil { + fmt.Println("writeto failed"); + } + } } } + s.sync.Unlock(); } } diff --git a/net/server/internal/sturn/sturn.go b/net/server/internal/sturn/sturn.go index 83693c56..6bc03afc 100644 --- a/net/server/internal/sturn/sturn.go +++ b/net/server/internal/sturn/sturn.go @@ -30,7 +30,7 @@ type SturnAllocation struct { type SturnSession struct { user string auth string - allocations map[string]*SturnAllocation + relayPorts []int } type Sturn struct { @@ -47,6 +47,7 @@ type Sturn struct { relayCount int relayPorts map[int]bool relayIndex int + allocations map[string]*SturnAllocation } func Listen(port int, relayStart int, relayCount int) (error) { @@ -78,6 +79,7 @@ func Listen(port int, relayStart int, relayCount int) (error) { conn: &conn, buf: make([]byte, SturnMaxSize), sessions: make(map[string]*SturnSession), + allocations: make(map[string]*SturnAllocation), } go sturn.serve(conn); @@ -124,7 +126,6 @@ func TestSession() { session := &SturnSession{ user: "user", auth: "pass", - allocations: make(map[string]*SturnAllocation), } sturn.sessions["user"] = session } @@ -145,7 +146,6 @@ func (s *Sturn) addSession() (*SturnSession, error) { session := &SturnSession{ user: user, auth: hex.EncodeToString(authBin), - allocations: make(map[string]*SturnAllocation), } s.sessions[user] = session return session, nil diff --git a/net/server/internal/sturn/types.go b/net/server/internal/sturn/types.go index a1fbc0e0..9a2f2547 100644 --- a/net/server/internal/sturn/types.go +++ b/net/server/internal/sturn/types.go @@ -54,6 +54,7 @@ type SturnAttribute struct { byteValue byte strValue string intValue int32 + binValue []byte } type SturnMessage struct { @@ -144,6 +145,9 @@ func setMessageType(class int, method int) (byte, byte) { if class == CLSError && method == MEHCreatePermission { return 0x01, 0x18 } + if class == CLSIndication && method == MEHData { + return 0x00, 0x17 + } return 0x00, 0x00 }