Skip to content

Commit e7ffc5e

Browse files
committed
(squash): Add WIP arch and FSM for DTLS 1.3
1 parent 620d642 commit e7ffc5e

12 files changed

Lines changed: 619 additions & 34 deletions

config.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,10 @@ type Config struct { //nolint:dupl
236236

237237
// ListenConfig used to create the underlying listener socket.
238238
listenConfig net.ListenConfig
239+
240+
// version13
241+
// WIP experimental feature, see https://github.com/pion/dtls/issues/188
242+
version13 bool
239243
}
240244

241245
func (c *Config) includeCertificateSuites() bool {

conn.go

Lines changed: 104 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -92,11 +92,12 @@ type Conn struct {
9292
cancelHandshaker func()
9393
cancelHandshakeReader func()
9494

95-
fsm *handshakeFSM
95+
fsm handshakeFSM
9696

9797
replayProtectionWindow uint
9898

99-
handshakeConfig *handshakeConfig
99+
handshakeConfig *handshakeConfig
100+
handshakeConfig13 *handshakeConfig13
100101
}
101102

102103
// createConn creates a new DTLS connection.
@@ -256,6 +257,14 @@ func createConn(
256257
},
257258
}
258259

260+
if config.version13 {
261+
handshakeConfig13 := &handshakeConfig13{
262+
handshakeConfig: handshakeConfig,
263+
}
264+
conn.handshakeConfig13 = handshakeConfig13
265+
conn.handshakeConfig = nil
266+
}
267+
259268
conn.setRemoteEpoch(0)
260269
conn.setLocalEpoch(0)
261270

