diff options
Diffstat (limited to 'tcpd/libcouriertls.c')
| -rw-r--r-- | tcpd/libcouriertls.c | 228 |
1 files changed, 228 insertions, 0 deletions
diff --git a/tcpd/libcouriertls.c b/tcpd/libcouriertls.c index 874a0d1..39c7d49 100644 --- a/tcpd/libcouriertls.c +++ b/tcpd/libcouriertls.c @@ -732,6 +732,203 @@ static int server_cert_cb(ssl_handle ssl, int *ad, void *arg) return SSL_TLSEXT_ERR_OK; } +#if HAVE_OPENSSL_ALPN + +/* +** Parse the TLS_ALPN setting, repeatedly invoking a callback with: +** 1) size of the protocol string +** 2) the protocol string +** 3) pass-through parameter. +*/ + +static void alpn_parse(const char *alpn, + void (*cb)(unsigned char l, + const char *p, + void *arg), + void *arg) +{ + while (alpn && *alpn) + { + const char *p; + unsigned char l; + + /* Commas and spaces are delimiters here */ + + if (*alpn==',' || isspace(*alpn)) + { + ++alpn; + continue; + } + + /* Look for the next comma, spaces, or end of string */ + p=alpn; + + while (*alpn && *alpn != ',' && !isspace(*alpn)) + ++alpn; + + /* + ** We now have the character count and the string. + ** Defend against a bad setting by checking for overflow. + */ + l=alpn - p; + + if (l != alpn - p) + continue; + + if (*p != '@') /* Used in the gnutls version */ + (*cb)(l, p, arg); + } +} + +struct tls_avpn_to_protocol_list_info { + void (*cb)(const char *, unsigned int, void *); + void *arg; +}; + +/* +** Take the result of alpn_parse and construct a character stream, in ALPN +** protocol-list format. +*/ + +static void tls_avpn_to_protocol_list_cb(unsigned char l, + const char *p, + void *arg) +{ + struct tls_avpn_to_protocol_list_info *ptr= + (struct tls_avpn_to_protocol_list_info *)arg; + + (*ptr->cb)((char *)&l, 1, ptr->arg); + (*ptr->cb)(p, l, ptr->arg); +} + +static void tls_avpn_to_protocol_list(const char *alpn, + void (*cb)(const char *, unsigned int, + void *), + void *arg) +{ + struct tls_avpn_to_protocol_list_info info; + + info.cb=cb; + info.arg=arg; + + alpn_parse(alpn, tls_avpn_to_protocol_list_cb, &info); +} + +/* +** Create a discrete unsigned char buffer with the ALPN protocol-list +** string. +** +** First we use tls_avpn_to_protocol_list to count the size of the string, +** then malloc it, then create it for real-sies. +*/ + +struct alpn_info { + unsigned char *p; + unsigned int l; +}; + +static void alpn_proto_count(const char *ptr, unsigned int n, + void *arg) +{ + struct alpn_info *info=(struct alpn_info *)arg; + + info->l += n; +} + +static void alpn_proto_save(const char *ptr, unsigned int n, + void *arg) +{ + struct alpn_info *info=(struct alpn_info *)arg; + + memcpy(info->p, ptr, n); + info->p+=n; +} + +static void parse_tls_alpn(const struct tls_info *info, + struct alpn_info *alpn) +{ + const char *s=safe_getenv(info, "TLS_ALPN"); + unsigned char *buffer; + + alpn->p=NULL; + alpn->l=0; + + tls_avpn_to_protocol_list(s, alpn_proto_count, alpn); + + if (alpn->l==0) + return; + + buffer=(unsigned char *)malloc(alpn->l); + + if (!buffer) + { + nonsslerror(info, "malloc"); + exit(1); + } + alpn->p=buffer; + tls_avpn_to_protocol_list(s, alpn_proto_save, alpn); + alpn->p=buffer; +} + +struct tls_alpn_server_info { + const unsigned char *in; + unsigned char inl; + int found; +}; + +static void alpn_proto_search(unsigned char l, + const char *p, + void *arg) +{ + struct tls_alpn_server_info *info= + (struct tls_alpn_server_info *)arg; + + if (info->inl == l && + memcmp(info->in, p, l) == 0) + { + info->found=1; + } +} + +static int tls_alpn_server_cb(SSL *ssl, + const unsigned char **out, + unsigned char *outlen, + const unsigned char *in, + unsigned int inlen, + void *arg) +{ + struct tls_info *info=(struct tls_info *)SSL_get_app_data(ssl); + const char *s=safe_getenv(info, "TLS_ALPN"); + + while (inlen) + { + struct tls_alpn_server_info search_info; + + if (*in >= inlen || *in == 0) /* Won't assume */ + return SSL_TLSEXT_ERR_ALERT_FATAL; + + search_info.in=in+1; + search_info.inl=*in; + search_info.found=0; + + alpn_parse(s, alpn_proto_search, &search_info); + if (search_info.found) + { + *outlen=search_info.inl; + *out=search_info.in; + return SSL_TLSEXT_ERR_OK; + } + + inlen -= *in; + --inlen; + in += *in; + ++in; + } + return SSL_TLSEXT_ERR_ALERT_FATAL; +} + +#endif + SSL_CTX *tls_create(int isserver, const struct tls_info *info) { return tls_create_int(isserver, info, 0); @@ -966,10 +1163,41 @@ SSL_CTX *tls_create_int(int isserver, const struct tls_info *info, if (isserver) { +#if HAVE_OPENSSL_ALPN + struct alpn_info alpn; + + parse_tls_alpn(info, &alpn); + + if (alpn.p) + { + free((char *)alpn.p); + + SSL_CTX_set_alpn_select_cb(ctx, tls_alpn_server_cb, + NULL); + } +#endif SSL_CTX_set_tlsext_servername_callback(ctx, server_cert_cb); } else { +#if HAVE_OPENSSL_ALPN + struct alpn_info alpn; + + parse_tls_alpn(info, &alpn); + + if (alpn.p) + { + int ret=SSL_CTX_set_alpn_protos(ctx, alpn.p, alpn.l); + free((char *)alpn.p); + + if (ret) + { + sslerror(info, "SSL_CTX_set_alpn_protos", -1); + tls_destroy(ctx); + return (NULL); + } + } +#endif SSL_CTX_set_client_cert_cb(ctx, client_cert_cb); } return (ctx); |
