@@ -92,11 +92,12 @@ type Conn struct {
9292 cancelHandshaker func ()
9393 cancelHandshakeReader func ()
9494
95- fsm * handshakeFSM
95+ fsm handshakeFSM
9696
9797 replayProtectionWindow uint
9898
99- handshakeConfig * handshakeConfig
99+ handshakeConfig * handshakeConfig
100+ handshakeConfig13 * handshakeConfig13
100101}
101102
102103// createConn creates a new DTLS connection.
@@ -256,6 +257,14 @@ func createConn(
256257 },
257258 }
258259
260+ if config .version13 {
261+ handshakeConfig13 := & handshakeConfig13 {
262+ handshakeConfig : handshakeConfig ,
263+ }
264+ conn .handshakeConfig13 = handshakeConfig13
265+ conn .handshakeConfig = nil
266+ }
267+
259268 conn .setRemoteEpoch (0 )
260269 conn .setLocalEpoch (0 )
261270
@@ -284,7 +293,7 @@ func (c *Conn) Handshake() error {
284293//
285294// Most uses of this package need not call HandshakeContext explicitly: the
286295// first [Conn.Read] or [Conn.Write] will call it automatically.
287- func (c * Conn ) HandshakeContext (ctx context.Context ) error {
296+ func (c * Conn ) HandshakeContext (ctx context.Context ) error { //nolint:cyclop
288297 c .handshakeMutex .Lock ()
289298 defer c .handshakeMutex .Unlock ()
290299
@@ -298,6 +307,23 @@ func (c *Conn) HandshakeContext(ctx context.Context) error {
298307 c .handshakeDone = handshakeDone
299308 c .closeLock .Unlock ()
300309
310+ if c .isVersion13Enabled () {
311+ var initialFlight flightVal
312+ if c .state .isClient {
313+ initialFlight = flightVal (flight13_1 )
314+ } else {
315+ initialFlight = flightVal (flight13_0 )
316+ }
317+ initialFSMState := handshakePreparing
318+
319+ if err := c .handshake (ctx , initialFlight , initialFSMState ); err != nil {
320+ return err
321+ }
322+ c .log .Trace ("Handshake DTLS 1.3 Completed" )
323+
324+ return nil
325+ }
326+
301327 // rfc5246#section-7.4.3
302328 // In addition, the hash and signature algorithms MUST be compatible
303329 // with the key in the server's end-entity certificate.
@@ -330,7 +356,7 @@ func (c *Conn) HandshakeContext(ctx context.Context) error {
330356 initialFSMState = handshakePreparing
331357 }
332358 // Do handshake
333- if err := c .handshake (ctx , c . handshakeConfig , initialFlight , initialFSMState ); err != nil {
359+ if err := c .handshake (ctx , initialFlight , initialFSMState ); err != nil {
334360 return err
335361 }
336362
@@ -482,6 +508,8 @@ func (c *Conn) Write(payload []byte) (int, error) {
482508 return 0 , err
483509 }
484510
511+ //nolint:godox
512+ // TODO: check for version
485513 return len (payload ), c .writePackets (c .writeDeadline , []* packet {
486514 {
487515 record : & recordlayer.RecordLayer {
@@ -808,7 +836,7 @@ var poolReadBuffer = sync.Pool{ //nolint:gochecknoglobals
808836 },
809837}
810838
811- func (c * Conn ) readAndBuffer (ctx context.Context ) error { //nolint:cyclop
839+ func (c * Conn ) readAndBuffer (ctx context.Context ) error { //nolint:cyclop,gocognit
812840 bufptr , ok := poolReadBuffer .Get ().(* []byte )
813841 if ! ok {
814842 return errFailedToAccessPoolReadBuffer
@@ -828,6 +856,8 @@ func (c *Conn) readAndBuffer(ctx context.Context) error { //nolint:cyclop
828856
829857 var hasHandshake , isRetransmit bool
830858 for _ , p := range pkts {
859+ //nolint:godox
860+ // TODO: check version
831861 hs , rtx , alert , err := c .handleIncomingPacket (ctx , p , rAddr , true )
832862 if alert != nil {
833863 if alertErr := c .notify (ctx , alert .Level , alert .Description ); alertErr != nil {
@@ -875,6 +905,8 @@ func (c *Conn) handleQueuedPackets(ctx context.Context) error {
875905 c .lock .Unlock ()
876906
877907 for _ , p := range pkts {
908+ //nolint:godox
909+ // TODO: check version
878910 _ , _ , alert , err := c .handleIncomingPacket (ctx , p .data , p .rAddr , false ) // don't re-enqueue
879911 if alert != nil {
880912 if alertErr := c .notify (ctx , alert .Level , alert .Description ); alertErr != nil {
@@ -908,6 +940,17 @@ func (c *Conn) enqueueEncryptedPackets(packet addrPkt) bool {
908940 return false
909941}
910942
943+ // nolint:unused
944+ func (c * Conn ) handleIncomingPacket13 (
945+ ctx context.Context ,
946+ buf []byte ,
947+ rAddr net.Addr ,
948+ enqueue bool ,
949+ ) (bool , bool , * alert.Alert , error ) {
950+ // Placeholder function
951+ return false , false , nil , nil
952+ }
953+
911954//nolint:gocognit,gocyclo,cyclop,maintidx
912955func (c * Conn ) handleIncomingPacket (
913956 ctx context.Context ,
@@ -1129,17 +1172,29 @@ func (c *Conn) recvHandshake() <-chan recvHandshakeState {
11291172}
11301173
11311174func (c * Conn ) notify (ctx context.Context , level alert.Level , desc alert.Description ) error {
1132- if level == alert .Fatal && len (c .state .SessionID ) > 0 {
1133- // According to the RFC, we need to delete the stored session.
1134- // https://datatracker.ietf.org/doc/html/rfc5246#section-7.2
1135- if ss := c .fsm .cfg .sessionStore ; ss != nil {
1136- c .log .Tracef ("clean invalid session: %s" , c .state .SessionID )
1137- if err := ss .Del (c .sessionKey ()); err != nil {
1138- return err
1175+ if level == alert .Fatal && len (c .state .SessionID ) > 0 { //nolint:nestif
1176+ if c .isVersion13Enabled () {
1177+ // With compatibility mode for 1.3, CH uses a non-empty session_id
1178+ // https://datatracker.ietf.org/doc/html/rfc8446#appendix-D.4
1179+ if ss := c .fsm .(* handshakeFSM13 ).cfg .sessionStore ; ss != nil { //nolint:forcetypeassert
1180+ c .log .Tracef ("clean invalid session: %s" , c .state .SessionID )
1181+ if err := ss .Del (c .sessionKey ()); err != nil {
1182+ return err
1183+ }
1184+ }
1185+ } else {
1186+ // According to the RFC, we need to delete the stored session.
1187+ // https://datatracker.ietf.org/doc/html/rfc5246#section-7.2
1188+ if ss := c .fsm .(* handshakeFSM12 ).cfg .sessionStore ; ss != nil { //nolint:forcetypeassert
1189+ c .log .Tracef ("clean invalid session: %s" , c .state .SessionID )
1190+ if err := ss .Del (c .sessionKey ()); err != nil {
1191+ return err
1192+ }
11391193 }
11401194 }
11411195 }
11421196
1197+ // This should be updated with DTLS 1.3 record encoding.
11431198 return c .writePackets (ctx , []* packet {
11441199 {
11451200 record : & recordlayer.RecordLayer {
@@ -1169,20 +1224,41 @@ func (c *Conn) isHandshakeCompletedSuccessfully() bool {
11691224//nolint:cyclop,gocognit,contextcheck
11701225func (c * Conn ) handshake (
11711226 ctx context.Context ,
1172- cfg * handshakeConfig ,
11731227 initialFlight flightVal ,
11741228 initialState handshakeState ,
11751229) error {
1176- c .fsm = newHandshakeFSM (& c .state , c .handshakeCache , cfg , initialFlight )
1177-
11781230 done := make (chan struct {})
1179- ctxRead , cancelRead := context .WithCancel (context .Background ())
1180- cfg .onFlightState = func (_ flightVal , s handshakeState ) {
1181- if s == handshakeFinished && c .setHandshakeCompletedSuccessfully () {
1182- close (done )
1231+ if c .isVersion13Enabled () {
1232+ c .fsm = & handshakeFSM13 {
1233+ currentFlight : flightVal13 (initialFlight ),
1234+ state : & c .state ,
1235+ cache : c .handshakeCache ,
1236+ cfg : c .handshakeConfig13 ,
1237+ retransmitInterval : c .handshakeConfig13 .initialRetransmitInterval ,
1238+ closed : make (chan struct {}),
1239+ }
1240+ c .handshakeConfig13 .onFlightState13 = func (_ flightVal13 , s handshakeState ) {
1241+ if c .fsm .(* handshakeFSM13 ).currentFlight .isLastSendFlight () { //nolint:forcetypeassert
1242+ close (done )
1243+ }
1244+ }
1245+ } else {
1246+ c .fsm = & handshakeFSM12 {
1247+ currentFlight : initialFlight ,
1248+ state : & c .state ,
1249+ cache : c .handshakeCache ,
1250+ cfg : c .handshakeConfig ,
1251+ retransmitInterval : c .handshakeConfig .initialRetransmitInterval ,
1252+ closed : make (chan struct {}),
1253+ }
1254+ c .handshakeConfig .onFlightState = func (_ flightVal , s handshakeState ) {
1255+ if s == handshakeFinished && c .setHandshakeCompletedSuccessfully () {
1256+ close (done )
1257+ }
11831258 }
11841259 }
11851260
1261+ ctxRead , cancelRead := context .WithCancel (context .Background ())
11861262 ctxHs , cancel := context .WithCancel (context .Background ())
11871263
11881264 c .closeLock .Lock ()
@@ -1379,7 +1455,11 @@ func (c *Conn) sessionKey() []byte {
13791455 // As ServerName can be like 0.example.com, it's better to add
13801456 // delimiter character which is not allowed to be in
13811457 // neither address or domain name.
1382- return []byte (c .rAddr .String () + "_" + c .fsm .cfg .serverName )
1458+ if c .isVersion13Enabled () {
1459+ return []byte (c .rAddr .String () + "_" + c .fsm .(* handshakeFSM13 ).cfg .serverName ) //nolint:forcetypeassert
1460+ }
1461+
1462+ return []byte (c .rAddr .String () + "_" + c .fsm .(* handshakeFSM12 ).cfg .serverName ) //nolint:forcetypeassert
13831463 }
13841464
13851465 return c .state .SessionID
@@ -1406,3 +1486,7 @@ func (c *Conn) SetWriteDeadline(t time.Time) error {
14061486 // Write deadline is also fully managed by this layer.
14071487 return nil
14081488}
1489+
1490+ func (c * Conn ) isVersion13Enabled () bool {
1491+ return c .handshakeConfig13 != nil
1492+ }
0 commit comments