@@ -284,7 +293,7 @@ func (c *Conn) Handshake() error {
284293
//
285294
// Most uses of this package need not call HandshakeContext explicitly: the
286295
// first [Conn.Read] or [Conn.Write] will call it automatically.
287-
func (c *Conn) HandshakeContext(ctx context.Context) error {
296+
func (c *Conn) HandshakeContext(ctx context.Context) error { //nolint:cyclop
288297
c.handshakeMutex.Lock()
289298
defer c.handshakeMutex.Unlock()
290299

@@ -298,6 +307,23 @@ func (c *Conn) HandshakeContext(ctx context.Context) error {
298307
c.handshakeDone = handshakeDone
299308
c.closeLock.Unlock()
300309

310+
if c.isVersion13Enabled() {
311+
var initialFlight flightVal
312+
if c.state.isClient {
313+
initialFlight = flightVal(flight13_1)
314+
} else {
315+
initialFlight = flightVal(flight13_0)
316+
}
317+
initialFSMState := handshakePreparing
318+
319+
if err := c.handshake(ctx, initialFlight, initialFSMState); err != nil {
320+
return err
321+
}
322+
c.log.Trace("Handshake DTLS 1.3 Completed")
323+
324+
return nil
325+
}
326+
301327
// rfc5246#section-7.4.3
302328
// In addition, the hash and signature algorithms MUST be compatible
303329
// with the key in the server's end-entity certificate.
@@ -330,7 +356,7 @@ func (c *Conn) HandshakeContext(ctx context.Context) error {
330356
initialFSMState = handshakePreparing
331357
}
332358
// Do handshake
333-
if err := c.handshake(ctx, c.handshakeConfig, initialFlight, initialFSMState); err != nil {
359+
if err := c.handshake(ctx, initialFlight, initialFSMState); err != nil {
334360
return err
335361
}
336362

@@ -482,6 +508,8 @@ func (c *Conn) Write(payload []byte) (int, error) {
482508
return 0, err
483509
}
484510

511+
//nolint:godox
512+
// TODO: check for version
485513
return len(payload), c.writePackets(c.writeDeadline, []*packet{
486514
{
487515
record: &recordlayer.RecordLayer{
@@ -808,7 +836,7 @@ var poolReadBuffer = sync.Pool{ //nolint:gochecknoglobals
808836
},
809837
}
810838

811-
func (c *Conn) readAndBuffer(ctx context.Context) error { //nolint:cyclop
839+
func (c *Conn) readAndBuffer(ctx context.Context) error { //nolint:cyclop,gocognit
812840
bufptr, ok := poolReadBuffer.Get().(*[]byte)
813841
if !ok {
814842
return errFailedToAccessPoolReadBuffer
@@ -828,6 +856,8 @@ func (c *Conn) readAndBuffer(ctx context.Context) error { //nolint:cyclop
828856

829857
var hasHandshake, isRetransmit bool
830858
for _, p := range pkts {
859+
//nolint:godox
860+
// TODO: check version
831861
hs, rtx, alert, err := c.handleIncomingPacket(ctx, p, rAddr, true)
832862
if alert != nil {
833863
if alertErr := c.notify(ctx, alert.Level, alert.Description); alertErr != nil {
@@ -875,6 +905,8 @@ func (c *Conn) handleQueuedPackets(ctx context.Context) error {
875905
c.lock.Unlock()
876906

877907
for _, p := range pkts {
908+
//nolint:godox
909+
// TODO: check version
878910
_, _, alert, err := c.handleIncomingPacket(ctx, p.data, p.rAddr, false) // don't re-enqueue
879911
if alert != nil {
880912
if alertErr := c.notify(ctx, alert.Level, alert.Description); alertErr != nil {
@@ -908,6 +940,17 @@ func (c *Conn) enqueueEncryptedPackets(packet addrPkt) bool {
908940
return false
909941
}
910942

943+
// nolint:unused
944+
func (c *Conn) handleIncomingPacket13(
945+
ctx context.Context,
946+
buf []byte,
947+
rAddr net.Addr,
948+
enqueue bool,
949+
) (bool, bool, *alert.Alert, error) {
950+
// Placeholder function
951+
return false, false, nil, nil
952+
}
953+
911954
//nolint:gocognit,gocyclo,cyclop,maintidx
912955
func (c *Conn) handleIncomingPacket(
913956
ctx context.Context,
@@ -1129,17 +1172,29 @@ func (c *Conn) recvHandshake() <-chan recvHandshakeState {
11291172
}
11301173

11311174
func (c *Conn) notify(ctx context.Context, level alert.Level, desc alert.Description) error {
1132-
if level == alert.Fatal && len(c.state.SessionID) > 0 {
1133-
// According to the RFC, we need to delete the stored session.
1134-
// https://datatracker.ietf.org/doc/html/rfc5246#section-7.2
1135-
if ss := c.fsm.cfg.sessionStore; ss != nil {
1136-
c.log.Tracef("clean invalid session: %s", c.state.SessionID)
1137-
if err := ss.Del(c.sessionKey()); err != nil {
1138-
return err
1175+
if level == alert.Fatal && len(c.state.SessionID) > 0 { //nolint:nestif
1176+
if c.isVersion13Enabled() {
1177+
// With compatibility mode for 1.3, CH uses a non-empty session_id
1178+
// https://datatracker.ietf.org/doc/html/rfc8446#appendix-D.4
1179+
if ss := c.fsm.(*handshakeFSM13).cfg.sessionStore; ss != nil { //nolint:forcetypeassert
1180+
c.log.Tracef("clean invalid session: %s", c.state.SessionID)
1181+
if err := ss.Del(c.sessionKey()); err != nil {
1182+
return err
1183+
}
1184+
}
1185+
} else {
1186+
// According to the RFC, we need to delete the stored session.
1187+
// https://datatracker.ietf.org/doc/html/rfc5246#section-7.2
1188+
if ss := c.fsm.(*handshakeFSM12).cfg.sessionStore; ss != nil { //nolint:forcetypeassert
1189+
c.log.Tracef("clean invalid session: %s", c.state.SessionID)
1190+
if err := ss.Del(c.sessionKey()); err != nil {
1191+
return err
1192+
}
11391193
}
11401194
}
11411195
}
11421196

1197+
// This should be updated with DTLS 1.3 record encoding.
11431198
return c.writePackets(ctx, []*packet{
11441199
{
11451200
record: &recordlayer.RecordLayer{
@@ -1169,20 +1224,41 @@ func (c *Conn) isHandshakeCompletedSuccessfully() bool {
11691224
//nolint:cyclop,gocognit,contextcheck
11701225
func (c *Conn) handshake(
11711226
ctx context.Context,
1172-
cfg *handshakeConfig,
11731227
initialFlight flightVal,
11741228
initialState handshakeState,
11751229
) error {
1176-
c.fsm = newHandshakeFSM(&c.state, c.handshakeCache, cfg, initialFlight)
1177-
11781230
done := make(chan struct{})
1179-
ctxRead, cancelRead := context.WithCancel(context.Background())
1180-
cfg.onFlightState = func(_ flightVal, s handshakeState) {
1181-
if s == handshakeFinished && c.setHandshakeCompletedSuccessfully() {
1182-
close(done)
1231+
if c.isVersion13Enabled() {
1232+
c.fsm = &handshakeFSM13{
1233+
currentFlight: flightVal13(initialFlight),
1234+
state: &c.state,
1235+
cache: c.handshakeCache,
1236+
cfg: c.handshakeConfig13,
1237+
retransmitInterval: c.handshakeConfig13.initialRetransmitInterval,
1238+
closed: make(chan struct{}),
1239+
}
1240+
c.handshakeConfig13.onFlightState13 = func(_ flightVal13, s handshakeState) {
1241+
if c.fsm.(*handshakeFSM13).currentFlight.isLastSendFlight() { //nolint:forcetypeassert
1242+
close(done)
1243+
}
1244+
}
1245+
} else {
1246+
c.fsm = &handshakeFSM12{
1247+
currentFlight: initialFlight,
1248+
state: &c.state,
1249+
cache: c.handshakeCache,
1250+
cfg: c.handshakeConfig,
1251+
retransmitInterval: c.handshakeConfig.initialRetransmitInterval,
1252+
closed: make(chan struct{}),
1253+
}
1254+
c.handshakeConfig.onFlightState = func(_ flightVal, s handshakeState) {
1255+
if s == handshakeFinished && c.setHandshakeCompletedSuccessfully() {
1256+
close(done)
1257+
}
11831258
}
11841259
}
11851260

1261+
ctxRead, cancelRead := context.WithCancel(context.Background())
11861262
ctxHs, cancel := context.WithCancel(context.Background())
11871263

11881264
c.closeLock.Lock()
@@ -1379,7 +1455,11 @@ func (c *Conn) sessionKey() []byte {
13791455
// As ServerName can be like 0.example.com, it's better to add
13801456
// delimiter character which is not allowed to be in
13811457
// neither address or domain name.
1382-
return []byte(c.rAddr.String() + "_" + c.fsm.cfg.serverName)
1458+
if c.isVersion13Enabled() {
1459+
return []byte(c.rAddr.String() + "_" + c.fsm.(*handshakeFSM13).cfg.serverName) //nolint:forcetypeassert
1460+
}
1461+
1462+
return []byte(c.rAddr.String() + "_" + c.fsm.(*handshakeFSM12).cfg.serverName) //nolint:forcetypeassert
13831463
}
13841464

13851465
return c.state.SessionID
@@ -1406,3 +1486,7 @@ func (c *Conn) SetWriteDeadline(t time.Time) error {
14061486
// Write deadline is also fully managed by this layer.
14071487
return nil
14081488
}
1489+
1490+
func (c *Conn) isVersion13Enabled() bool {
1491+
return c.handshakeConfig13 != nil
1492+
}

conn_test.go

Lines changed: 59 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3127,10 +3127,20 @@ func TestEllipticCurveConfiguration(t *testing.T) {
31273127
assert.True(t, ok, "Failed to default Elliptic curves")
31283128

31293129
if len(test.ConfigCurves) != 0 {
3130-
assert.Equal(t, len(test.HandshakeCurves), len(server.fsm.cfg.ellipticCurves), "Failed to configure Elliptic curves")
3130+
assert.Equal(
3131+
t,
3132+
len(test.HandshakeCurves),
3133+
len(server.fsm.(*handshakeFSM12).cfg.ellipticCurves), //nolint:forcetypeassert
3134+
"Failed to configure Elliptic curves",
3135+
)
31313136

31323137
for i, c := range test.ConfigCurves {
3133-
assert.Equal(t, c, server.fsm.cfg.ellipticCurves[i], "Failed to maintain Elliptic curve order")
3138+
assert.Equal(
3139+
t,
3140+
c,
3141+
server.fsm.(*handshakeFSM12).cfg.ellipticCurves[i], //nolint:forcetypeassert
3142+
"Failed to maintain Elliptic curve order",
3143+
)
31343144
}
31353145
}
31363146

@@ -3500,3 +3510,50 @@ func TestCloseWithoutHandshake(t *testing.T) {
35003510
assert.NoError(t, err)
35013511
assert.NoError(t, server.Close())
35023512
}
3513+
3514+
// WIP! Tests if DTLS 1.3 handshake flow is enabled and the correct error is returned.
3515+
func TestDTLS13Config(t *testing.T) {
3516+
ca, cb := dpipe.Pipe()
3517+
3518+
// Setup client
3519+
clientCert, err := selfsign.GenerateSelfSigned()
3520+
assert.NoError(t, err)
3521+
3522+
clientcfg, err := buildClientConfig(
3523+
WithCertificates(clientCert),
3524+
WithInsecureSkipVerify(true),
3525+
withVersion13(true),
3526+
)
3527+
3528+
assert.NoError(t, err)
3529+
3530+
client, err := Client(dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), clientcfg)
3531+
assert.NoError(t, err)
3532+
defer func() {
3533+
_ = client.Close()
3534+
}()
3535+
3536+
_, ok := client.ConnectionState()
3537+
assert.False(t, ok)
3538+
3539+
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
3540+
defer cancel()
3541+
errorChannel := make(chan error)
3542+
go func() {
3543+
errC := client.HandshakeContext(ctx)
3544+
errorChannel <- errC
3545+
}()
3546+
3547+
// Setup server, ignore error
3548+
server, _ := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{}, true)
3549+
assert.NoError(t, err)
3550+
3551+
defer func() {
3552+
_ = server.Close()
3553+
}()
3554+
3555+
err = <-errorChannel
3556+
if err.Error() == errFlightUnimplemented13.Error() {
3557+
return
3558+
}
3559+
}

errors.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,10 @@ var (
129129
//nolint:err113
130130
errInvalidFlight = &InternalError{Err: errors.New("invalid flight number")}
131131
//nolint:err113
132+
errFlightUnimplemented13 = &InternalError{Err: errors.New("unimplemeted DTLS 1.3 flight")}
133+
//nolint:err113
134+
errStateUnimplemented13 = &InternalError{Err: errors.New("unimplemeted DTLS 1.3 handshake state")}
135+
//nolint:err113
132136
errKeySignatureGenerateUnimplemented = &InternalError{
133137
Err: errors.New("unable to generate key signature, unimplemented"),
134138
}

0 commit comments

Comments
 (0)