Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 22 additions & 15 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -560,7 +560,7 @@ func (c *Conn) writePackets(ctx context.Context, pkts []*packet) error {
c.lock.Lock()
defer c.lock.Unlock()

var rawPackets [][]byte
rawPackets := make([][]byte, 0, len(pkts))

for _, pkt := range pkts {
if dtlsHandshake, ok := pkt.record.Content.(*handshake.Handshake); ok {
Expand Down Expand Up @@ -614,11 +614,16 @@ func (c *Conn) compactRawPackets(rawPackets [][]byte) [][]byte {
return rawPackets
}

combinedRawPackets := make([][]byte, 0)
currentCombinedRawPacket := make([]byte, 0)
combinedRawPackets := make([][]byte, 0, len(rawPackets))
var currentCombinedRawPacket []byte

for _, rawPacket := range rawPackets {
if len(currentCombinedRawPacket) > 0 && len(currentCombinedRawPacket)+len(rawPacket) >= c.maximumTransmissionUnit {
if len(currentCombinedRawPacket) == 0 && len(rawPacket) >= c.maximumTransmissionUnit {
combinedRawPackets = append(combinedRawPackets, rawPacket)

continue
} else if len(currentCombinedRawPacket) > 0 &&
len(currentCombinedRawPacket)+len(rawPacket) >= c.maximumTransmissionUnit {
combinedRawPackets = append(combinedRawPackets, currentCombinedRawPacket)
currentCombinedRawPacket = []byte{}
}
Expand Down Expand Up @@ -697,8 +702,6 @@ func (c *Conn) processPacket(pkt *packet) ([]byte, error) { //nolint:cyclop

//nolint:cyclop
func (c *Conn) processHandshakePacket(pkt *packet, dtlsHandshake *handshake.Handshake) ([][]byte, error) {
rawPackets := make([][]byte, 0)

handshakeFragments, err := c.fragmentHandshake(dtlsHandshake)
if err != nil {
return nil, err
Expand All @@ -708,6 +711,7 @@ func (c *Conn) processHandshakePacket(pkt *packet, dtlsHandshake *handshake.Hand
c.state.localSequenceNumber = append(c.state.localSequenceNumber, uint64(0))
}

rawPackets := make([][]byte, 0, len(handshakeFragments))
for _, handshakeFragment := range handshakeFragments {
seq := atomic.AddUint64(&c.state.localSequenceNumber[epoch], 1) - 1
if seq > recordlayer.MaxSequenceNumber {
Expand All @@ -733,12 +737,14 @@ func (c *Conn) processHandshakePacket(pkt *packet, dtlsHandshake *handshake.Hand
ConnectionID: c.state.remoteConnectionID,
SequenceNumber: pkt.record.Header.SequenceNumber,
}
rawPacket, err = cidHeader.Marshal()

rawPacket = make([]byte, cidHeader.MarshalSize()+len(rawInner))
_, err = cidHeader.MarshalTo(rawPacket)
if err != nil {
return nil, err
}
pkt.record.Header = *cidHeader
rawPacket = append(rawPacket, rawInner...)
copy(rawPacket[cidHeader.MarshalSize():], rawInner)
} else {
recordlayerHeader := &recordlayer.Header{
Version: pkt.record.Header.Version,
Expand All @@ -748,13 +754,14 @@ func (c *Conn) processHandshakePacket(pkt *packet, dtlsHandshake *handshake.Hand
SequenceNumber: seq,
}

rawPacket, err = recordlayerHeader.Marshal()
rawPacket = make([]byte, recordlayerHeader.MarshalSize()+len(handshakeFragment))
_, err = recordlayerHeader.MarshalTo(rawPacket)
if err != nil {
return nil, err
}

pkt.record.Header = *recordlayerHeader
rawPacket = append(rawPacket, handshakeFragment...)
copy(rawPacket[recordlayerHeader.MarshalSize():], handshakeFragment)
}

if pkt.shouldEncrypt {
Expand All @@ -777,8 +784,6 @@ func (c *Conn) fragmentHandshake(dtlsHandshake *handshake.Handshake) ([][]byte,
return nil, err
}

fragmentedHandshakes := make([][]byte, 0)

contentFragments := splitBytes(content, c.maximumTransmissionUnit)
if len(contentFragments) == 0 {
contentFragments = [][]byte{
Expand All @@ -787,6 +792,7 @@ func (c *Conn) fragmentHandshake(dtlsHandshake *handshake.Handshake) ([][]byte,
}

offset := 0
fragmentedHandshakes := make([][]byte, 0, len(contentFragments))
for _, contentFragment := range contentFragments {
contentFragmentLen := len(contentFragment)

Expand All @@ -800,12 +806,13 @@ func (c *Conn) fragmentHandshake(dtlsHandshake *handshake.Handshake) ([][]byte,

offset += contentFragmentLen

fragmentedHandshake, err := headerFragment.Marshal()
fragmentedHandshake := make([]byte, handshake.HeaderLength+len(contentFragment))
_, err := headerFragment.MarshalTo(fragmentedHandshake)
if err != nil {
return nil, err
}

fragmentedHandshake = append(fragmentedHandshake, contentFragment...)
copy(fragmentedHandshake[handshake.HeaderLength:], contentFragment)
fragmentedHandshakes = append(fragmentedHandshakes, fragmentedHandshake)
}

Expand Down Expand Up @@ -1014,7 +1021,7 @@ func (c *Conn) handleIncomingPacket(
if header.ContentType == protocol.ContentTypeConnectionID {
originalCID = true
ip := &recordlayer.InnerPlaintext{}
if err := ip.Unmarshal(buf[header.Size():]); err != nil { //nolint:govet
if err := ip.Unmarshal(buf[header.MarshalSize():]); err != nil { //nolint:govet
c.log.Debugf("unpacking inner plaintext failed: %s", err)

return false, false, nil, nil
Expand Down
2 changes: 2 additions & 0 deletions errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,8 @@ var (
//nolint:err113
errFailedToAccessPoolReadBuffer = &InternalError{Err: errors.New("failed to access pool read buffer")}
//nolint:err113
errFailedToAccessPoolTimer = &InternalError{Err: errors.New("failed to access pool timer")}
//nolint:err113
errFragmentBufferOverflow = &InternalError{Err: errors.New("fragment buffer overflow")}

//nolint:err113
Expand Down
22 changes: 20 additions & 2 deletions handshaker.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,8 @@ func (s *handshakeFSM) Run(ctx context.Context, conn flightConn, initialState ha
close(s.closed)
}()
for {
s.cfg.log.Tracef("[handshake:%s] %s: %s", srvCliStr(s.state.isClient), s.currentFlight.String(), state.String())
s.cfg.log.Tracef("[handshake:%s] %s: %s",
srvCliStr(s.state.isClient), s.currentFlight.String(), state.String())
if s.cfg.onFlightState != nil {
s.cfg.onFlightState(s.currentFlight, state)
}
Expand Down Expand Up @@ -279,6 +280,15 @@ func (s *handshakeFSM) send(ctx context.Context, c flightConn) (handshakeState,
return handshakeWaiting, nil
}

var timerPool = sync.Pool{ //nolint:gochecknoglobals
New: func() any {
t := time.NewTimer(time.Millisecond)
t.Stop()

return t
},
}

func (s *handshakeFSM) wait(ctx context.Context, conn flightConn) (handshakeState, error) { //nolint:gocognit,cyclop
parse, errFlight := s.currentFlight.getFlightParser()
if errFlight != nil {
Expand All @@ -289,7 +299,15 @@ func (s *handshakeFSM) wait(ctx context.Context, conn flightConn) (handshakeStat
return handshakeErrored, errFlight
}

retransmitTimer := time.NewTimer(s.retransmitInterval)
retransmitTimer, ok := timerPool.Get().(*time.Timer)
if !ok {
return handshakeErrored, errFailedToAccessPoolTimer
}
defer func() {
retransmitTimer.Stop()
timerPool.Put(retransmitTimer)
}()
retransmitTimer.Reset(s.retransmitInterval)
for {
select {
case state := <-conn.recvHandshake():
Expand Down
11 changes: 6 additions & 5 deletions pkg/crypto/ciphersuite/cbc.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ func NewCBC(

// Encrypt encrypt a DTLS RecordLayer message.
func (c *CBC) Encrypt(pkt *recordlayer.RecordLayer, raw []byte) ([]byte, error) {
payload := raw[pkt.Header.Size():]
raw = raw[:pkt.Header.Size()]
payload := raw[pkt.Header.MarshalSize():]
raw = raw[:pkt.Header.MarshalSize()]
blockSize := c.writeCBC.BlockSize()

// Generate + Append MAC
Expand Down Expand Up @@ -110,7 +110,8 @@ func (c *CBC) Encrypt(pkt *recordlayer.RecordLayer, raw []byte) ([]byte, error)
raw = append(raw, payload...)

// Update recordLayer size to include IV+MAC+Padding
binary.BigEndian.PutUint16(raw[pkt.Header.Size()-2:], uint16(len(raw)-pkt.Header.Size())) //nolint:gosec //G115
binary.BigEndian.PutUint16(raw[pkt.Header.MarshalSize()-2:],
uint16(len(raw)-pkt.Header.MarshalSize())) //nolint:gosec //G115

return raw, nil
}
Expand All @@ -123,7 +124,7 @@ func (c *CBC) Decrypt(header recordlayer.Header, in []byte) ([]byte, error) {
if err := header.Unmarshal(in); err != nil {
return nil, err
}
body := in[header.Size():]
body := in[header.MarshalSize():]

switch {
case header.ContentType == protocol.ContentTypeChangeCipherSpec:
Expand Down Expand Up @@ -171,7 +172,7 @@ func (c *CBC) Decrypt(header recordlayer.Header, in []byte) ([]byte, error) {
return nil, errInvalidMAC
}

return append(in[:header.Size()], body[:dataEnd]...), nil
return append(in[:header.MarshalSize()], body[:dataEnd]...), nil
}

func (c *CBC) hmac(
Expand Down
10 changes: 5 additions & 5 deletions pkg/crypto/ciphersuite/chacha20poly1305.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ func NewChaCha20Poly1305(localKey, localWriteIV, remoteKey, remoteWriteIV []byte

// Encrypt encrypts a DTLS RecordLayer message.
func (c *ChaCha20Poly1305) Encrypt(pkt *recordlayer.RecordLayer, raw []byte) ([]byte, error) {
payload := raw[pkt.Header.Size():]
raw = raw[:pkt.Header.Size()]
payload := raw[pkt.Header.MarshalSize():]
raw = raw[:pkt.Header.MarshalSize()]

var nonce [chachaNonceLength]byte
copy(nonce[:], c.localWriteIV)
Expand Down Expand Up @@ -80,7 +80,7 @@ func (c *ChaCha20Poly1305) Encrypt(pkt *recordlayer.RecordLayer, raw []byte) ([]
copy(result, raw)
copy(result[len(raw):], encrypted)

binary.BigEndian.PutUint16(result[pkt.Header.Size()-2:], uint16(len(encrypted))) //nolint:gosec
binary.BigEndian.PutUint16(result[pkt.Header.MarshalSize()-2:], uint16(len(encrypted))) //nolint:gosec

return result, nil
}
Expand Down Expand Up @@ -108,7 +108,7 @@ func (c *ChaCha20Poly1305) Decrypt(header recordlayer.Header, in []byte) ([]byte
}

// NOTE: ChaCha20-Poly1305 has NO explicit nonce in the record
ciphertext := in[header.Size():]
ciphertext := in[header.MarshalSize():]

var additionalData []byte
if header.ContentType == protocol.ContentTypeConnectionID {
Expand All @@ -122,5 +122,5 @@ func (c *ChaCha20Poly1305) Decrypt(header recordlayer.Header, in []byte) ([]byte
return nil, fmt.Errorf("%w: %v", errDecryptPacket, err) //nolint:errorlint
}

return append(in[:header.Size()], plaintext...), nil
return append(in[:header.MarshalSize()], plaintext...), nil
}
6 changes: 3 additions & 3 deletions pkg/crypto/ciphersuite/chacha20poly1305_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ func FuzzChaCha20Poly1305_RoundTrip(f *testing.F) {

var parsedHdr recordlayer.Header
require.NoError(t, parsedHdr.Unmarshal(dec))
got := dec[parsedHdr.Size():]
got := dec[parsedHdr.MarshalSize():]

require.Equal(t, plain, got)
})
Expand Down Expand Up @@ -117,7 +117,7 @@ func FuzzChaCha20Poly1305_Bidirectional_RoundTrip(f *testing.F) {
var parsedHdrA recordlayer.Header
require.NoError(t, parsedHdrA.Unmarshal(decAonB))

gotA := decAonB[parsedHdrA.Size():]
gotA := decAonB[parsedHdrA.MarshalSize():]
require.Equal(t, pA, gotA)

// B -> A
Expand All @@ -144,7 +144,7 @@ func FuzzChaCha20Poly1305_Bidirectional_RoundTrip(f *testing.F) {
var parsedHdrB recordlayer.Header
require.NoError(t, parsedHdrB.Unmarshal(decBonA))

gotB := decBonA[parsedHdrB.Size():]
gotB := decBonA[parsedHdrB.MarshalSize():]
require.Equal(t, pB, gotB)
})
}
25 changes: 13 additions & 12 deletions pkg/crypto/ciphersuite/ciphersuite.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@ func newAEAD(

// encrypt encrypts a DTLS RecordLayer message.
func (a *aead) encrypt(pkt *recordlayer.RecordLayer, raw []byte) ([]byte, error) {
payload := raw[pkt.Header.Size():]
raw = raw[:pkt.Header.Size()]
payload := raw[pkt.Header.MarshalSize():]
raw = raw[:pkt.Header.MarshalSize()]

// Get nonce buffer from pool
noncePtr := a.nonceBufferPool.Get().(*[]byte) // nolint:forcetypeassert
Expand All @@ -93,19 +93,20 @@ func (a *aead) encrypt(pkt *recordlayer.RecordLayer, raw []byte) ([]byte, error)
additionalData = generateAEADAdditionalData(&pkt.Header, len(payload))
}
finalSize := len(raw) + 8 + len(payload) + a.tagLength
r := make([]byte, finalSize)
copy(r, raw)
copy(r[len(raw):], nonce[4:])
out := make([]byte, finalSize)
copy(out, raw)
copy(out[len(raw):], nonce[4:])

a.localAEAD.Seal(r[len(raw)+8:len(raw)+8], nonce, payload, additionalData)
a.localAEAD.Seal(out[len(raw)+8:len(raw)+8], nonce, payload, additionalData)

// Update recordLayer size to include explicit nonce
binary.BigEndian.PutUint16(r[pkt.Header.Size()-2:], uint16(len(r)-pkt.Header.Size())) //nolint:gosec //G115
binary.BigEndian.PutUint16(out[pkt.Header.MarshalSize()-2:],
uint16(len(out)-pkt.Header.MarshalSize())) //nolint:gosec //G115

// Return nonce buffer to pool
a.nonceBufferPool.Put(noncePtr)

return r, nil
return out, nil
}

// decrypt decrypts a DTLS RecordLayer message.
Expand All @@ -117,7 +118,7 @@ func (a *aead) decrypt(header recordlayer.Header, in []byte) ([]byte, error) {
case header.ContentType == protocol.ContentTypeChangeCipherSpec:
// Nothing to encrypt with ChangeCipherSpec
return in, nil
case len(in) <= (8 + header.Size()):
case len(in) <= (8 + header.MarshalSize()):
return nil, errNotEnoughRoomForNonce
}

Expand All @@ -126,8 +127,8 @@ func (a *aead) decrypt(header recordlayer.Header, in []byte) ([]byte, error) {
nonce := *noncePtr

copy(nonce[:4], a.remoteWriteIV[:4])
copy(nonce[4:], in[header.Size():header.Size()+8])
out := in[header.Size()+8:]
copy(nonce[4:], in[header.MarshalSize():header.MarshalSize()+8])
out := in[header.MarshalSize()+8:]

var additionalData []byte
if header.ContentType == protocol.ContentTypeConnectionID {
Expand All @@ -146,7 +147,7 @@ func (a *aead) decrypt(header recordlayer.Header, in []byte) ([]byte, error) {
// Return nonce buffer to pool
a.nonceBufferPool.Put(noncePtr)

return append(in[:header.Size()], out...), nil
return append(in[:header.MarshalSize()], out...), nil
}

func generateAEADAdditionalData(h *recordlayer.Header, payloadLen int) []byte {
Expand Down
6 changes: 3 additions & 3 deletions pkg/crypto/ciphersuite/gcm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ func FuzzGCM_RoundTrip(f *testing.F) {

var parsedHdr recordlayer.Header
require.NoError(t, parsedHdr.Unmarshal(dec))
got := dec[parsedHdr.Size():]
got := dec[parsedHdr.MarshalSize():]

require.Equal(t, plain, got)
})
Expand Down Expand Up @@ -117,7 +117,7 @@ func FuzzGCM_Bidirectional_RoundTrip(f *testing.F) {
var parsedHdrA recordlayer.Header
require.NoError(t, parsedHdrA.Unmarshal(decAonB))

gotA := decAonB[parsedHdrA.Size():]
gotA := decAonB[parsedHdrA.MarshalSize():]
require.Equal(t, pA, gotA)

// B -> A
Expand All @@ -144,7 +144,7 @@ func FuzzGCM_Bidirectional_RoundTrip(f *testing.F) {
var parsedHdrB recordlayer.Header
require.NoError(t, parsedHdrB.Unmarshal(decBonA))

gotB := decBonA[parsedHdrB.Size():]
gotB := decBonA[parsedHdrB.MarshalSize():]
require.Equal(t, pB, gotB)
})
}
21 changes: 20 additions & 1 deletion pkg/protocol/alert/alert.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,28 @@ func (a Alert) ContentType() protocol.ContentType {
return protocol.ContentTypeAlert
}

// MarshalSize returns the minimal buffer size required for MarshalTo.
func (a Alert) MarshalSize() int {
return 2
}

// Marshal returns the encoded alert.
func (a *Alert) Marshal() ([]byte, error) {
return []byte{byte(a.Level), byte(a.Description)}, nil
out := make([]byte, a.MarshalSize())
_, err := a.MarshalTo(out)

return out, err
}

// MarshalTo returns the encoded alert.
func (a *Alert) MarshalTo(out []byte) (int, error) {
if len(out) < a.MarshalSize() {
return 0, errBufferTooSmall
}
out[0] = byte(a.Level)
out[1] = byte(a.Description)

return 2, nil
}

// Unmarshal populates the alert from binary data.
Expand Down
Loading
Loading