3333
3434#include "dtls.h"
3535
36- #define DTLS_PSK_ID_LEN 16
37- #define DTLS_PSK_KEY_LEN 16
38-
3936static const uint32_t dtls_magic = 'D' << 24 | 't' << 16 | 'L' << 8 | 's' ;
4037
4138struct queue_item {
@@ -56,18 +53,8 @@ struct sol_socket_dtls {
5653 const void * data ;
5754 struct sol_vector queue ;
5855 } read , write ;
59- };
60-
61- /* Both `struct cred_item` and `struct creds` should be in its own file when
62- * these things are not hardcoded anymore. */
63- struct cred_item {
64- char * id ;
65- char * psk ;
66- };
6756
68- struct creds {
69- struct sol_vector items ;
70- char * id ;
57+ const struct sol_socket_dtls_credential_cb * credentials ;
7158};
7259
7360static bool encrypt_payload (struct sol_socket_dtls * s );
@@ -144,17 +131,22 @@ session_from_linkaddr(const struct sol_network_link_addr *addr,
144131 return to_sockaddr (addr , & session -> addr .sa , & session -> size ) >= 0 ;
145132}
146133
134+ static void
135+ clear_queue_item (struct queue_item * item )
136+ {
137+ sol_util_secure_clear_memory (item -> buffer .data , item -> buffer .capacity );
138+ sol_buffer_fini (& item -> buffer );
139+ sol_util_secure_clear_memory (item , sizeof (* item ));
140+ }
141+
147142static void
148143clear_queue (struct sol_vector * vec )
149144{
150145 struct queue_item * item ;
151146 uint16_t idx ;
152147
153- SOL_VECTOR_FOREACH_IDX (vec , item , idx ) {
154- sol_util_secure_clear_memory (item -> buffer .data , item -> buffer .capacity );
155- sol_buffer_fini (& item -> buffer );
156- sol_util_secure_clear_memory (item , sizeof (* item ));
157- }
148+ SOL_VECTOR_FOREACH_IDX (vec , item , idx )
149+ clear_queue_item (item );
158150
159151 sol_vector_clear (vec );
160152}
@@ -207,10 +199,7 @@ static int
207199remove_item_from_vector (struct sol_vector * vec , struct queue_item * item ,
208200 int retval )
209201{
210- sol_util_secure_clear_memory (item -> buffer .data , item -> buffer .capacity );
211- sol_buffer_fini (& item -> buffer );
212-
213- sol_util_secure_clear_memory (item , sizeof (* item ));
202+ clear_queue_item (item );
214203 sol_vector_del (vec , 0 );
215204
216205 return retval ;
@@ -315,7 +304,7 @@ init_dtls_if_needed(void)
315304{
316305 static bool initialized = false;
317306
318- if (!initialized ) {
307+ if (SOL_UNLIKELY ( !initialized ) ) {
319308 dtls_init ();
320309 initialized = true;
321310 SOL_DBG ("TinyDTLS initialized" );
@@ -474,9 +463,9 @@ call_user_write_cb(void *data, struct sol_socket *wrapped)
474463 SOL_DBG ("Data encrypted, should have been passed to the "
475464 "wrapped socket" );
476465 return true;
477- } else {
478- SOL_DBG ("Could not encrypt payload" );
479466 }
467+
468+ SOL_DBG ("Could not encrypt payload" );
480469 }
481470
482471 return false;
@@ -510,10 +499,8 @@ retransmit_timer_enable(struct sol_socket_dtls *s, clock_time_t next)
510499{
511500 SOL_DBG ("Next DTLS retransmission will happen in %u seconds" , next );
512501
513- if (s -> retransmit_timeout ) {
502+ if (s -> retransmit_timeout )
514503 sol_timeout_del (s -> retransmit_timeout );
515- s -> retransmit_timeout = NULL ;
516- }
517504
518505 s -> retransmit_timeout = sol_timeout_add (next * 1000 , retransmit_timer_cb ,
519506 socket );
@@ -592,12 +579,13 @@ handle_dtls_event(struct dtls_context_t *ctx, session_t *session,
592579 msg = "unknown_event" ;
593580
594581 if (level == DTLS_ALERT_LEVEL_WARNING ) {
595- SOL_WRN ("\n\nDTLS warning for socket %p: %s\n\n " , socket , msg );
582+ SOL_WRN ("DTLS warning for socket %p: %s" , socket , msg );
596583 } else if (level == DTLS_ALERT_LEVEL_FATAL ) {
597584 /* FIXME: What to do here? Destroy the wrapped socket? Renegotiate? */
598- SOL_ERR ("\n\nDTLS fatal error for socket %p: %s\n\n " , socket , msg );
585+ SOL_ERR ("DTLS fatal error for socket %p: %s" , socket , msg );
599586 } else {
600- SOL_DBG ("\n\nTLS session changed for socket %p: %s\n\n" , socket , msg );
587+ SOL_DBG ("TLS session changed for socket %p: %s" , socket , msg );
588+
601589 if (code == DTLS_EVENT_CONNECTED ) {
602590 struct queue_item * item ;
603591 uint16_t idx ;
@@ -610,6 +598,7 @@ handle_dtls_event(struct dtls_context_t *ctx, session_t *session,
610598 continue ;
611599
612600 (void )dtls_write (socket -> context , & session , item -> buffer .data , item -> buffer .used );
601+ clear_queue_item (item );
613602 }
614603 clear_queue (& socket -> write .queue );
615604 }
@@ -646,142 +635,66 @@ sol_socket_dtls_set_on_write(struct sol_socket *socket, bool (*cb)(void *data, s
646635 return sol_socket_set_on_write (s -> wrapped , call_user_write_cb , socket );
647636}
648637
649- static const char *
650- creds_find_psk (const struct creds * creds , const char * desc , size_t desc_len )
651- {
652- struct cred_item * iter ;
653- uint16_t idx ;
654-
655- SOL_DBG ("Looking for PSK with ID=%.*s" , (int )desc_len , desc );
656-
657- SOL_VECTOR_FOREACH_IDX (& creds -> items , iter , idx ) {
658- if (!memcmp (desc , iter -> id , desc_len )) /* timingsafe_bcmp()? */
659- return iter -> psk ;
660- }
661-
662- return NULL ;
663- }
664-
665- static bool
666- creds_add (struct creds * creds , const char * id , size_t id_len ,
667- const char * psk , size_t psk_len )
668- {
669- struct cred_item * item ;
670- char * psk_stored ;
671-
672- psk_stored = creds_find_psk (creds , id , id_len );
673- if (psk_stored ) {
674- if (!memcmp (psk_stored , psk , psk_len ))
675- return true;
676-
677- SOL_WRN ("Attempting to add PSK for ID=%.*s, but it's already"
678- " registered and different from the supplied key" ,
679- (int )id_len , id );
680- return false;
681- }
682-
683- item = sol_vector_append (& creds -> items );
684- SOL_NULL_CHECK (item , false);
685-
686- item -> id = strndup (id , id_len );
687- SOL_NULL_CHECK_GOTO (item -> id , no_id );
688-
689- item -> psk = strndup (psk , psk_len );
690- SOL_NULL_CHECK_GOTO (item -> psk , no_psk );
691-
692- return true;
693-
694- no_psk :
695- sol_util_secure_clear_memory (item -> id , strlen (id ));
696- free (item -> id );
697- no_id :
698- sol_util_secure_clear_memory (item , sizeof (* item ));
699- sol_vector_del_last (& creds -> items );
700-
701- return false;
702- }
703-
704- static void
705- creds_clear (struct creds * creds )
706- {
707- struct cred_item * iter ;
708- uint16_t idx ;
709-
710- SOL_VECTOR_FOREACH_IDX (& creds -> items , iter , idx ) {
711- sol_util_secure_clear_memory (iter -> id , DTLS_PSK_ID_LEN );
712- sol_util_secure_clear_memory (iter -> psk , DTLS_PSK_KEY_LEN );
713-
714- free (iter -> id );
715- free (iter -> psk );
716- }
717- sol_vector_clear (& creds -> items );
718-
719- sol_util_secure_clear_memory (creds -> id , strlen (creds -> id ));
720- free (creds -> id );
721-
722- sol_util_secure_clear_memory (creds , sizeof (* creds ));
723- }
724-
725- static bool
726- creds_init (struct creds * creds )
727- {
728- creds -> id = strdup ("1111111111111111" );
729- if (!creds -> id )
730- return false;
731-
732- sol_vector_init (& creds -> items , sizeof (struct cred_item ));
733-
734- /* FIXME: Load this information from a secure storage area somehow. */
735- if (!creds_add (creds , "1111111111111111" , DTLS_PSK_ID_LEN , "AAAAAAAAAAAAAAAA" , DTLS_PSK_KEY_LEN )) {
736- creds_clear (creds );
737- return false;
738- }
739-
740- return true;
741- }
742-
743638static int
744639get_psk_info (struct dtls_context_t * ctx , const session_t * session ,
745640 dtls_credentials_type_t type , const char * desc , size_t desc_len ,
746641 char * result , size_t result_len )
747642{
748- struct creds creds ;
643+ struct sol_socket_dtls * socket = dtls_get_app_data (ctx );
644+ ssize_t len ;
645+ void * creds ;
749646 int r = -1 ;
750647
751- if (!creds_init (& creds )) {
752- SOL_WRN ("Could not obtain PSK credentials" );
648+ SOL_NULL_CHECK (socket -> credentials ,
649+ dtls_alert_fatal_create (DTLS_ALERT_INTERNAL_ERROR ));
650+ SOL_NULL_CHECK (socket -> credentials -> init ,
651+ dtls_alert_fatal_create (DTLS_ALERT_INTERNAL_ERROR ));
652+ SOL_NULL_CHECK (socket -> credentials -> clear ,
653+ dtls_alert_fatal_create (DTLS_ALERT_INTERNAL_ERROR ));
654+ SOL_NULL_CHECK (socket -> credentials -> get_psk ,
655+ dtls_alert_fatal_create (DTLS_ALERT_INTERNAL_ERROR ));
656+ SOL_NULL_CHECK (socket -> credentials -> get_id ,
657+ dtls_alert_fatal_create (DTLS_ALERT_INTERNAL_ERROR ));
658+
659+ creds = socket -> credentials -> init (socket -> credentials -> data );
660+ if (!creds ) {
661+ SOL_WRN ("Could not initialize credential storage" );
753662 return dtls_alert_fatal_create (DTLS_ALERT_INTERNAL_ERROR );
754663 }
755664
756665 if (type == DTLS_PSK_IDENTITY || type == DTLS_PSK_HINT ) {
757666 SOL_DBG ("Server asked for PSK %s with %zu bytes, have %d" ,
758667 type == DTLS_PSK_IDENTITY ? "identity" : "hint" ,
759- result_len , DTLS_PSK_ID_LEN );
668+ result_len , SOL_DTLS_PSK_ID_LEN );
760669
761- if (result && result_len >= DTLS_PSK_ID_LEN ) {
762- memcpy (result , creds .id , DTLS_PSK_ID_LEN );
763- r = DTLS_PSK_ID_LEN ;
764- } else {
765- SOL_DBG ("Not enough space to write PSK" );
670+ len = socket -> credentials -> get_id (creds , result , result_len );
671+ if (len != SOL_DTLS_PSK_ID_LEN ) {
672+ SOL_DBG ("Not enough space to write key ID" );
766673 r = dtls_alert_fatal_create (DTLS_ALERT_INTERNAL_ERROR );
674+ } else {
675+ r = (int )len ;
767676 }
768677 } else if (type != DTLS_PSK_KEY ) {
769678 SOL_WRN ("Expecting request for PSK, got something else instead (got %d, expected %d)" ,
770679 type , DTLS_PSK_KEY );
771680 r = dtls_alert_fatal_create (DTLS_ALERT_INTERNAL_ERROR );
772- } else if (!desc || desc_len < DTLS_PSK_KEY_LEN ) {
773- SOL_WRN ("Expecting PSK key but no space to write it (got %zu, have %d)" ,
774- desc_len , DTLS_PSK_KEY_LEN );
775- r = dtls_alert_fatal_create (DTLS_ALERT_ILLEGAL_PARAMETER );
776681 } else {
777- const char * psk = creds_find_psk (& creds , desc , desc_len );
778- if (psk ) {
779- memcpy (result , psk , DTLS_PSK_KEY_LEN );
780- r = DTLS_PSK_KEY_LEN ;
682+ len = socket -> credentials -> get_psk (creds ,
683+ SOL_STR_SLICE_STR (desc , desc_len ), result , result_len );
684+ if (len != SOL_DTLS_PSK_KEY_LEN ) {
685+ if (len < 0 )
686+ SOL_WRN ("Expecting PSK key but no space to write it (need %d, got %zd <%s>)" ,
687+ SOL_DTLS_PSK_KEY_LEN , len , sol_util_strerrora (- len ));
688+ else
689+ SOL_WRN ("Expecting PSK key but no space to write it (need %d, got %zd)" ,
690+ SOL_DTLS_PSK_KEY_LEN , len );
691+ r = dtls_alert_fatal_create (DTLS_ALERT_ILLEGAL_PARAMETER );
692+ } else {
693+ r = (int )len ;
781694 }
782695 }
783696
784- creds_clear ( & creds );
697+ socket -> credentials -> clear ( creds );
785698 return r ;
786699}
787700
@@ -824,6 +737,7 @@ sol_socket_dtls_wrap_socket(struct sol_socket *to_wrap)
824737
825738 socket -> read .cb = NULL ;
826739 socket -> write .cb = NULL ;
740+ socket -> credentials = NULL ;
827741 socket -> retransmit_timeout = NULL ;
828742 socket -> wrapped = to_wrap ;
829743 socket -> base .impl = & impl ;
@@ -894,3 +808,16 @@ sol_socket_dtls_prf_keyblock(struct sol_socket *s,
894808
895809 return 0 ;
896810}
811+
812+ int
813+ sol_socket_dtls_set_credentials_callbacks (struct sol_socket * s ,
814+ const struct sol_socket_dtls_credential_cb * cb )
815+ {
816+ struct sol_socket_dtls * socket = (struct sol_socket_dtls * )s ;
817+
818+ SOL_INT_CHECK (socket -> dtls_magic , != dtls_magic , - EINVAL );
819+
820+ socket -> credentials = cb ;
821+
822+ return 0 ;
823+ }
0 commit comments