Skip to content

Commit 585504e

Browse files
committed
Add parsing of flight 0
1 parent f3ea931 commit 585504e

7 files changed

Lines changed: 178 additions & 31 deletions

File tree

conn.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,7 @@ func (c *Conn) HandshakeContext(ctx context.Context) error { //nolint:cyclop
311311
c.closeLock.Unlock()
312312

313313
if c.isVersion13Enabled() {
314-
c.state.version = protocol.Version1_3
314+
c.state.localVersion = protocol.Version1_3
315315
var initialFlight flightVal
316316
if c.state.isClient {
317317
initialFlight = flightVal(flight13_1)
@@ -327,7 +327,7 @@ func (c *Conn) HandshakeContext(ctx context.Context) error { //nolint:cyclop
327327

328328
return nil
329329
}
330-
c.state.version = protocol.Version1_2
330+
c.state.localVersion = protocol.Version1_2
331331

332332
// rfc5246#section-7.4.3
333333
// In addition, the hash and signature algorithms MUST be compatible
@@ -1178,7 +1178,7 @@ func (c *Conn) recvHandshake() <-chan recvHandshakeState {
11781178

11791179
func (c *Conn) notify(ctx context.Context, level alert.Level, desc alert.Description) error {
11801180
if level == alert.Fatal && len(c.state.SessionID) > 0 { //nolint:nestif
1181-
if c.state.version == protocol.Version1_2 {
1181+
if c.state.localVersion == protocol.Version1_2 {
11821182
// According to the RFC, we need to delete the stored session.
11831183
// https://datatracker.ietf.org/doc/html/rfc5246#section-7.2
11841184
if ss := c.fsm.(*handshakeFSM12).cfg.sessionStore; ss != nil { //nolint:forcetypeassert
@@ -1224,7 +1224,7 @@ func (c *Conn) handshake(
12241224
initialState handshakeState,
12251225
) error {
12261226
done := make(chan struct{})
1227-
if c.state.version == protocol.Version1_3 {
1227+
if c.state.localVersion == protocol.Version1_3 {
12281228
c.fsm = &handshakeFSM13{
12291229
currentFlight: flightVal13(initialFlight),
12301230
state: &c.state,
@@ -1455,7 +1455,7 @@ func (c *Conn) sessionKey() []byte {
14551455
// As ServerName can be like 0.example.com, it's better to add
14561456
// delimiter character which is not allowed to be in
14571457
// neither address or domain name.
1458-
if c.state.version == protocol.Version1_3 {
1458+
if c.state.localVersion == protocol.Version1_3 {
14591459
return []byte(c.rAddr.String() + "_" + c.fsm.(*handshakeFSM13).cfg.serverName) //nolint:forcetypeassert
14601460
}
14611461

errors.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,8 @@ var (
108108
//nolint:err113
109109
errUnsupportedProtocolVersion = &FatalError{Err: errors.New("unsupported protocol version")}
110110
//nolint:err113
111+
errInvalidProtocolVersionState = &FatalError{Err: errors.New("invalid protocol version in state")}
112+
//nolint:err113
111113
errPSKAndIdentityMustBeSetForClient = &FatalError{
112114
Err: errors.New("PSK and PSK Identity Hint must both be set for client"),
113115
}
@@ -219,6 +221,10 @@ var (
219221
errNilOnConnectionAttempt = &FatalError{
220222
Err: errors.New("on connection attempt option requires a non-nil callback"),
221223
}
224+
//nolint:err113
225+
errInvalidGroupInKeyShare = &FatalError{
226+
Err: errors.New("groups offered in the key share extension must be included in the supported groups extension"),
227+
}
222228
)
223229

224230
// FatalError indicates that the DTLS connection is no longer available.

flighthandler_13.go

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,19 @@ type flightGenerator13 func(flightConn, *State, *handshakeCache, *handshakeConfi
2323

2424
//nolint:unused
2525
func (f flightVal13) getFlightParser13() (flightParser13, error) {
26-
return nil, errFlightUnimplemented13
26+
switch f {
27+
case flight13_0:
28+
return flight13_0Parse, nil
29+
default:
30+
return nil, errFlightUnimplemented13
31+
}
2732
}
2833

2934
//nolint:unused
3035
func (f flightVal13) getFlightGenerator13() (gen flightGenerator13, retransmit bool, err error) {
3136
switch f {
3237
case flight13_0:
33-
return flight13_0Generate, true, nil
38+
return flight0Generate, true, nil
3439
case flight13_1:
3540
return flight13_1Generate, true, nil
3641
default:

flighthandlers_client_13.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@ func flight13_1Generate(
4141
_ *handshakeCache,
4242
cfg *handshakeConfig,
4343
) ([]*packet, *alert.Alert, error) {
44-
if state.version != protocol.Version1_3 {
45-
return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, errUnsupportedProtocolVersion
44+
if state.localVersion != protocol.Version1_3 {
45+
return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, errInvalidProtocolVersionState
4646
}
4747
var zeroEpoch uint16
4848
state.localEpoch.Store(zeroEpoch)

flighthandlers_server_13.go

Lines changed: 136 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,14 @@
44
package dtls
55

66
import (
7-
"crypto/rand"
7+
"context"
8+
"slices"
89

10+
"github.com/pion/dtls/v3/pkg/crypto/elliptic"
11+
"github.com/pion/dtls/v3/pkg/protocol"
912
"github.com/pion/dtls/v3/pkg/protocol/alert"
13+
"github.com/pion/dtls/v3/pkg/protocol/extension"
14+
"github.com/pion/dtls/v3/pkg/protocol/handshake"
1015
)
1116

1217
// we'll add the flight handlers for the DTLS 1.3 server here.
@@ -33,31 +38,144 @@ import (
3338
// | Flight 4c |
3439
// +-----------+
3540

36-
func flight13_0Generate(
41+
//nolint:cyclop,gocognit
42+
func flight13_0Parse(
43+
_ context.Context,
3744
_ flightConn,
3845
state *State,
39-
_ *handshakeCache,
46+
cache *handshakeCache,
4047
cfg *handshakeConfig,
41-
) ([]*packet, *alert.Alert, error) { //nolint:unparam
42-
// Initialize
43-
if !cfg.insecureSkipHelloVerify {
44-
state.cookie = make([]byte, cookieLength)
45-
if _, err := rand.Read(state.cookie); err != nil {
46-
return nil, nil, err
48+
) (flightVal13, *alert.Alert, error) {
49+
if state.localVersion != protocol.Version1_3 {
50+
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, errInvalidProtocolVersionState
51+
}
52+
seq, msgs, ok := cache.fullPullMap(0, state.cipherSuite,
53+
handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false},
54+
)
55+
if !ok {
56+
// No valid message received. Keep reading
57+
return 0, nil, nil
58+
}
59+
60+
// Connection Identifiers must be negotiated afresh on session resumption.
61+
// https://datatracker.ietf.org/doc/html/rfc9146#name-the-connection_id-extension
62+
state.setLocalConnectionID(nil)
63+
state.remoteConnectionID = nil
64+
65+
state.handshakeRecvSequence = seq
66+
67+
var clientHello *handshake.MessageClientHello
68+
69+
// Validate type
70+
if clientHello, ok = msgs[handshake.TypeClientHello].(*handshake.MessageClientHello); !ok {
71+
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, nil
72+
}
73+
74+
if !clientHello.Version.Equal(protocol.Version1_2) {
75+
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.ProtocolVersion}, errUnsupportedProtocolVersion
76+
}
77+
78+
state.remoteRandom = clientHello.Random
79+
80+
cipherSuites := []CipherSuite{}
81+
for _, id := range clientHello.CipherSuiteIDs {
82+
if id == renegotiationInfoSCSV {
83+
state.remoteSupportsRenegotiation = true
84+
85+
continue
4786
}
87+
if c := cipherSuiteForID(CipherSuiteID(id), cfg.customCipherSuites); c != nil {
88+
cipherSuites = append(cipherSuites, c)
89+
}
90+
}
91+
92+
// Check for DTLS 1.3 cipher suites?
93+
if state.cipherSuite, ok = findMatchingCipherSuite(cipherSuites, cfg.localCipherSuites); !ok {
94+
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errCipherSuiteNoIntersection
4895
}
4996

50-
var zeroEpoch uint16
51-
state.localEpoch.Store(zeroEpoch)
52-
state.remoteEpoch.Store(zeroEpoch)
53-
if len(cfg.ellipticCurves) < 1 {
54-
return nil, nil, errEmptyEllipticCurves
97+
for _, val := range clientHello.Extensions {
98+
switch ext := val.(type) {
99+
case *extension.SupportedEllipticCurves:
100+
if len(ext.EllipticCurves) == 0 {
101+
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errNoSupportedEllipticCurves
102+
}
103+
state.remoteGroups = ext.EllipticCurves
104+
case *extension.UseSRTP:
105+
profile, ok := findMatchingSRTPProfile(cfg.localSRTPProtectionProfiles, ext.ProtectionProfiles)
106+
if !ok {
107+
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errServerNoMatchingSRTPProfile
108+
}
109+
state.setSRTPProtectionProfile(profile)
110+
state.remoteSRTPMasterKeyIdentifier = ext.MasterKeyIdentifier
111+
case *extension.UseExtendedMasterSecret:
112+
if cfg.extendedMasterSecret != DisableExtendedMasterSecret {
113+
state.extendedMasterSecret = true
114+
}
115+
case *extension.ServerName:
116+
state.serverName = ext.ServerName // remote server name
117+
case *extension.RenegotiationInfo:
118+
state.remoteSupportsRenegotiation = true
119+
case *extension.ALPN:
120+
state.peerSupportedProtocols = ext.ProtocolNameList
121+
case *extension.ConnectionID:
122+
// Only set connection ID to be sent if server supports connection
123+
// IDs.
124+
if cfg.connectionIDGenerator != nil {
125+
state.remoteConnectionID = ext.CID
126+
}
127+
case *extension.SignatureAlgorithmsCert:
128+
// Store the client's certificate signature schemes for later validation
129+
state.remoteCertSignatureSchemes = ext.SignatureHashAlgorithms
130+
case *extension.SupportedVersions:
131+
state.remoteVersions = ext.Versions
132+
case *extension.KeyShare:
133+
state.remoteKeyEntries = ext.ClientShares
134+
}
135+
}
136+
137+
if !slices.Contains(state.remoteVersions, protocol.Version1_3) {
138+
// nolint:godox
139+
// TODO: This should actually handover the state machine to DTLS 1.2
140+
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, errInvalidProtocolVersionState
141+
}
142+
143+
// If the client doesn't support connection IDs, the server should not
144+
// expect one to be sent.
145+
if state.remoteConnectionID == nil {
146+
state.setLocalConnectionID(nil)
147+
}
148+
149+
if cfg.extendedMasterSecret == RequireExtendedMasterSecret && !state.extendedMasterSecret {
150+
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errServerRequiredButNoClientEMS
151+
}
152+
153+
if state.localKeypair == nil {
154+
var err error
155+
state.localKeypair, err = elliptic.GenerateKeypair(state.namedCurve)
156+
if err != nil {
157+
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.IllegalParameter}, err
158+
}
159+
}
160+
161+
nextFlight := flight13_2
162+
163+
var groups []elliptic.Curve
164+
for _, entry := range state.remoteKeyEntries {
165+
// Clients MUST NOT offer any KeyShareEntry values
166+
// for groups not listed in the client's "supported_groups" extension.
167+
// Servers MAY check for violations of these rules and abort the
168+
// handshake with an "illegal_parameter" alert if one is violated.
169+
if !slices.Contains(state.remoteGroups, entry.Group) {
170+
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.IllegalParameter}, errInvalidGroupInKeyShare
171+
}
172+
groups = append(groups, entry.Group)
55173
}
56-
state.namedCurve = cfg.ellipticCurves[0]
174+
state.namedCurve, _ = findMatchingGroup(groups, cfg.ellipticCurves)
57175

58-
if err := state.localRandom.Populate(); err != nil {
59-
return nil, nil, err
176+
if cfg.insecureSkipHelloVerify {
177+
nextFlight = flight13_4
60178
}
61179

62-
return nil, nil, nil
180+
return nextFlight, nil, nil
63181
}

state.go

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,11 @@ type State struct {
7575
peerSupportedProtocols []string
7676
NegotiatedProtocol string
7777

78-
version protocol.Version
78+
localVersion protocol.Version
79+
remoteVersions []protocol.Version
7980
localKeyEntries []extension.KeyShareEntry
8081
remoteKeyEntries []extension.KeyShareEntry //nolint:unused
82+
remoteGroups []elliptic.Curve
8183
}
8284

8385
type serializedState struct {
@@ -123,7 +125,7 @@ func (s *State) serialize() (*serializedState, error) {
123125
remoteRnd := s.remoteRandom.MarshalFixed()
124126

125127
epoch := s.getLocalEpoch()
126-
version := uint16(s.version.Major)<<8 + uint16(s.version.Minor)
128+
version := uint16(s.localVersion.Major)<<8 + uint16(s.localVersion.Minor)
127129

128130
return &serializedState{
129131
LocalEpoch: s.getLocalEpoch(),
@@ -191,7 +193,7 @@ func (s *State) deserialize(serialized serializedState) {
191193

192194
major := uint8((serialized.version & 0xff00) >> 8) //nolint:gosec
193195
minor := uint8(serialized.version & 0xff) //nolint:gosec
194-
s.version = protocol.Version{Major: major, Minor: minor}
196+
s.localVersion = protocol.Version{Major: major, Minor: minor}
195197
}
196198

197199
func (s *State) initCipherSuite() error {

util.go

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,11 @@
33

44
package dtls
55

6-
import "slices"
6+
import (
7+
"slices"
8+
9+
"github.com/pion/dtls/v3/pkg/crypto/elliptic"
10+
)
711

812
func findMatchingSRTPProfile(a, b []SRTPProtectionProfile) (SRTPProtectionProfile, bool) {
913
for _, aProfile := range a {
@@ -15,6 +19,18 @@ func findMatchingSRTPProfile(a, b []SRTPProtectionProfile) (SRTPProtectionProfil
1519
return 0, false
1620
}
1721

22+
func findMatchingGroup(a, b []elliptic.Curve) (elliptic.Curve, bool) {
23+
for _, aGroup := range a {
24+
for _, bGroup := range b {
25+
if aGroup == bGroup {
26+
return aGroup, true
27+
}
28+
}
29+
}
30+
31+
return 0, false
32+
}
33+
1834
func findMatchingCipherSuite(a, b []CipherSuite) (CipherSuite, bool) {
1935
for _, aSuite := range a {
2036
for _, bSuite := range b {

0 commit comments

Comments
 (0)