diff --git a/config.go b/config.go index 982a22ef..91154282 100644 --- a/config.go +++ b/config.go @@ -186,6 +186,7 @@ type Config struct { //nolint:dupl // InsecureSkipVerifyHello, if true and when acting as server, allow client to // skip hello verify phase and receive ServerHello after initial ClientHello. // This have implication on DoS attack resistance. + // For DTLS 1.3 this skips the HelloRetryRequest message InsecureSkipVerifyHello bool // ConnectionIDGenerator generates connection identifiers that should be diff --git a/conn.go b/conn.go index 6e28a714..5d80a992 100644 --- a/conn.go +++ b/conn.go @@ -20,6 +20,7 @@ import ( "github.com/pion/dtls/v3/pkg/crypto/signaturehash" "github.com/pion/dtls/v3/pkg/protocol" "github.com/pion/dtls/v3/pkg/protocol/alert" + "github.com/pion/dtls/v3/pkg/protocol/extension" "github.com/pion/dtls/v3/pkg/protocol/handshake" "github.com/pion/dtls/v3/pkg/protocol/recordlayer" "github.com/pion/logging" @@ -92,7 +93,7 @@ type Conn struct { cancelHandshaker func() cancelHandshakeReader func() - fsm *handshakeFSM + fsm handshakeFSM replayProtectionWindow uint @@ -192,12 +193,12 @@ func createConn( } minVersion := config.minVersion - if !minVersion.Equal(protocol.Version1_2) || !minVersion.Equal(protocol.Version1_3) { + if !minVersion.Equal(protocol.Version1_3) { minVersion = protocol.Version1_2 } - maxVersion := config.minVersion - if !maxVersion.Equal(protocol.Version1_2) || !maxVersion.Equal(protocol.Version1_3) { + maxVersion := config.maxVersion + if !maxVersion.Equal(protocol.Version1_3) { maxVersion = protocol.Version1_2 } @@ -296,7 +297,7 @@ func (c *Conn) Handshake() error { // // Most uses of this package need not call HandshakeContext explicitly: the // first [Conn.Read] or [Conn.Write] will call it automatically. -func (c *Conn) HandshakeContext(ctx context.Context) error { +func (c *Conn) HandshakeContext(ctx context.Context) error { //nolint:cyclop c.handshakeMutex.Lock() defer c.handshakeMutex.Unlock() @@ -321,36 +322,91 @@ func (c *Conn) HandshakeContext(ctx context.Context) error { c.handshakeConfig.localCipherSuites = filterCipherSuitesForCertificate(cert, c.handshakeConfig.localCipherSuites) } - var initialFlight flightVal - var initialFSMState handshakeState - - if c.handshakeConfig.resumeState != nil { //nolint:nestif - if c.state.isClient { - initialFlight = flight5 - } else { - initialFlight = flight6 - } - initialFSMState = handshakeFinished - - c.state = *c.handshakeConfig.resumeState - } else { - if c.state.isClient { - initialFlight = flight1 - } else { - initialFlight = flight0 - } - initialFSMState = handshakePreparing + initialFlight, initialFSMState, initialFlights, postFSMSetup, err := c.prepareHandshakeStart(ctx) + if err != nil { + return err } - // Do handshake - if err := c.handshake(ctx, c.handshakeConfig, initialFlight, initialFSMState); err != nil { + + if err := c.handshake(ctx, initialFlight, initialFSMState, initialFlights, postFSMSetup); err != nil { return err } - c.log.Trace("Handshake Completed") + if c.state.localVersion == protocol.Version1_3 { + c.log.Trace("Handshake DTLS 1.3 Completed") + } else { + c.log.Trace("Handshake Completed") + } return nil } +// prepareHandshakeStart negotiates the DTLS version and decides how the FSM should start. +// +// There are three modes for the version: +// - DTLS 1.2 only +// - DTLS 1.3 only +// - Dual-stack (this mode sends or read handshake messages without starting a FSM) +// +// []*packet holds the ClientHello already sent by the dual-stack, the caller +// must seed the FSM with it. +// nolint:cyclop +func (c *Conn) prepareHandshakeStart( + ctx context.Context, +) (flightVal, handshakeState, []*packet, func(context.Context), error) { + switch { + // DTLS 1.2 only + case c.state.isClient && c.handshakeConfig.maxVersion == protocol.Version1_2: + c.state.localVersion = protocol.Version1_2 + if c.handshakeConfig.resumeState != nil { + c.state = *c.handshakeConfig.resumeState + + return flight5, handshakeFinished, nil, nil, nil + } + + return flight1, handshakePreparing, nil, nil, nil + case !c.state.isClient && c.handshakeConfig.maxVersion == protocol.Version1_2: + c.state.localVersion = protocol.Version1_2 + if c.handshakeConfig.resumeState != nil { + c.state = *c.handshakeConfig.resumeState + + return flight6, handshakeFinished, nil, nil, nil + } + + return flight0, handshakePreparing, nil, nil, nil + + // DTLS 1.3 only + case c.state.isClient && c.handshakeConfig.minVersion == protocol.Version1_3: + c.state.localVersion = protocol.Version1_3 + + return flightVal(flight13_1), handshakePreparing, nil, nil, nil + case !c.state.isClient && c.handshakeConfig.minVersion == protocol.Version1_3: + c.state.localVersion = protocol.Version1_3 + + return flightVal(flight13_0), handshakePreparing, nil, nil, nil + + // Dual-stack + // This mode sends or read handshake messages to decide version without starting a FSM + case c.state.isClient: + initialFlights, err := c.negotiateVersionClient(ctx) + if err != nil { + return 0, 0, nil, nil, err + } + + primer := func(ctx context.Context) { + go c.primeHandshakeRecv(ctx) + } + + return flightVal(flight13_1), handshakeWaiting, initialFlights, primer, nil + default: + err := c.negotiateVersionServer(ctx) + if err != nil { + return 0, 0, nil, nil, err + } + + return flightVal(flight13_0), handshakePreparing, nil, nil, nil + } +} + // Dial connects to the given network address and establishes a DTLS connection on top. // // Deprecated: Use DialWithOptions instead. @@ -494,6 +550,8 @@ func (c *Conn) Write(payload []byte) (int, error) { return 0, err } + //nolint:godox + // TODO: check for version return len(payload), c.writePackets(c.writeDeadline, []*packet{ { record: &recordlayer.RecordLayer{ @@ -820,7 +878,7 @@ var poolReadBuffer = sync.Pool{ //nolint:gochecknoglobals }, } -func (c *Conn) readAndBuffer(ctx context.Context) error { //nolint:cyclop +func (c *Conn) readAndBuffer(ctx context.Context) error { //nolint:cyclop,gocognit bufptr, ok := poolReadBuffer.Get().(*[]byte) if !ok { return errFailedToAccessPoolReadBuffer @@ -840,6 +898,8 @@ func (c *Conn) readAndBuffer(ctx context.Context) error { //nolint:cyclop var hasHandshake, isRetransmit bool for _, p := range pkts { + //nolint:godox + // TODO: check version hs, rtx, alert, err := c.handleIncomingPacket(ctx, p, rAddr, true) if alert != nil { if alertErr := c.notify(ctx, alert.Level, alert.Description); alertErr != nil { @@ -887,6 +947,8 @@ func (c *Conn) handleQueuedPackets(ctx context.Context) error { c.lock.Unlock() for _, p := range pkts { + //nolint:godox + // TODO: check version _, _, alert, err := c.handleIncomingPacket(ctx, p.data, p.rAddr, false) // don't re-enqueue if alert != nil { if alertErr := c.notify(ctx, alert.Level, alert.Description); alertErr != nil { @@ -920,6 +982,17 @@ func (c *Conn) enqueueEncryptedPackets(packet addrPkt) bool { return false } +// nolint:unused +func (c *Conn) handleIncomingPacket13( + ctx context.Context, + buf []byte, + rAddr net.Addr, + enqueue bool, +) (bool, bool, *alert.Alert, error) { + // Placeholder function + return false, false, nil, nil +} + //nolint:gocognit,gocyclo,cyclop,maintidx func (c *Conn) handleIncomingPacket( ctx context.Context, @@ -1141,17 +1214,20 @@ func (c *Conn) recvHandshake() <-chan recvHandshakeState { } func (c *Conn) notify(ctx context.Context, level alert.Level, desc alert.Description) error { - if level == alert.Fatal && len(c.state.SessionID) > 0 { - // According to the RFC, we need to delete the stored session. - // https://datatracker.ietf.org/doc/html/rfc5246#section-7.2 - if ss := c.fsm.cfg.sessionStore; ss != nil { - c.log.Tracef("clean invalid session: %s", c.state.SessionID) - if err := ss.Del(c.sessionKey()); err != nil { - return err + if level == alert.Fatal && len(c.state.SessionID) > 0 { //nolint:nestif + if c.state.localVersion == protocol.Version1_2 { + // According to the RFC, we need to delete the stored session. + // https://datatracker.ietf.org/doc/html/rfc5246#section-7.2 + if ss := c.fsm.(*handshakeFSM12).cfg.sessionStore; ss != nil { //nolint:forcetypeassert + c.log.Tracef("clean invalid session: %s", c.state.SessionID) + if err := ss.Del(c.sessionKey()); err != nil { + return err + } } } } + // This should be updated with DTLS 1.3 record encoding. return c.writePackets(ctx, []*packet{ { record: &recordlayer.RecordLayer{ @@ -1178,23 +1254,283 @@ func (c *Conn) isHandshakeCompletedSuccessfully() bool { return c.handshakeCompletedSuccessfully.Load() } -//nolint:cyclop,gocognit,contextcheck +func (c *Conn) negotiateVersionServer(ctx context.Context) error { + for { + if err := c.readAndBufferNoFSM(ctx); err != nil { + return err + } + if ok, err := c.pickVersionFromClientHello(); err != nil { + return err + } else if ok { + return nil + } + // ClientHello not yet (fully) received; keep reading. + } +} + +//nolint:cyclop +func (c *Conn) negotiateVersionClient(ctx context.Context) ([]*packet, error) { + pkts, dtlsAlert, err := flight13_1Generate(c, &c.state, c.handshakeCache, c.handshakeConfig) + if dtlsAlert != nil { + if alertErr := c.notify(ctx, dtlsAlert.Level, dtlsAlert.Description); alertErr != nil && err == nil { + err = alertErr + } + } + if err != nil { + return nil, err + } + + c.stampHandshakeSequence(pkts) + if err := c.writePackets(ctx, pkts); err != nil { + return nil, err + } + + for { + if err := c.readAndBufferNoFSM(ctx); err != nil { + return nil, err + } + if ok, err := c.pickVersionFromServerResponse(); err != nil { + return nil, err + } else if ok { + return pkts, nil + } + // ServerHello or HelloVerifyRequest not yet (fully) received; keep reading. + } +} + +// pickVersionFromClientHello inspects the handshake cache for incoming +// ClientHello and, if found, sets localVersion and remoteVersions. +// Returns true once the version can be decided. +func (c *Conn) pickVersionFromClientHello() (bool, error) { + _, msgs, ok := c.handshakeCache.fullPullMap(0, c.state.cipherSuite, + handshakeCachePullRule{handshake.TypeClientHello, c.handshakeConfig.initialEpoch, true, false}, + ) + if !ok { + return false, nil + } + ch, ok := msgs[handshake.TypeClientHello].(*handshake.MessageClientHello) + if !ok { + return false, nil + } + + var remote []protocol.Version + seenSupportedVersions := false + for _, e := range ch.Extensions { + if sv, ok := e.(*extension.SupportedVersions); ok { //nolint:govet + seenSupportedVersions = true + remote = sv.Versions + + break + } + } + if !seenSupportedVersions { + remote = []protocol.Version{ch.Version} + } + + chosen, ok := selectVersion(remote, c.handshakeConfig.minVersion, c.handshakeConfig.maxVersion) + if !ok { + return false, errNoCommonProtocolVersion + } + + c.state.remoteVersions = remote + c.state.localVersion = chosen + + return true, nil +} + +// pickVersionFromServerResponse inspects the handshake cache for the server's +// response to our ClientHello and, if found, sets localVersion and +// remoteVersions. Returns true once the version can be pinned down. +// +// Handling: +// - ServerHello with supported_versions: finds match (1.2 or 1.3). +// - ServerHello without supported_versions: fall back to ServerHello.Version. +// - HelloVerifyRequest (1.2 cookie request): version is 1.2. +func (c *Conn) pickVersionFromServerResponse() (bool, error) { + if sh, ok := c.findCachedServerMessage(handshake.TypeServerHello).(*handshake.MessageServerHello); ok { + var remote []protocol.Version + seenSupportedVersions := false + for _, e := range sh.Extensions { + if sv, ok := e.(*extension.SupportedVersions); ok { + seenSupportedVersions = true + remote = sv.Versions + + break + } + } + if !seenSupportedVersions { + remote = []protocol.Version{sh.Version} + } + + chosen, ok := selectVersion(remote, c.handshakeConfig.minVersion, c.handshakeConfig.maxVersion) + if !ok { + return false, errNoCommonProtocolVersion + } + c.state.remoteVersions = remote + c.state.localVersion = chosen + + return true, nil + } + + if hvr, ok := c.findCachedServerMessage(handshake.TypeHelloVerifyRequest).(*handshake.MessageHelloVerifyRequest); ok { + c.state.localVersion = protocol.Version1_2 + remote := []protocol.Version{hvr.Version} + chosen, ok := selectVersion(remote, c.handshakeConfig.minVersion, c.handshakeConfig.maxVersion) + if !ok { + return false, errNoCommonProtocolVersion + } + c.state.remoteVersions = remote + c.state.localVersion = chosen + + return true, nil + } + + return false, nil +} + +// findCachedServerMessage pulls the most recent handshake message of the +// given type sent by the peer from the cache, if any. +func (c *Conn) findCachedServerMessage(t handshake.Type) handshake.Message { + _, msgs, ok := c.handshakeCache.fullPullMap(0, c.state.cipherSuite, + handshakeCachePullRule{t, c.handshakeConfig.initialEpoch, false, true}, + ) + if !ok { + return nil + } + + return msgs[t] +} + +// stampHandshakeSequence assigns the DTLS message_sequence to each handshake +// record in pkts, using and advancing state.handshakeSendSequence. This is +// the subset of handshakeFSM.prepare()'s bookkeeping that generated dual-stack +// packets need before being passed to writePackets. +func (c *Conn) stampHandshakeSequence(pkts []*packet) { + epoch := c.handshakeConfig.initialEpoch + for _, p := range pkts { + p.record.Header.Epoch += epoch + if h, ok := p.record.Content.(*handshake.Handshake); ok { + h.Header.MessageSequence = uint16(c.state.handshakeSendSequence) //nolint:gosec // G115 + c.state.handshakeSendSequence++ + } + } +} + +// primeHandshakeRecv sends a single recvHandshakeState to the FSM so that its +// wait state parses messages already pushed into handshakeCache during the +// dual-stack version negotiation mode. Without this, the FSM would block until +// its retransmit timer fires, since readAndBufferNoFSM does not signal. +// The send blocks until the FSM reaches wait() or the handshake is torn down. +func (c *Conn) primeHandshakeRecv(ctx context.Context) { + s := recvHandshakeState{ + done: make(chan struct{}), + isRetransmit: false, + } + select { + case c.handshakeRecv <- s: + select { + case <-s.done: + case <-ctx.Done(): + case <-c.fsm.Done(): + } + case <-ctx.Done(): + case <-c.fsm.Done(): + } +} + +// readAndBufferNoFSM is a variant of readAndBuffer used during the dual-stack +// version negotiation phase. It reads a datagram and pushes any handshake +// fragments into handshakeCache, but does not signal an FSM (there is none +// yet) or wait for its Done channel. +func (c *Conn) readAndBufferNoFSM(ctx context.Context) error { //nolint:cyclop + bufptr, ok := poolReadBuffer.Get().(*[]byte) + if !ok { + return errFailedToAccessPoolReadBuffer + } + defer poolReadBuffer.Put(bufptr) + + b := *bufptr + i, rAddr, err := c.nextConn.ReadFromContext(ctx, b) + if err != nil { + return netError(err) + } + + pkts, err := recordlayer.ContentAwareUnpackDatagram(b[:i], len(c.state.getLocalConnectionID())) + if err != nil { + return err + } + + for _, p := range pkts { + // nolint:godox + // TODO: check version + _, _, alert, err := c.handleIncomingPacket(ctx, p, rAddr, true) + if alert != nil { + if alertErr := c.notify(ctx, alert.Level, alert.Description); alertErr != nil { + if err == nil { + err = alertErr + } + } + } + + var e *alertError + if errors.As(err, &e) && e.IsFatalOrCloseNotify() { + return e + } + if err != nil { + return err + } + } + + return nil +} + +//nolint:gocyclo,cyclop,gocognit,contextcheck func (c *Conn) handshake( ctx context.Context, - cfg *handshakeConfig, initialFlight flightVal, initialState handshakeState, + initialFlights []*packet, + postFSMSetup func(context.Context), ) error { - c.fsm = newHandshakeFSM(&c.state, c.handshakeCache, cfg, initialFlight) - done := make(chan struct{}) - ctxRead, cancelRead := context.WithCancel(context.Background()) - cfg.onFlightState = func(_ flightVal, s handshakeState) { - if s == handshakeFinished && c.setHandshakeCompletedSuccessfully() { - close(done) + if c.state.localVersion == protocol.Version1_3 { + c.fsm = &handshakeFSM13{ + currentFlight: flightVal13(initialFlight), + flights: initialFlights, + retransmit: initialFlights != nil, + state: &c.state, + cache: c.handshakeCache, + cfg: c.handshakeConfig, + retransmitInterval: c.handshakeConfig.initialRetransmitInterval, + closed: make(chan struct{}), + } + c.handshakeConfig.onFlightState13 = func(_ flightVal13, s handshakeState) { + // The ACK for the last flights has been received and we are in a Finished state. + // nolint:godox + // TODO: should be moved to FSM. + if s == handshakeFinished && c.setHandshakeCompletedSuccessfully() { + close(done) + } + } + } else { + c.fsm = &handshakeFSM12{ + currentFlight: initialFlight, + flights: initialFlights, + retransmit: initialFlights != nil, + state: &c.state, + cache: c.handshakeCache, + cfg: c.handshakeConfig, + retransmitInterval: c.handshakeConfig.initialRetransmitInterval, + closed: make(chan struct{}), + } + c.handshakeConfig.onFlightState = func(_ flightVal, s handshakeState) { + if s == handshakeFinished && c.setHandshakeCompletedSuccessfully() { + close(done) + } } } + ctxRead, cancelRead := context.WithCancel(context.Background()) ctxHs, cancel := context.WithCancel(context.Background()) c.closeLock.Lock() @@ -1219,6 +1555,11 @@ func (c *Conn) handshake( } } }() + + if postFSMSetup != nil { + postFSMSetup(ctxHs) + } + go func() { defer func() { if c.isHandshakeCompletedSuccessfully() { @@ -1391,7 +1732,11 @@ func (c *Conn) sessionKey() []byte { // As ServerName can be like 0.example.com, it's better to add // delimiter character which is not allowed to be in // neither address or domain name. - return []byte(c.rAddr.String() + "_" + c.fsm.cfg.serverName) + if c.state.localVersion == protocol.Version1_3 { + return []byte(c.rAddr.String() + "_" + c.fsm.(*handshakeFSM13).cfg.serverName) //nolint:forcetypeassert + } + + return []byte(c.rAddr.String() + "_" + c.fsm.(*handshakeFSM12).cfg.serverName) //nolint:forcetypeassert } return c.state.SessionID diff --git a/conn_test.go b/conn_test.go index e7e33b05..8bf675d6 100644 --- a/conn_test.go +++ b/conn_test.go @@ -3127,10 +3127,20 @@ func TestEllipticCurveConfiguration(t *testing.T) { assert.True(t, ok, "Failed to default Elliptic curves") if len(test.ConfigCurves) != 0 { - assert.Equal(t, len(test.HandshakeCurves), len(server.fsm.cfg.ellipticCurves), "Failed to configure Elliptic curves") + assert.Equal( + t, + len(test.HandshakeCurves), + len(server.fsm.(*handshakeFSM12).cfg.ellipticCurves), //nolint:forcetypeassert + "Failed to configure Elliptic curves", + ) for i, c := range test.ConfigCurves { - assert.Equal(t, c, server.fsm.cfg.ellipticCurves[i], "Failed to maintain Elliptic curve order") + assert.Equal( + t, + c, + server.fsm.(*handshakeFSM12).cfg.ellipticCurves[i], //nolint:forcetypeassert + "Failed to maintain Elliptic curve order", + ) } } @@ -3500,3 +3510,198 @@ func TestCloseWithoutHandshake(t *testing.T) { assert.NoError(t, err) assert.NoError(t, server.Close()) } + +// WIP! Tests if DTLS 1.3 handshake flow is enabled and the correct error is returned. +func TestDTLS13Enabled(t *testing.T) { + ca, cb := dpipe.Pipe() + + // Setup client + clientCert, err := selfsign.GenerateSelfSigned() + assert.NoError(t, err) + + clientcfg, err := buildClientConfig( + WithCertificates(clientCert), + WithInsecureSkipVerify(true), + ) + + assert.NoError(t, err) + + clientcfg.minVersion = protocol.Version1_3 + clientcfg.maxVersion = protocol.Version1_3 + + client, err := Client(dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), clientcfg) + assert.NoError(t, err) + defer func() { + _ = client.Close() + }() + + _, ok := client.ConnectionState() + assert.False(t, ok) + + ctxClient, cancelClient := context.WithTimeout(context.Background(), 5*time.Second) + defer cancelClient() + errorChannel := make(chan error) + go func() { + errC := client.HandshakeContext(ctxClient) + errorChannel <- errC + }() + + err = <-errorChannel + assert.Error(t, err) + assert.ErrorIs(t, err, errStateUnimplemented13) + + // Setup server + serverCert, err := selfsign.GenerateSelfSigned() + assert.NoError(t, err) + + servercfg, err := buildServerConfig( + WithCertificates(serverCert), + WithInsecureSkipVerify(true), + ) + + assert.NoError(t, err) + + servercfg.minVersion = protocol.Version1_3 + servercfg.maxVersion = protocol.Version1_3 + + server, err := Server(dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), servercfg) + assert.NoError(t, err) + defer func() { + _ = server.Close() + }() + + _, ok = server.ConnectionState() + assert.False(t, ok) + + ctxServer, cancelServer := context.WithTimeout(context.Background(), 5*time.Second) + defer cancelServer() + go func() { + errS := server.HandshakeContext(ctxServer) + errorChannel <- errS + }() + err = <-errorChannel + assert.Error(t, err) + assert.ErrorIs(t, err, errStateUnimplemented13) +} + +// WIP! Tests if the dual stack mode client managed to negotiate a version successfully. +func TestDTLSDualStackClient(t *testing.T) { + defer test.CheckRoutines(t)() + defer test.TimeOut(time.Second * 10).Stop() + + // Setup client + clientCert, err := selfsign.GenerateSelfSigned() + assert.NoError(t, err) + + clientcfg, err := buildClientConfig( + WithCertificates(clientCert), + WithInsecureSkipVerify(true), + ) + + assert.NoError(t, err) + + clientcfg.minVersion = protocol.Version1_2 + clientcfg.maxVersion = protocol.Version1_3 + + // Setup server + serverCert, err := selfsign.GenerateSelfSigned() + assert.NoError(t, err) + + servercfg, err := buildServerConfig( + WithCertificates(serverCert), + WithInsecureSkipVerify(true), + ) + + assert.NoError(t, err) + + servercfg.minVersion = protocol.Version1_2 + servercfg.maxVersion = protocol.Version1_2 + + testDTLSDualStack(t, *clientcfg, *servercfg) +} + +// WIP! Tests if the dual stack mode server managed to negotiate a version successfully. +func TestDTLSDualStackServer(t *testing.T) { + defer test.CheckRoutines(t)() + defer test.TimeOut(time.Second * 10).Stop() + + // Setup client + clientCert, err := selfsign.GenerateSelfSigned() + assert.NoError(t, err) + + clientcfg, err := buildClientConfig( + WithCertificates(clientCert), + WithInsecureSkipVerify(true), + ) + + assert.NoError(t, err) + + clientcfg.minVersion = protocol.Version1_2 + clientcfg.maxVersion = protocol.Version1_2 + + // Setup server + serverCert, err := selfsign.GenerateSelfSigned() + assert.NoError(t, err) + + servercfg, err := buildServerConfig( + WithCertificates(serverCert), + WithInsecureSkipVerify(true), + ) + + assert.NoError(t, err) + + servercfg.minVersion = protocol.Version1_2 + servercfg.maxVersion = protocol.Version1_3 + + testDTLSDualStack(t, *clientcfg, *servercfg) +} + +// WIP! Tests if the dual stack mode managed to negotiate a version successfully. +func testDTLSDualStack(t *testing.T, clientCfg Config, serverCfg Config) { + t.Helper() + ca, cb := dpipe.Pipe() + + client, err := Client(dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), &clientCfg) + assert.NoError(t, err) + defer func() { + _ = client.Close() + }() + + _, ok := client.ConnectionState() + assert.False(t, ok) + + ctxClient, cancelClient := context.WithTimeout(context.Background(), 5*time.Second) + defer cancelClient() + errorChannel := make(chan error, 2) + + server, err := Server(dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &serverCfg) + assert.NoError(t, err) + defer func() { + _ = server.Close() + }() + + _, ok = server.ConnectionState() + assert.False(t, ok) + + ctxServer, cancelServer := context.WithTimeout(context.Background(), 5*time.Second) + defer cancelServer() + + go func() { + errC := client.HandshakeContext(ctxClient) + errorChannel <- errC + }() + + go func() { + errS := server.HandshakeContext(ctxServer) + errorChannel <- errS + }() + + err = <-errorChannel + assert.NoError(t, err) + + err = <-errorChannel + assert.NoError(t, err) + + assert.NoError(t, server.Close()) + assert.NoError(t, client.Close()) +} diff --git a/errors.go b/errors.go index 0db0de67..ea8046a9 100644 --- a/errors.go +++ b/errors.go @@ -108,6 +108,9 @@ var ( //nolint:err113 errUnsupportedProtocolVersion = &FatalError{Err: errors.New("unsupported protocol version")} //nolint:err113 + errNoCommonProtocolVersion = &FatalError{Err: errors.New("no common DTLS version between peer and local")} + errInvalidProtocolVersionState = &FatalError{Err: errors.New("invalid protocol version in state")} + //nolint:err113 errPSKAndIdentityMustBeSetForClient = &FatalError{ Err: errors.New("PSK and PSK Identity Hint must both be set for client"), } @@ -128,6 +131,10 @@ var ( //nolint:err113 errInvalidFlight = &InternalError{Err: errors.New("invalid flight number")} + //nolint:err113,unused + errFlightUnimplemented13 = &InternalError{Err: errors.New("unimplemented DTLS 1.3 flight")} + //nolint:err113 + errStateUnimplemented13 = &InternalError{Err: errors.New("unimplemented DTLS 1.3 handshake state")} //nolint:err113 errKeySignatureGenerateUnimplemented = &InternalError{ Err: errors.New("unable to generate key signature, unimplemented"), @@ -215,6 +222,10 @@ var ( errNilOnConnectionAttempt = &FatalError{ Err: errors.New("on connection attempt option requires a non-nil callback"), } + //nolint:err113 + errInvalidGroupInKeyShare = &FatalError{ + Err: errors.New("groups offered in the key share extension must be included in the supported groups extension"), + } ) // FatalError indicates that the DTLS connection is no longer available. diff --git a/flight_13.go b/flight_13.go new file mode 100644 index 00000000..396a28f1 --- /dev/null +++ b/flight_13.go @@ -0,0 +1,194 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package dtls + +type flightVal13 uint8 + +/* +// [RFC9147 Section-5.7] + +Full DTLS Handshake (with Cookie Exchange): + +Client Server + + +----------+ + Waiting | Flight 0 | + +----------+ + + +----------+ + ClientHello | Flight 1 | + --------> +----------+ + + +----------+ + <-------- HelloRetryRequest | Flight 2 | + + cookie +----------+ + + + +----------+ +ClientHello | Flight 3 | + + cookie --------> +----------+ + + + + ServerHello + {EncryptedExtensions} +----------+ + {CertificateRequest*} | Flight 4 | + {Certificate*} +----------+ + {CertificateVerify*} + {Finished} + <-------- [Application Data*] + + + + {Certificate*} +----------+ + {CertificateVerify*} | Flight 5 | + {Finished} --------> +----------+ + [Application Data] + +----------+ + <-------- [ACK] | Flight 6 | + [Application Data*] +----------+ + + [Application Data] <-------> [Application Data] + + + + +Resumption and PSK Handshake (without Cookie Exchange): + +Client Server + + ClientHello +-----------+ + + pre_shared_key | Flight 3a | + + psk_key_exchange_modes +-----------+ + + key_share* --------> + + + ServerHello + + pre_shared_key +-----------+ + + key_share* | Flight 4a | + {EncryptedExtensions} +-----------+ + <-------- {Finished} + [Application Data*] + +-----------+ + {Finished} --------> | Flight 5a | + [Application Data*] +-----------+ + + +-----------+ + <-------- [ACK] | Flight 6a | + [Application Data*] +-----------+ + + [Application Data] <-------> [Application Data] + + +Zero-RTT Handshake: + +Client Server + + ClientHello + + early_data + + psk_key_exchange_modes +-----------+ + + key_share* | Flight 3b | + + pre_shared_key +-----------+ + (Application Data*) --------> + + ServerHello + + pre_shared_key + + key_share* +-----------+ + {EncryptedExtensions} | Flight 4b | + {Finished} +-----------+ + <-------- [Application Data*] + + + +-----------+ + {Finished} --------> | Flight 5b | + [Application Data*] +-----------+ + + +-----------+ + <-------- [ACK] | Flight 6b | + [Application Data*] +-----------+ + + [Application Data] <-------> [Application Data] + + +NewSessionTicket Message: + +Client Server + + +-----------+ + <-------- [NewSessionTicket] | Flight 4c | + +-----------+ + + +-----------+ +[ACK] --------> | Flight 5c | + +-----------+ +*/ + +const ( + flight13_0 flightVal13 = iota + 1 + flight13_1 + flight13_2 + flight13_3 + flight13_3a + flight13_3b + flight13_4 + flight13_4a + flight13_4b + flight13_4c + flight13_5 + flight13_5a + flight13_5b + flight13_5c + flight13_6 + flight13_6a + flight13_6b +) + +func (f flightVal13) String() string { //nolint:cyclop + switch f { + case flight13_0: + return "Flight13 0" + case flight13_1: + return "Flight13 1" + case flight13_2: + return "Flight13 2" + case flight13_3: + return "Flight13 3" + case flight13_3a: + return "Flight13 3a" + case flight13_3b: + return "Flight13 3b" + case flight13_4: + return "Flight13 4" + case flight13_4a: + return "Flight13 4a" + case flight13_4b: + return "Flight13 4b" + case flight13_4c: + return "Flight13 4c" + case flight13_5: + return "Flight13 5" + case flight13_5a: + return "Flight13 5a" + case flight13_5b: + return "Flight13 5b" + case flight13_5c: + return "Flight13 5c" + case flight13_6: + return "Flight13 6" + case flight13_6a: + return "Flight13 6a" + case flight13_6b: + return "Flight13 6b" + default: + return "Invalid Flight" + } +} + +func (f flightVal13) isLastSendFlight() bool { // nolint: unused + return f == flight13_6 || f == flight13_6a || f == flight13_6b || f == flight13_5c +} + +func (f flightVal13) isLastRecvFlight() bool { // nolint: unused + return f == flight13_5 || f == flight13_5a || f == flight13_5b || f == flight13_4c +} diff --git a/flighthandler_13.go b/flighthandler_13.go new file mode 100644 index 00000000..f24d12eb --- /dev/null +++ b/flighthandler_13.go @@ -0,0 +1,43 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package dtls + +import ( + "context" + + "github.com/pion/dtls/v3/pkg/protocol/alert" +) + +// Parse received handshakes and return next flightVal. +type flightParser13 func( //nolint:unused + context.Context, + flightConn, + *State, + *handshakeCache, + *handshakeConfig, +) (flightVal13, *alert.Alert, error) + +//nolint:unused +type flightGenerator13 func(flightConn, *State, *handshakeCache, *handshakeConfig) ([]*packet, *alert.Alert, error) + +//nolint:unused +func (f flightVal13) getFlightParser13() (flightParser13, error) { + switch f { + case flight13_0: + return flight13_0Parse, nil + default: + return nil, errFlightUnimplemented13 + } +} + +func (f flightVal13) getFlightGenerator13() (gen flightGenerator13, retransmit bool, err error) { + switch f { + case flight13_0: + return flight13_0Generate, true, nil + case flight13_1: + return flight13_1Generate, true, nil + default: + return nil, false, errFlightUnimplemented13 + } +} diff --git a/flighthandlers_client_13.go b/flighthandlers_client_13.go new file mode 100644 index 00000000..e5c7ffba --- /dev/null +++ b/flighthandlers_client_13.go @@ -0,0 +1,238 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package dtls + +import ( + "bytes" + "context" + + "github.com/pion/dtls/v3/pkg/crypto/elliptic" + "github.com/pion/dtls/v3/pkg/protocol" + "github.com/pion/dtls/v3/pkg/protocol/alert" + "github.com/pion/dtls/v3/pkg/protocol/extension" + "github.com/pion/dtls/v3/pkg/protocol/handshake" + "github.com/pion/dtls/v3/pkg/protocol/recordlayer" +) + +// we'll add the flight handlers for the DTLS 1.3 client here. +// +// +----------+ +// | Flight 1 | +// | Flight 3 | +// | Flight 5 | +// +----------+ +// +// +-----------+ +// | Flight 3a | +// | Flight 5a | +// +-----------+ +// +// +-----------+ +// | Flight 3b | +// | Flight 5b | +// +-----------+ +// +// +-----------+ +// | Flight 5c | +// +-----------+ + +// nolint:unused +func flight13_1Parse( + ctx context.Context, + conn flightConn, + state *State, + cache *handshakeCache, + cfg *handshakeConfig, +) (flightVal13, *alert.Alert, error) { + seq, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence, state.cipherSuite, + handshakeCachePullRule{handshake.TypeServerHello, cfg.initialEpoch, false, true}, + ) + if !ok { + // No valid message received. Keep reading + return 0, nil, nil + } + + sh, ok := msgs[handshake.TypeServerHello].(*handshake.MessageServerHello) + if !ok { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, nil + } + + randomBytes := sh.Random.MarshalFixed() + if !bytes.Equal(randomBytes[:], handshake.HelloRetryRequestRandom()) { + // Flight1 and flight2 were skipped. + // Parse as flight3. + return flight13_3Parse(ctx, conn, state, cache, cfg) + } + // Handle HelloRetryRequest + + if !sh.Version.Equal(protocol.Version1_0) && !sh.Version.Equal(protocol.Version1_2) { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.ProtocolVersion}, errUnsupportedProtocolVersion + } + + // nolint:godox + // TODO: negotiate minimial set of extensions necessary for the client + // to generate a correct CH pair. As with the ServerHello, a + // HelloRetryRequest MUST NOT contain any extensions that were not first + // offered by the client in its ClientHello, with the exception of + // optionally the "cookie" extension + for _, val := range sh.Extensions { + switch ext := val.(type) { + case *extension.SupportedVersions: + // nolint:godox + // TODO: negotiate version + state.remoteVersions = ext.Versions + case *extension.CookieExt: + state.cookie = ext.Cookie + case *extension.KeyShare: + state.remoteKeyEntries = ext.ClientShares + } + } + + state.handshakeRecvSequence = seq + + return flight13_3, nil, nil +} + +//nolint:unused +func flight13_3Parse( + ctx context.Context, + conn flightConn, + state *State, + cache *handshakeCache, + cfg *handshakeConfig, +) (flightVal13, *alert.Alert, error) { + return 0, nil, errFlightUnimplemented13 +} + +//nolint:cyclop +func flight13_1Generate( + _ flightConn, + state *State, + _ *handshakeCache, + cfg *handshakeConfig, +) ([]*packet, *alert.Alert, error) { + var zeroEpoch uint16 + state.localEpoch.Store(zeroEpoch) + state.remoteEpoch.Store(zeroEpoch) + if len(cfg.ellipticCurves) < 1 { + return nil, nil, errEmptyEllipticCurves + } + state.namedCurve = cfg.ellipticCurves[0] + state.cookie = nil + + if err := state.localRandom.Populate(); err != nil { + return nil, nil, err + } + + if cfg.helloRandomBytesGenerator != nil { + state.localRandom.RandomBytes = cfg.helloRandomBytesGenerator() + } + + extensions := []extension.Extension{} + + if cfg.extendedMasterSecret == RequestExtendedMasterSecret || + cfg.extendedMasterSecret == RequireExtendedMasterSecret { + extensions = append(extensions, &extension.UseExtendedMasterSecret{ + Supported: true, + }) + } + + extensions = append(extensions, &extension.RenegotiationInfo{ + RenegotiatedConnection: 0, + }) + + var setEllipticCurveCryptographyClientHelloExtensions bool + for _, c := range cfg.localCipherSuites { + if c.ECC() { + setEllipticCurveCryptographyClientHelloExtensions = true + + break + } + } + + if setEllipticCurveCryptographyClientHelloExtensions { + extensions = append(extensions, []extension.Extension{ + &extension.SupportedEllipticCurves{ + EllipticCurves: cfg.ellipticCurves, + }, + &extension.SupportedPointFormats{ + PointFormats: []elliptic.CurvePointFormat{elliptic.CurvePointFormatUncompressed}, + }, + }...) + } + + if len(cfg.supportedProtocols) > 0 { + extensions = append(extensions, &extension.ALPN{ProtocolNameList: cfg.supportedProtocols}) + } + + var entries []extension.KeyShareEntry + for _, group := range cfg.ellipticCurves { + keypair, err := elliptic.GenerateKeypair(group) + if err != nil { + return nil, nil, err + } + entries = append(entries, extension.KeyShareEntry{ + Group: keypair.Curve, KeyExchange: keypair.PublicKey, + }) + } + state.localKeyEntries = entries + extensions = append(extensions, &extension.KeyShare{ + ClientShares: entries, + }) + + extensions = append(extensions, &extension.SupportedVersions{ + Versions: supportedVersionsRange(cfg.minVersion, cfg.maxVersion), + }) + + if len(cfg.localCertSignatureSchemes) > 0 { + extensions = append(extensions, &extension.SignatureAlgorithmsCert{ + SignatureHashAlgorithms: cfg.localCertSignatureSchemes, + }) + } + + if len(cfg.serverName) > 0 { + extensions = append(extensions, &extension.ServerName{ServerName: cfg.serverName}) + } + + if len(cfg.localSRTPProtectionProfiles) > 0 { + extensions = append(extensions, &extension.UseSRTP{ + ProtectionProfiles: cfg.localSRTPProtectionProfiles, + MasterKeyIdentifier: cfg.localSRTPMasterKeyIdentifier, + }) + } + + // connection ID + + // Pre_shared_key must be last extension + + clientHello := &handshake.MessageClientHello{ + Version: protocol.Version1_2, + SessionID: state.SessionID, + Cookie: state.cookie, + Random: state.localRandom, + // Add DTLS 1.3 ciphersuites + CipherSuiteIDs: cipherSuiteIDs(cfg.localCipherSuites), + CompressionMethods: defaultCompressionMethods(), + Extensions: extensions, + } + + var content handshake.Handshake + + if cfg.clientHelloMessageHook != nil { + content = handshake.Handshake{Message: cfg.clientHelloMessageHook(*clientHello)} + } else { + content = handshake.Handshake{Message: clientHello} + } + + return []*packet{ + { + record: &recordlayer.RecordLayer{ + Header: recordlayer.Header{ + Version: protocol.Version1_2, + }, + Content: &content, + }, + }, + }, nil, nil +} diff --git a/flighthandlers_server_13.go b/flighthandlers_server_13.go new file mode 100644 index 00000000..1450923d --- /dev/null +++ b/flighthandlers_server_13.go @@ -0,0 +1,211 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package dtls + +import ( + "context" + "crypto/rand" + "slices" + + "github.com/pion/dtls/v3/pkg/crypto/elliptic" + "github.com/pion/dtls/v3/pkg/protocol" + "github.com/pion/dtls/v3/pkg/protocol/alert" + "github.com/pion/dtls/v3/pkg/protocol/extension" + "github.com/pion/dtls/v3/pkg/protocol/handshake" +) + +// we'll add the flight handlers for the DTLS 1.3 server here. +// +// Flight0 +// +// +----------+ +// | Flight 2 | +// | Flight 4 | +// | Flight 6 | +// +----------+ +// +// +-----------+ +// | Flight 4a | +// | Flight 6a | +// +-----------+ +// +// +-----------+ +// | Flight 4b | +// | Flight 6b | +// +-----------+ +// +// +-----------+ +// | Flight 4c | +// +-----------+ + +//nolint:cyclop,gocognit,gocyclo,unused +func flight13_0Parse( + _ context.Context, + _ flightConn, + state *State, + cache *handshakeCache, + cfg *handshakeConfig, +) (flightVal13, *alert.Alert, error) { + if state.localVersion != protocol.Version1_3 { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, errInvalidProtocolVersionState + } + seq, msgs, ok := cache.fullPullMap(0, state.cipherSuite, + handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false}, + ) + if !ok { + // No valid message received. Keep reading + return 0, nil, nil + } + + // Connection Identifiers must be negotiated afresh on session resumption. + // https://datatracker.ietf.org/doc/html/rfc9146#name-the-connection_id-extension + state.setLocalConnectionID(nil) + state.remoteConnectionID = nil + + state.handshakeRecvSequence = seq + + var clientHello *handshake.MessageClientHello + + // Validate type + if clientHello, ok = msgs[handshake.TypeClientHello].(*handshake.MessageClientHello); !ok { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, nil + } + + if !clientHello.Version.Equal(protocol.Version1_2) { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.ProtocolVersion}, errUnsupportedProtocolVersion + } + + state.remoteRandom = clientHello.Random + + cipherSuites := []CipherSuite{} + for _, id := range clientHello.CipherSuiteIDs { + if id == renegotiationInfoSCSV { + state.remoteSupportsRenegotiation = true + + continue + } + if c := cipherSuiteForID(CipherSuiteID(id), cfg.customCipherSuites); c != nil { + cipherSuites = append(cipherSuites, c) + } + } + + // nolint:godox + // TODO: check for DTLS 1.3 cipher suites + if state.cipherSuite, ok = findMatchingCipherSuite(cipherSuites, cfg.localCipherSuites); !ok { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errCipherSuiteNoIntersection + } + + for _, val := range clientHello.Extensions { + switch ext := val.(type) { + case *extension.SupportedEllipticCurves: + if len(ext.EllipticCurves) == 0 { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errNoSupportedEllipticCurves + } + state.remoteGroups = ext.EllipticCurves + case *extension.UseSRTP: + profile, ok := findMatchingSRTPProfile(cfg.localSRTPProtectionProfiles, ext.ProtectionProfiles) + if !ok { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errServerNoMatchingSRTPProfile + } + state.setSRTPProtectionProfile(profile) + state.remoteSRTPMasterKeyIdentifier = ext.MasterKeyIdentifier + case *extension.UseExtendedMasterSecret: + if cfg.extendedMasterSecret != DisableExtendedMasterSecret { + state.extendedMasterSecret = true + } + case *extension.ServerName: + state.serverName = ext.ServerName // remote server name + case *extension.RenegotiationInfo: + state.remoteSupportsRenegotiation = true + case *extension.ALPN: + state.peerSupportedProtocols = ext.ProtocolNameList + case *extension.ConnectionID: + // Only set connection ID to be sent if server supports connection + // IDs. + if cfg.connectionIDGenerator != nil { + state.remoteConnectionID = ext.CID + } + case *extension.SignatureAlgorithmsCert: + // Store the client's certificate signature schemes for later validation + state.remoteCertSignatureSchemes = ext.SignatureHashAlgorithms + case *extension.SupportedVersions: + state.remoteVersions = ext.Versions + case *extension.KeyShare: + state.remoteKeyEntries = ext.ClientShares + } + } + + if !slices.Contains(state.remoteVersions, protocol.Version1_3) { + // nolint:godox + // TODO: This should actually handover the state machine to DTLS 1.2 + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, errInvalidProtocolVersionState + } + + // If the client doesn't support connection IDs, the server should not + // expect one to be sent. + if state.remoteConnectionID == nil { + state.setLocalConnectionID(nil) + } + + if cfg.extendedMasterSecret == RequireExtendedMasterSecret && !state.extendedMasterSecret { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errServerRequiredButNoClientEMS + } + + if state.localKeypair == nil { + var err error + state.localKeypair, err = elliptic.GenerateKeypair(state.namedCurve) + if err != nil { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.IllegalParameter}, err + } + } + + nextFlight := flight13_2 + + var groups []elliptic.Curve + for _, entry := range state.remoteKeyEntries { + // Clients MUST NOT offer any KeyShareEntry values + // for groups not listed in the client's "supported_groups" extension. + // Servers MAY check for violations of these rules and abort the + // handshake with an "illegal_parameter" alert if one is violated. + if !slices.Contains(state.remoteGroups, entry.Group) { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.IllegalParameter}, errInvalidGroupInKeyShare + } + groups = append(groups, entry.Group) + } + state.namedCurve, _ = findMatchingGroup(groups, cfg.ellipticCurves) + + if cfg.insecureSkipHelloVerify { + nextFlight = flight13_4 + } + + return nextFlight, nil, nil +} + +// nolint:unparam +func flight13_0Generate( + _ flightConn, + state *State, + _ *handshakeCache, + cfg *handshakeConfig, +) ([]*packet, *alert.Alert, error) { + if !cfg.insecureSkipHelloVerify { + state.cookie = make([]byte, cookieLength) + if _, err := rand.Read(state.cookie); err != nil { + return nil, nil, err + } + } + + var zeroEpoch uint16 + state.localEpoch.Store(zeroEpoch) + state.remoteEpoch.Store(zeroEpoch) + if len(cfg.ellipticCurves) < 1 { + return nil, nil, errEmptyEllipticCurves + } + + if err := state.localRandom.Populate(); err != nil { + return nil, nil, err + } + + return nil, nil, nil +} diff --git a/handshaker.go b/handshaker.go index 425934d4..faf51530 100644 --- a/handshaker.go +++ b/handshaker.go @@ -82,7 +82,7 @@ func (s handshakeState) String() string { } } -type handshakeFSM struct { +type handshakeFSM12 struct { currentFlight flightVal flights []*packet retransmit bool @@ -140,6 +140,8 @@ type handshakeConfig struct { minVersion protocol.Version maxVersion protocol.Version + + onFlightState13 func(flightVal13, handshakeState) } type flightConn interface { @@ -171,11 +173,11 @@ func srvCliStr(isClient bool) string { return "server" } -func newHandshakeFSM( +func newHandshakeFSM12( s *State, cache *handshakeCache, cfg *handshakeConfig, initialFlight flightVal, -) *handshakeFSM { - return &handshakeFSM{ +) *handshakeFSM12 { + return &handshakeFSM12{ currentFlight: initialFlight, state: s, cache: cache, @@ -185,7 +187,17 @@ func newHandshakeFSM( } } -func (s *handshakeFSM) Run(ctx context.Context, conn flightConn, initialState handshakeState) error { +type handshakeFSM interface { + Done() <-chan struct{} + Run(ctx context.Context, conn flightConn, initialState handshakeState) error + finish(ctx context.Context, c flightConn) (handshakeState, error) + prepare(ctx context.Context, conn flightConn) (handshakeState, error) + send(ctx context.Context, c flightConn) (handshakeState, error) + wait(ctx context.Context, conn flightConn) (handshakeState, error) +} + +//nolint:dupl +func (s *handshakeFSM12) Run(ctx context.Context, conn flightConn, initialState handshakeState) error { state := initialState defer func() { close(s.closed) @@ -214,11 +226,12 @@ func (s *handshakeFSM) Run(ctx context.Context, conn flightConn, initialState ha } } -func (s *handshakeFSM) Done() <-chan struct{} { +func (s *handshakeFSM12) Done() <-chan struct{} { return s.closed } -func (s *handshakeFSM) prepare(ctx context.Context, conn flightConn) (handshakeState, error) { +//nolint:dupl +func (s *handshakeFSM12) prepare(ctx context.Context, conn flightConn) (handshakeState, error) { s.flights = nil // Prepare flights var ( @@ -266,7 +279,7 @@ func (s *handshakeFSM) prepare(ctx context.Context, conn flightConn) (handshakeS return handshakeSending, nil } -func (s *handshakeFSM) send(ctx context.Context, c flightConn) (handshakeState, error) { +func (s *handshakeFSM12) send(ctx context.Context, c flightConn) (handshakeState, error) { // Send flights if err := c.writePackets(ctx, s.flights); err != nil { return handshakeErrored, err @@ -279,7 +292,7 @@ func (s *handshakeFSM) send(ctx context.Context, c flightConn) (handshakeState, return handshakeWaiting, nil } -func (s *handshakeFSM) wait(ctx context.Context, conn flightConn) (handshakeState, error) { //nolint:gocognit,cyclop +func (s *handshakeFSM12) wait(ctx context.Context, conn flightConn) (handshakeState, error) { //nolint:gocognit,cyclop parse, errFlight := s.currentFlight.getFlightParser() if errFlight != nil { if alertErr := conn.notify(ctx, alert.Fatal, alert.InternalError); alertErr != nil { @@ -351,7 +364,7 @@ func (s *handshakeFSM) wait(ctx context.Context, conn flightConn) (handshakeStat } } -func (s *handshakeFSM) finish(ctx context.Context, c flightConn) (handshakeState, error) { +func (s *handshakeFSM12) finish(ctx context.Context, c flightConn) (handshakeState, error) { select { case state := <-c.recvHandshake(): close(state.done) diff --git a/handshaker_13.go b/handshaker_13.go new file mode 100644 index 00000000..50204835 --- /dev/null +++ b/handshaker_13.go @@ -0,0 +1,169 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package dtls + +import ( + "context" + "time" + + "github.com/pion/dtls/v3/pkg/protocol/alert" + "github.com/pion/dtls/v3/pkg/protocol/handshake" +) + +// [RFC9147 Section-5.8.1] +// +-----------+ +// | PREPARING | +// +----------> | | +// | | | +// | +-----------+ +// | | +// | | Buffer next flight +// | | +// | \|/ +// | +-----------+ +// | | | +// | | SENDING |<------------------+ +// | | | | +// | +-----------+ | +// Receive | | | +// next | | Send flight or partial | +// flight | | flight | +// | | | +// | | Set retransmit timer | +// | \|/ | +// | +-----------+ | +// | | | | +// +------------| WAITING |-------------------+ +// | +----->| | Timer expires | +// | | +-----------+ | +// | | | | | | +// | | | | | | +// | +----------+ | +--------------------+ +// | Receive record | Read retransmit or ACK +// Receive | (Maybe Send ACK) | +// last | | +// flight | | Receive ACK +// | | for last flight +// \|/ | +// | +// +-----------+ | +// | | <---------+ +// | FINISHED | +// | | +// +-----------+ +// | /|\ +// | | +// | | +// +---+ +// +// Server read retransmit +// Retransmit ACK + +type handshakeFSM13 struct { + currentFlight flightVal13 + flights []*packet //nolint:unused + retransmit bool //nolint:unused + retransmitInterval time.Duration + state *State + cache *handshakeCache + cfg *handshakeConfig + closed chan struct{} +} + +//nolint:dupl +func (s *handshakeFSM13) Run(ctx context.Context, conn flightConn, initialState handshakeState) error { + state := initialState + defer func() { + close(s.closed) + }() + for { + s.cfg.log.Tracef("[handshake13:%s] %s: %s", srvCliStr(s.state.isClient), s.currentFlight.String(), state.String()) + // nolint:godox + // TODO:: refactor callback, see discussion in https://github.com/pion/dtls/pull/738#discussion_r3131501159 + if s.cfg.onFlightState13 != nil { + s.cfg.onFlightState13(s.currentFlight, state) + } + var err error + switch state { + case handshakePreparing: + state, err = s.prepare(ctx, conn) + case handshakeSending: + state, err = s.send(ctx, conn) + case handshakeWaiting: + state, err = s.wait(ctx, conn) + case handshakeFinished: + state, err = s.finish(ctx, conn) + default: + return errInvalidFSMTransition + } + if err != nil { + return err + } + } +} + +func (s *handshakeFSM13) Done() <-chan struct{} { + return s.closed +} + +//nolint:dupl +func (s *handshakeFSM13) prepare(ctx context.Context, conn flightConn) (handshakeState, error) { + s.flights = nil + // Prepare flights + var ( + dtlsAlert *alert.Alert + err error + pkts []*packet + ) + gen, retransmit, errFlight := s.currentFlight.getFlightGenerator13() + if errFlight != nil { + err = errFlight + dtlsAlert = &alert.Alert{Level: alert.Fatal, Description: alert.InternalError} + } else { + pkts, dtlsAlert, err = gen(conn, s.state, s.cache, s.cfg) + s.retransmit = retransmit + } + if dtlsAlert != nil { + if alertErr := conn.notify(ctx, dtlsAlert.Level, dtlsAlert.Description); alertErr != nil { + if err != nil { + err = alertErr + } + } + } + if err != nil { + return handshakeErrored, err + } + + s.flights = pkts + epoch := s.cfg.initialEpoch + nextEpoch := epoch + for _, p := range s.flights { + p.record.Header.Epoch += epoch + if p.record.Header.Epoch > nextEpoch { + nextEpoch = p.record.Header.Epoch + } + if h, ok := p.record.Content.(*handshake.Handshake); ok { + h.Header.MessageSequence = uint16(s.state.handshakeSendSequence) //nolint:gosec // G115 + s.state.handshakeSendSequence++ + } + } + if epoch != nextEpoch { + s.cfg.log.Tracef("[handshake13:%s] -> changeCipherSpec (epoch: %d)", srvCliStr(s.state.isClient), nextEpoch) + conn.setLocalEpoch(nextEpoch) + } + + return handshakeSending, nil +} + +func (s *handshakeFSM13) send(ctx context.Context, c flightConn) (handshakeState, error) { + return handshakeErrored, errStateUnimplemented13 +} + +func (s *handshakeFSM13) wait(ctx context.Context, conn flightConn) (handshakeState, error) { + return handshakeErrored, errStateUnimplemented13 +} + +func (s *handshakeFSM13) finish(ctx context.Context, c flightConn) (handshakeState, error) { + return handshakeErrored, errStateUnimplemented13 +} diff --git a/handshaker_test.go b/handshaker_test.go index 82e8cbc8..61189428 100644 --- a/handshaker_test.go +++ b/handshaker_test.go @@ -289,7 +289,7 @@ func TestHandshaker(t *testing.T) { //nolint:gocyclo,cyclop,maintidx initialRetransmitInterval: nonZeroRetransmitInterval, } - fsm := newHandshakeFSM(&ca.state, ca.handshakeCache, cfg, flight1) + fsm := newHandshakeFSM12(&ca.state, ca.handshakeCache, cfg, flight1) err := fsm.Run(ctx, ca, handshakePreparing) switch { case errors.Is(err, context.Canceled): @@ -322,7 +322,7 @@ func TestHandshaker(t *testing.T) { //nolint:gocyclo,cyclop,maintidx initialRetransmitInterval: nonZeroRetransmitInterval, } - fsm := newHandshakeFSM(&cb.state, cb.handshakeCache, cfg, flight0) + fsm := newHandshakeFSM12(&cb.state, cb.handshakeCache, cfg, flight0) err := fsm.Run(ctx, cb, handshakePreparing) switch { case errors.Is(err, context.Canceled): diff --git a/pkg/protocol/handshake/handshake.go b/pkg/protocol/handshake/handshake.go index 6d40a5a7..16f0ce64 100644 --- a/pkg/protocol/handshake/handshake.go +++ b/pkg/protocol/handshake/handshake.go @@ -29,6 +29,18 @@ const ( TypeFinished Type = 20 ) +// HelloRetryRequestRandom is set as the Random value of a ServerHello +// to signal that the message is actually a HelloRetryRequest. +// See RFC 8446 Section 4.1.3. +func HelloRetryRequestRandom() []byte { + return []byte{ + 0xCF, 0x21, 0xAD, 0x74, 0xE5, 0x9A, 0x61, 0x11, + 0xBE, 0x1D, 0x8C, 0x02, 0x1E, 0x65, 0xB8, 0x91, + 0xC2, 0xA2, 0x11, 0x16, 0x7A, 0xBB, 0x8C, 0x5E, + 0x07, 0x9E, 0x09, 0xE2, 0xC8, 0xA8, 0x33, 0x9C, + } +} + // String returns the string representation of this type. func (t Type) String() string { //nolint:cyclop switch t { diff --git a/state.go b/state.go index c5fbfcf7..6ec5b8c1 100644 --- a/state.go +++ b/state.go @@ -12,6 +12,8 @@ import ( "github.com/pion/dtls/v3/pkg/crypto/elliptic" "github.com/pion/dtls/v3/pkg/crypto/prf" "github.com/pion/dtls/v3/pkg/crypto/signaturehash" + "github.com/pion/dtls/v3/pkg/protocol" + "github.com/pion/dtls/v3/pkg/protocol/extension" "github.com/pion/dtls/v3/pkg/protocol/handshake" "github.com/pion/transport/v4/replaydetector" ) @@ -72,6 +74,16 @@ type State struct { peerSupportedProtocols []string NegotiatedProtocol string + + // localVersion is the DTLS version we intend to speak on this connection. + localVersion protocol.Version + // remoteVersions are the DTLS versions advertised by the peer + remoteVersions []protocol.Version + // localKeyEntries are the DTLS 1.3 KeyShareEntry values generated locally + // and sent in the ClientHello's key_share extension. + localKeyEntries []extension.KeyShareEntry + remoteKeyEntries []extension.KeyShareEntry //nolint:unused + remoteGroups []elliptic.Curve } type serializedState struct { diff --git a/util.go b/util.go index 7b5cb0e5..727cde41 100644 --- a/util.go +++ b/util.go @@ -3,7 +3,52 @@ package dtls -import "slices" +import ( + "slices" + + "github.com/pion/dtls/v3/pkg/crypto/elliptic" + "github.com/pion/dtls/v3/pkg/protocol" +) + +// supportedVersionsRange returns the supported DTLS versions from maxVersion +// down to minVersion, in preference order (newest first). Only DTLS 1.2 and +// 1.3 are emitted. +func supportedVersionsRange(minVersion, maxVersion protocol.Version) []protocol.Version { + ordered := []protocol.Version{protocol.Version1_3, protocol.Version1_2} + out := make([]protocol.Version, 0, len(ordered)) + for _, v := range ordered { + if versionAtLeast(v, minVersion) && versionAtMost(v, maxVersion) { + out = append(out, v) + } + } + + return out +} + +// selectVersion picks the highest-preference version from remote that is +// within the local [minVersion, maxVersion] range. Returns false if there +// is no intersection. +func selectVersion( + remote []protocol.Version, + minVersion, maxVersion protocol.Version, +) (protocol.Version, bool) { + for _, v := range remote { + if versionAtLeast(v, minVersion) && versionAtMost(v, maxVersion) { + return v, true + } + } + + return protocol.Version{}, false +} + +func versionAtLeast(v, lo protocol.Version) bool { + // DTLS encodes newer versions as numerically smaller Minor bytes + return v.Minor <= lo.Minor +} + +func versionAtMost(v, hi protocol.Version) bool { + return v.Minor >= hi.Minor +} func findMatchingSRTPProfile(a, b []SRTPProtectionProfile) (SRTPProtectionProfile, bool) { for _, aProfile := range a { @@ -15,6 +60,18 @@ func findMatchingSRTPProfile(a, b []SRTPProtectionProfile) (SRTPProtectionProfil return 0, false } +func findMatchingGroup(a, b []elliptic.Curve) (elliptic.Curve, bool) { + for _, aGroup := range a { + for _, bGroup := range b { + if aGroup == bGroup { + return aGroup, true + } + } + } + + return 0, false +} + func findMatchingCipherSuite(a, b []CipherSuite) (CipherSuite, bool) { for _, aSuite := range a { for _, bSuite := range b {