diff options
Diffstat (limited to 'tcpd/libcouriertls.c')
| -rw-r--r-- | tcpd/libcouriertls.c | 231 | 
1 files changed, 231 insertions, 0 deletions
| diff --git a/tcpd/libcouriertls.c b/tcpd/libcouriertls.c index 5105030..39c7d49 100644 --- a/tcpd/libcouriertls.c +++ b/tcpd/libcouriertls.c @@ -445,6 +445,9 @@ static void load_dh_params(SSL_CTX *ctx, const char *filename,  static int check_readable_file(const char *filename)  { +	if (!filename) +		return 0; +  	return (access(filename, R_OK) == 0) ? 1 : 0;  } @@ -729,6 +732,203 @@ static int server_cert_cb(ssl_handle ssl, int *ad, void *arg)  	return SSL_TLSEXT_ERR_OK;  } +#if HAVE_OPENSSL_ALPN + +/* +** Parse the TLS_ALPN setting, repeatedly invoking a callback with: +** 1) size of the protocol string +** 2) the protocol string +** 3) pass-through parameter. +*/ + +static void alpn_parse(const char *alpn, +		       void (*cb)(unsigned char l, +				  const char *p, +				  void *arg), +		       void *arg) +{ +	while (alpn && *alpn) +	{ +		const char *p; +		unsigned char l; + +		/* Commas and spaces are delimiters here */ + +		if (*alpn==',' || isspace(*alpn)) +		{ +			++alpn; +			continue; +		} + +		/* Look for the next comma, spaces, or end of string */ +		p=alpn; + +		while (*alpn && *alpn != ',' && !isspace(*alpn)) +			++alpn; + +		/* +		** We now have the character count and the string. +		** Defend against a bad setting by checking for overflow. +		*/ +		l=alpn - p; + +		if (l != alpn - p) +			continue; + +		if (*p != '@') /* Used in the gnutls version */ +			(*cb)(l, p, arg); +	} +} + +struct tls_avpn_to_protocol_list_info { +	void (*cb)(const char *, unsigned int, void *); +	void *arg; +}; + +/* +** Take the result of alpn_parse and construct a character stream, in ALPN +** protocol-list format. +*/ + +static void tls_avpn_to_protocol_list_cb(unsigned char l, +					 const char *p, +					 void *arg) +{ +	struct tls_avpn_to_protocol_list_info *ptr= +		(struct tls_avpn_to_protocol_list_info *)arg; + +	(*ptr->cb)((char *)&l, 1, ptr->arg); +	(*ptr->cb)(p, l, ptr->arg); +} + +static void tls_avpn_to_protocol_list(const char *alpn, +				      void (*cb)(const char *, unsigned int, +						 void *), +				      void *arg) +{ +	struct tls_avpn_to_protocol_list_info info; + +	info.cb=cb; +	info.arg=arg; + +	alpn_parse(alpn, tls_avpn_to_protocol_list_cb, &info); +} + +/* +** Create a discrete unsigned char buffer with the ALPN protocol-list +** string. +** +** First we use tls_avpn_to_protocol_list to count the size of the string, +** then malloc it, then create it for real-sies. +*/ + +struct alpn_info { +	unsigned char *p; +	unsigned int l; +}; + +static void alpn_proto_count(const char *ptr, unsigned int n, +			     void *arg) +{ +	struct alpn_info *info=(struct alpn_info *)arg; + +	info->l += n; +} + +static void alpn_proto_save(const char *ptr, unsigned int n, +			    void *arg) +{ +	struct alpn_info *info=(struct alpn_info *)arg; + +	memcpy(info->p, ptr, n); +	info->p+=n; +} + +static void parse_tls_alpn(const struct tls_info *info, +			   struct alpn_info *alpn) +{ +	const char *s=safe_getenv(info, "TLS_ALPN"); +	unsigned char *buffer; + +	alpn->p=NULL; +	alpn->l=0; + +	tls_avpn_to_protocol_list(s, alpn_proto_count, alpn); + +	if (alpn->l==0) +		return; + +	buffer=(unsigned char *)malloc(alpn->l); + +	if (!buffer) +	{ +		nonsslerror(info, "malloc"); +		exit(1); +	} +	alpn->p=buffer; +	tls_avpn_to_protocol_list(s, alpn_proto_save, alpn); +	alpn->p=buffer; +} + +struct tls_alpn_server_info { +	const unsigned char *in; +	unsigned char inl; +	int found; +}; + +static void alpn_proto_search(unsigned char l, +			      const char *p, +			      void *arg) +{ +	struct tls_alpn_server_info *info= +		(struct tls_alpn_server_info *)arg; + +	if (info->inl == l && +	    memcmp(info->in, p, l) == 0) +	{ +		info->found=1; +	} +} + +static int tls_alpn_server_cb(SSL *ssl, +			      const unsigned char **out, +			      unsigned char *outlen, +			      const unsigned char *in, +			      unsigned int inlen, +			      void *arg) +{ +	struct tls_info *info=(struct tls_info *)SSL_get_app_data(ssl); +	const char *s=safe_getenv(info, "TLS_ALPN"); + +	while (inlen) +	{ +		struct tls_alpn_server_info search_info; + +		if (*in >= inlen || *in == 0) /* Won't assume */ +			return SSL_TLSEXT_ERR_ALERT_FATAL; + +		search_info.in=in+1; +		search_info.inl=*in; +		search_info.found=0; + +		alpn_parse(s, alpn_proto_search, &search_info); +		if (search_info.found) +		{ +			*outlen=search_info.inl; +			*out=search_info.in; +			return SSL_TLSEXT_ERR_OK; +		} + +		inlen -= *in; +		--inlen; +		in += *in; +		++in; +	} +	return SSL_TLSEXT_ERR_ALERT_FATAL; +} + +#endif +  SSL_CTX *tls_create(int isserver, const struct tls_info *info)  {  	return tls_create_int(isserver, info, 0); @@ -963,10 +1163,41 @@ SSL_CTX *tls_create_int(int isserver, const struct tls_info *info,  	if (isserver)  	{ +#if HAVE_OPENSSL_ALPN +		struct alpn_info alpn; + +		parse_tls_alpn(info, &alpn); + +		if (alpn.p) +		{ +			free((char *)alpn.p); + +			SSL_CTX_set_alpn_select_cb(ctx, tls_alpn_server_cb, +						   NULL); +		} +#endif  		SSL_CTX_set_tlsext_servername_callback(ctx, server_cert_cb);  	}  	else  	{ +#if HAVE_OPENSSL_ALPN +		struct alpn_info alpn; + +		parse_tls_alpn(info, &alpn); + +		if (alpn.p) +		{ +			int ret=SSL_CTX_set_alpn_protos(ctx, alpn.p, alpn.l); +			free((char *)alpn.p); + +			if (ret) +			{ +				sslerror(info, "SSL_CTX_set_alpn_protos", -1); +				tls_destroy(ctx); +				return (NULL); +			} +		} +#endif  		SSL_CTX_set_client_cert_cb(ctx, client_cert_cb);  	}  	return (ctx); | 
