Add contree test

master
kts of kettek (POWERQWACK) 2021-02-20 06:25:20 -08:00
parent b3b66d414e
commit c046e55a15
5 changed files with 265 additions and 0 deletions

72
contree/conn.go 100644
View File

@ -0,0 +1,72 @@
package main
import (
"net"
"sync"
"time"
)
// Reference: https://github.com/pion/sctp/blob/master/association_test.go
// Since UDP is connectionless, as a server, it doesn't know how to reply
// simply using the `Write` method. So, to make it work, `disconnectedPacketConn`
// will infer the last packet that it reads as the reply address for `Write`
type disconnectedPacketConn struct { // nolint: unused
mu sync.RWMutex
rAddr net.Addr
pConn net.PacketConn
}
// Read
func (c *disconnectedPacketConn) Read(p []byte) (int, error) {
i, rAddr, err := c.pConn.ReadFrom(p)
if err != nil {
return 0, err
}
c.mu.Lock()
c.rAddr = rAddr
c.mu.Unlock()
return i, err
}
// Write writes len(p) bytes from p to the DTLS connection
func (c *disconnectedPacketConn) Write(p []byte) (n int, err error) {
return c.pConn.WriteTo(p, c.RemoteAddr())
}
// Close closes the conn and releases any Read calls
func (c *disconnectedPacketConn) Close() error {
return c.pConn.Close()
}
// LocalAddr is a stub
func (c *disconnectedPacketConn) LocalAddr() net.Addr {
if c.pConn != nil {
return c.pConn.LocalAddr()
}
return nil
}
// RemoteAddr is a stub
func (c *disconnectedPacketConn) RemoteAddr() net.Addr {
c.mu.RLock()
defer c.mu.RUnlock()
return c.rAddr
}
// SetDeadline is a stub
func (c *disconnectedPacketConn) SetDeadline(t time.Time) error {
return nil
}
// SetReadDeadline is a stub
func (c *disconnectedPacketConn) SetReadDeadline(t time.Time) error {
return nil
}
// SetWriteDeadline is a stub
func (c *disconnectedPacketConn) SetWriteDeadline(t time.Time) error {
return nil
}

86
contree/dialer.go 100644
View File

@ -0,0 +1,86 @@
package main
import (
"encoding/json"
"fmt"
"log"
"net"
"time"
"github.com/pion/logging"
"github.com/pion/sctp"
)
func dial(address string) {
// Dial the target.
dialConn, err := net.Dial("udp", address)
if err != nil {
log.Panic(err)
}
defer func() {
if closeErr := dialConn.Close(); closeErr != nil {
panic(err)
}
}()
fmt.Println("Dialed UDP")
// Create the client.
config := sctp.Config{
NetConn: dialConn,
LoggerFactory: logging.NewDefaultLoggerFactory(),
}
client, err := sctp.Client(config)
if err != nil {
log.Panic(err)
}
defer func() {
if closeErr := client.Close(); closeErr != nil {
panic(err)
}
}()
fmt.Println("Created client")
// Create the stream.
stream, err := client.OpenStream(0, sctp.PayloadTypeWebRTCString)
if err != nil {
log.Panic(err)
}
defer func() {
if closeErr := stream.Close(); closeErr != nil {
panic(err)
}
}()
stream.SetReliabilityParams(false, sctp.ReliabilityTypeReliable, 10)
fmt.Println("Created stream")
// Writer
go func() {
enc := json.NewEncoder(stream)
msgNum := 1
for {
time.Sleep(2 * time.Second)
enc.Encode(Message{
seq: msgNum,
msg: "from dialer",
})
msgNum++
}
}()
// Reader
dec := json.NewDecoder(stream)
log.Println("dailer: Made a JSON stream")
for {
var msg Message
err := dec.Decode(&msg)
if err == nil {
fmt.Printf("dialer: Got msg %+v\n", msg)
if msg.msg == "bye" {
return
}
} else {
fmt.Printf("dialer: Got error %+v\n", err)
}
}
}

View File

@ -0,0 +1,76 @@
package main
import (
"encoding/json"
"fmt"
"log"
"net"
"github.com/pion/logging"
"github.com/pion/sctp"
)
func listen(ip net.IP, port int) {
addr := net.UDPAddr{
IP: ip,
Port: port,
}
mainConn, err := net.ListenUDP("udp", &addr)
if err != nil {
log.Panic(err)
}
defer mainConn.Close()
fmt.Println("Created listener")
config := sctp.Config{
NetConn: &disconnectedPacketConn{pConn: mainConn},
LoggerFactory: logging.NewDefaultLoggerFactory(),
}
for {
server, err := sctp.Server(config)
if err != nil {
log.Panic(err)
}
defer server.Close()
fmt.Println("Created server")
stream, err := server.AcceptStream()
if err != nil {
log.Panic(err)
}
stream.SetReliabilityParams(false, sctp.ReliabilityTypeReliable, 10)
go runListenStream(stream)
}
}
func runListenStream(stream *sctp.Stream) {
defer func() {
if closeErr := stream.Close(); closeErr != nil {
panic(closeErr)
}
}()
dec := json.NewDecoder(stream)
enc := json.NewEncoder(stream)
log.Println("Made a JSON stream")
msgNum := 1
for {
var msg Message
err := dec.Decode(&msg)
if err == nil {
fmt.Printf("Got msg %+v\n", msg)
if msg.msg == "bye" {
return
} else {
enc.Encode(Message{
seq: msgNum,
msg: fmt.Sprintf("from listener to msg %d", msg.seq),
})
msgNum++
}
} else {
fmt.Printf("Got error %+v\n", err)
}
}
}

25
contree/main.go 100644
View File

@ -0,0 +1,25 @@
package main
import (
"flag"
"fmt"
"net"
)
func main() {
var address string
var port int
var shouldDial bool
flag.StringVar(&address, "address", "127.0.0.1", "Address to dial or listen")
flag.IntVar(&port, "port", 10300, "Port to dial or listen")
flag.BoolVar(&shouldDial, "dial", false, "Whether to dial target address or not")
flag.Parse()
if shouldDial {
dial(fmt.Sprintf("%s:%d", address, port))
} else {
listen(net.ParseIP(address), port)
}
}

View File

@ -0,0 +1,6 @@
package main
type Message struct {
seq int
msg string
}