Rust Learning from Zero (22) —— SSL certificates management and distribution with Rust

Screenshot of rcert

Last night, I was planning to do some web development, and when I created my new VPS, I just felt like it was quite a tedious process either manually copying-pasting SSL certs around or deploying CertBot to allocate new ones for me. My life would be easier if there is some sort of SSL certificates management and distribution system.

Although I didn't really search these keywords on Google, I can promise that they definitely exist on the web. Nevertheless, I want to write one myself, because it has been such a long period since last time I wrote something in Rust.

The design and usage are pretty straightforward. It should be a C/S type software. The server side will be listening on a specific port, and the client should be able to put, get and list SSL certificates. And the certs will be stored in Redis. Well, actually I can store the certs in HashMap or other data structures in memory, or I can save they in local/remote disks, or I could get an object-storage system like Amazon S3 for them, or even on blockchain!

Technically , I can save these certs anywhere with any method if I want. What really matters is what operations I want to do with these certs. Different requirements and operations associated with them will lead to various storage implementations. If I can promise the server side program will never crash and the server will never have a down time, then of course I can store everything in memory. If the server has some automatically backup services and I want these certs to be backup along with that service, then I can choose to save these certs on disk. If I want to set a TTL for them, I can use Redis without bothering implementing the mechanism myself. These assumptions and requirements can be different for everyone. Besides, on what scale we're dealing with also affects our choice. A few thousands of certs, we can go with any approaches. Hundreds of millions certs, well, we might to need to think about the performance. If the number certs goes above trillions, perhaps we should design the system elaborately.

But for personal usage and the operations we talked about, using Redis is enough. Before firing up the IDE, I wrote down some presumed commands that I would use with it.

# Run as server
# Listening on all interfaces on port 12345, access secret is himitsu, and Redis connection URL
rcert server -l 0.0.0.0:12345 --secret himitsu --redis redis://127.0.0.1/

# Put SSL certs to rcert service
# Endpoint at 1.2.3.4:12345, access secret is himitsu
# Domain name (no checking) is ryza.moe
# Along with the path to public cert and private key
#
# And if succeeds, the program should print the version of the saved certs (e.g, a1b2c3d4e5)
rcert put -s 1.2.3.4:12345 --secret himitsu -d ryza.moe -c fullchain.pem -k privkey.pem

# Get SSL certs from rcert service
# Endpoint at 1.2.3.4:12345, access secret is himitsu
# Domain name (no checking) is ryza.moe, certs version is a1b2c3d4e5
# Along with the path to save public cert and private key
rcert get -s 1.2.3.4:12345 --secret himitsu -d ryza.moe -v a1b2c3d4e5 -c fullchain.pem -k privkey.pem

# List all SSL certs
rcert list -s 1.2.3.4:12345 --secret himitsu

Now, talk is done, and the code goes. https://magic.ryza.moe/ryza/rcert

extern crate ring;
extern crate redis;

use async_std::{
    prelude::*,
    task,
    net::{TcpListener,ToSocketAddrs,TcpStream},
    io::BufReader,
    fs::{self, File},
    path::Path,
};
use clap::{Arg, App, ArgMatches};
use futures::{select, FutureExt};
use hex::{encode, decode};
use ring::{aead::*, digest::{Context, Digest, SHA256}, pbkdf2::*};
use serde::{Deserialize, Serialize};
use std::{num::NonZeroU32, option::Option::*};
use redis::RedisResult;

type Result<T> = std::result::Result<T, Box<dyn std::error::Error + Send + Sync>>;

#[derive(Serialize, Deserialize, Debug)]
struct RcertRequest {
    domain: String,
    method: String,
    cert: Option<String>,
    key: Option<String>,
    id: Option<u32>,
    ver: Option<String>
}

#[derive(Serialize, Deserialize, Debug)]
struct PutResponse {
    domain: String,
    id: u32,
    sha256: String
}

#[derive(Serialize, Deserialize, Debug)]
struct GetResponse {
    domain: String,
    cert: Option<String>,
    key: Option<String>,
    ver: String
}

#[derive(Serialize, Deserialize, Debug)]
struct ListResponse {
    certs: Vec<GetResponse>
}

pub struct RingAeadNonceSequence {
    nonce: [u8; NONCE_LEN],
}

impl RingAeadNonceSequence {
    fn new() -> RingAeadNonceSequence {
        RingAeadNonceSequence {
            nonce: [0u8; NONCE_LEN],
        }
    }
}

impl NonceSequence for RingAeadNonceSequence {
    fn advance(&mut self) -> std::result::Result<Nonce, ring::error::Unspecified> {
        let nonce = Nonce::assume_unique_for_key(self.nonce);
        increase_nonce(&mut self.nonce);
        Ok(nonce)
    }
}

pub fn increase_nonce(nonce: &mut [u8]) {
    for i in nonce {
        if std::u8::MAX == *i {
            *i = 0;
        } else {
            *i += 1;
            return;
        }
    }
}

// ---------------- cli fn ----------------

fn main() -> Result<()> {
    let args = parse_arg();
    if let Some(server_config) = args.subcommand_matches("server") {
        // rcert server -l 0.0.0.0:7536 --secret himitsu
        let address = server_config.value_of("listen").unwrap();
        let secret = server_config.value_of("secret").unwrap();
        let redis_url = server_config.value_of("redis").unwrap();
        let fut = accept_loop(address, secret.to_string(), redis_url.to_string());
        let _ = task::block_on(fut);
    } else if let Some(client_config) = args.subcommand_matches("put") {
        // rcert put -s 127.0.0.1:7536 --secret himitsu -k privkey.pem -c fullchain.pem
        let server = client_config.value_of("server").unwrap();
        let secret = client_config.value_of("secret").unwrap();
        let domain = client_config.value_of("domain").unwrap();
        let cert = client_config.value_of("cert").unwrap();
        let key = client_config.value_of("key").unwrap();
        println!("[INFO] Uploading cert {} privkey {} to {}...", cert, key, server);
        match task::block_on(put_ssl_cert(server, secret, domain, cert, key)) {
            Ok(_) => {},
            Err(e) => eprintln!("[ERROR] {}", e.to_string()),
        };
    } else if let Some(client_config) = args.subcommand_matches("get") {
        // rcert get -s 127.0.0.1:7536 --secret himitsu -d ryza.moe -v a1b2c3d4e5 -k privkey.pem -c fullchain.pem
        let server = client_config.value_of("server").unwrap();
        let secret = client_config.value_of("secret").unwrap();
        let domain = client_config.value_of("domain").unwrap();
        let version = client_config.value_of("version").unwrap();
        let cert = client_config.value_of("cert").unwrap();
        let key = client_config.value_of("key").unwrap();
        println!("[INFO] Retrieving SSL certificate of {} ver {} from {}...", domain, version, server);
        match task::block_on(get_ssl_cert(server, secret, domain, version, cert, key)) {
            Ok(_) => {},
            Err(e) => eprintln!("[ERROR] {}", e.to_string()),
        };
    } else if let Some(client_config) = args.subcommand_matches("list") {
        // rcert list -s 127.0.0.1:7536 --secret himitsu
        let server = client_config.value_of("server").unwrap();
        let secret = client_config.value_of("secret").unwrap();
        println!("[INFO] Querying all certs on {}...", server);
        match task::block_on(list_ssl_cert(server, secret)) {
            Ok(_) => {},
            Err(e) => eprintln!("[ERROR] {}", e.to_string()),
        };
    } else {
        eprintln!("[ERROR] no such subcommand, please see help");
    }

    Ok(())
}

fn parse_arg() -> ArgMatches {
    App::new("rcert")
        .version("1.0")
        .author("Ryza <[email protected]>")
        .about("Store SSL Certificates")
        // rcert server -l 0.0.0.0:7536 --secret himitsu
        .subcommand(App::new("server")
            .about("Runs as server")
            .version("1.0")
            .author("Ryza <[email protected]>")
            .arg(Arg::new("listen")
                .short('l').long("listen")
                .value_name("ADDRESS")
                .required(true)
                .about("listen on IP"))
            .arg(Arg::new("secret")
                .long("secret")
                .value_name("SECRET")
                .required(true)
                .about("Access secret"))
            .arg(Arg::new("redis")
                .long("redis")
                .value_name("REDIS")
                .required(true)
                .about("Redis connection URL")))
        // rcert put -s 127.0.0.1:7536 --secret himitsu -k privkey.pem -c fullchain.crt
        .subcommand(App::new("put")
            .about("Put SSL certificates")
            .version("1.0")
            .author("Ryza <[email protected]>")
            .arg(Arg::new("server")
                .short('s').long("server")
                .value_name("server")
                .required(true)
                .about("rcert server"))
            .arg(Arg::new("secret")
                .long("secret")
                .value_name("SECRET")
                .required(true)
                .about("Access secret"))
            .arg(Arg::new("domain")
                .short('d').long("domain")
                .value_name("DOMAIN")
                .required(true)
                .about("Domain name of SSL certificate"))
            .arg(Arg::new("key")
                .short('k').long("key")
                .value_name("KEY")
                .required(true)
                .about("Private key of SSL certificate"))
            .arg(Arg::new("cert")
                .short('c').long("cert")
                .value_name("CERT")
                .required(true)
                .about("Public cert of SSL certificate")))
        // rcert get -s 127.0.0.1:7536 --secret himitsu -d ryza.moe -v a1b2c3d4e5 -k privkey.pem -c fullchain.crt
        .subcommand(App::new("get")
            .about("Get SSL certificates")
            .version("1.0")
            .author("Ryza <[email protected]>")
            .arg(Arg::new("server")
                .short('s').long("server")
                .value_name("SERVER")
                .required(true)
                .about("rcert server"))
            .arg(Arg::new("secret")
                .long("secret")
                .value_name("SECRET")
                .required(true)
                .about("Access secret"))
            .arg(Arg::new("domain")
                .short('d')
                .long("domain")
                .value_name("DOMAIN")
                .required(true)
                .about("Domain name of SSL certificate"))
            .arg(Arg::new("version")
                .short('v').long("ver")
                .value_name("VERSION")
                .required(true)
                .about("SSL certificate version"))
            .arg(Arg::new("key")
                .short('k').long("key")
                .value_name("KEY")
                .required(true)
                .about("Path to save private key of SSL certificate"))
            .arg(Arg::new("cert")
                .short('c').long("cert")
                .value_name("CERT")
                .required(true)
                .about("Path to save public cert of SSL certificate")))
        // rcert list -s 127.0.0.1:7536 --secret himitsu
        .subcommand(App::new("list")
            .about("List all SSL certificates")
            .version("1.0")
            .author("Ryza <[email protected]>")
            .arg(Arg::new("server")
                .short('s').long("server")
                .value_name("SERVER")
                .required(true)
                .about("rcert server"))
            .arg(Arg::new("secret")
                .long("secret")
                .value_name("SECRET")
                .required(true)
                .about("Access secret")))
        .get_matches()
}

// ---------------- server-side fn ----------------

async fn accept_loop(addr: &str, secret: String, redis_url: String) -> Result<()> {
    let listener = TcpListener::bind(addr).await?;
    let mut incoming = listener.incoming();
    println!("[OK] listening at {}", addr);
    while let Some(stream) = incoming.next().await {
        let stream = stream?;
        println!("[OK] accepting from: {}", stream.peer_addr()?);
        spawn_and_log_error(connection_loop(stream, secret.clone(), redis_url.clone()));
    }
    Ok(())
}

fn spawn_and_log_error<F>(fut: F) -> task::JoinHandle<()>
    where
        F: Future<Output = Result<()>> + Send + 'static,
{
    task::spawn(async move {
        if let Err(e) = fut.await {
            eprintln!("{}", e)
        }
    })
}

async fn connection_loop(stream: TcpStream, secret: String, redis_url: String) -> Result<()> {
    let mut reader = BufReader::new(&stream);
    let mut buf : Vec<u8> = Vec::new();
    let num_bytes = reader.read_until(b'\n', &mut buf).await?;
    if num_bytes == 0 { return Ok(()) }
    buf.remove(num_bytes - 1);

    let hex_string = String::from_utf8(buf)?;
    let data = decode(hex_string)?;
    let decrypted = decrypt_data(data, secret)?;
    let trimmed_data = decrypted.trim_matches('\0');
    let req = serde_json::from_str::<RcertRequest>(trimmed_data)?;
    match &req.method[..] {
        "put"  => handle_put_req(stream, redis_url, req, trimmed_data).await?,
        "get"  => handle_get_req(stream, redis_url,req).await?,
        "list" => handle_list_req(stream, redis_url).await?,
        _ => eprintln!("[ERROR] unknown method from peer {}", stream.peer_addr()?),
    };

    Ok(())
}

async fn handle_put_req(mut stream: TcpStream, redis_url: String, put_req: RcertRequest, trimmed_data: &str) -> Result<()> {
    let trimmed_data_len = trimmed_data.len();
    let sha256 = encode(sha256_digest(trimmed_data.as_bytes(), trimmed_data_len)?.as_ref());
    let sha256 = String::from(&sha256[0..10]);

    let client = redis::Client::open(redis_url).unwrap();
    let mut con = client.get_async_connection().await?;
    redis::cmd("SET").arg(&[format!("{}|{}", put_req.domain, sha256.clone()), trimmed_data.to_string()]).query_async(&mut con).await?;

    let resp = PutResponse {
        domain: put_req.domain.to_string(),
        id: put_req.id.unwrap_or(0),
        sha256: sha256.clone()
    };
    let json_resp = serde_json::to_string(&resp)?;
    stream.write_all(json_resp.as_ref()).await?;
    println!("[OK] put {}|{}", put_req.domain, sha256);
    Ok(())
}

async fn handle_get_req(mut stream: TcpStream, redis_url: String, get_req: RcertRequest) -> Result<()> {
    if let Some(ver) = get_req.ver {
        let client = redis::Client::open(redis_url).unwrap();
        let mut con = client.get_async_connection().await?;
        let result: RedisResult<String> = redis::cmd("GET")
            .arg(&[format!("{}|{}", get_req.domain, ver)])
            .query_async(&mut con)
            .await;
        match result {
            Ok(result) => {
                let redis_data = serde_json::from_str::<RcertRequest>(&result)?;
                let resp = GetResponse {
                    domain: get_req.domain.clone(),
                    cert: redis_data.cert,
                    key: redis_data.key,
                    ver: ver.clone()
                };
                let json_resp = serde_json::to_string(&resp)?;
                stream.write_all(json_resp.as_ref()).await?;
                println!("[OK] get {}|{}", get_req.domain, ver);
                Ok(())
            },
            Err(e) => Err(Box::new(std::io::Error::new(std::io::ErrorKind::Other, e.to_string())))
        }
    } else {
        Err(Box::new(std::io::Error::new(std::io::ErrorKind::InvalidInput, "No ver param found")))
    }
}

async fn handle_list_req(mut stream: TcpStream, redis_url: String) -> Result<()> {
    let client = redis::Client::open(redis_url).unwrap();
    let mut con = client.get_async_connection().await?;
    let result: RedisResult<Vec<String>> = redis::cmd("KEYS")
        .arg(&["*|*"])
        .query_async(&mut con)
        .await;
    if let Ok(keys) = result {
        let certs = keys.into_iter()
            .filter_map(|s| {
                let parts = s.split("|").collect::<Vec<&str>>();
                if parts.len() == 2 {
                    if parts[0].len() > 0 && parts[1].len() == 10 {
                        return Some(GetResponse {
                            domain: parts[0].to_string(),
                            cert: None,
                            key: None,
                            ver: parts[1].to_string()
                        })
                    }
                }
                None
            }).collect::<Vec<GetResponse>>();
        let number_certs = certs.len();
        let json_resp = serde_json::to_string(&ListResponse{certs})?;
        stream.write_all(json_resp.as_ref()).await?;
        println!("[OK] list {} certs available", number_certs);
        Ok(())
    } else {
        Err(Box::new(std::io::Error::new(std::io::ErrorKind::Other, result.err().unwrap().to_string())))
    }
}

// ---------------- client-side fn ----------------

async fn put_ssl_cert<P: AsRef<Path>>(server: impl ToSocketAddrs, secret: &str, domain: &str, cert: P, key: P) -> Result<()> {
    let cert = fs::read_to_string(cert).await?;
    let key = fs::read_to_string(key).await?;
    let data = serde_json::to_string(&RcertRequest {
        domain: domain.to_string().to_owned(),
        method: "put".to_string(),
        cert: Some(cert),
        key: Some(key),
        id: Some(0),
        ver: None
    })?;
    let encrypted = encrypt_data(data, secret.to_string())?;

    let stream = TcpStream::connect(server).await?;
    let (reader, mut writer) = (&stream, &stream);
    let mut server_resp = BufReader::new(reader).lines().fuse();
    writer.write_all(encrypted.as_bytes()).await?;
    writer.write_all(b"\n").await?;

    loop {
        select! {
            line = server_resp.next().fuse() => match line {
                Some(line) => {
                    let line = line?;
                    let resp = serde_json::from_str::<PutResponse>(&line)?;
                    println!("[OK] put {}|{}", resp.domain, resp.sha256);
                    break;
                },
                None => break,
            },
        }
    }
    Ok(())
}

async fn get_ssl_cert<P: AsRef<Path>>(server: impl ToSocketAddrs, secret: &str, domain: &str, version: &str, cert: P, key: P) -> Result<()> {
    let data = serde_json::to_string(&RcertRequest {
        domain: domain.to_string(),
        method: "get".to_string(),
        cert: None,
        key: None,
        id: None,
        ver: Some(version.to_string())
    })?;
    let encrypted = encrypt_data(data, secret.to_string())?;

    let stream = TcpStream::connect(server).await?;
    let (reader, mut writer) = (&stream, &stream);
    let mut server_resp = BufReader::new(reader).lines().fuse();
    writer.write_all(encrypted.as_bytes()).await?;
    writer.write_all(b"\n").await?;

    let mut server_responded = false;
    let mut cert_content = String::new();
    let mut key_content = String::new();
    loop {
        select! {
            line = server_resp.next().fuse() => match line {
                Some(line) => {
                    server_responded = true;
                    let line = line?;
                    let resp = serde_json::from_str::<GetResponse>(&line)?;
                    if let Some(redis_cert_content) = resp.cert {
                        if let Some(redis_key_content) = resp.key {
                            cert_content = redis_cert_content;
                            key_content = redis_key_content;
                            break;
                        } else {
                            eprintln!("[ERROR] No corresponding private key found");
                            return Ok(());
                        }
                    } else {
                        eprintln!("[ERROR] No corresponding public certificate found");
                        return Ok(());
                    }
                },
                None => break,
            },
        }
    }

    if !server_responded {
        eprintln!("[ERROR] server no response");
    } else {
        let mut cert_file = File::create(cert).await?;
        cert_file.write_all(cert_content.as_bytes()).await?;

        let mut key_file = File::create(key).await?;
        key_file.write_all(key_content.as_bytes()).await?;
        println!("[OK] get {}|{}", domain, version);
    }
    Ok(())
}

async fn list_ssl_cert(server: impl ToSocketAddrs, secret: &str) -> Result<()> {
    let data = serde_json::to_string(&RcertRequest {
        domain: String::new(),
        method: "list".to_string(),
        cert: None,
        key: None,
        id: None,
        ver: None
    })?;
    let encrypted = encrypt_data(data, secret.to_string())?;

    let stream = TcpStream::connect(server).await?;
    let (reader, mut writer) = (&stream, &stream);
    let mut server_resp = BufReader::new(reader).lines().fuse();
    writer.write_all(encrypted.as_bytes()).await?;
    writer.write_all(b"\n").await?;

    loop {
        select! {
            line = server_resp.next().fuse() => match line {
                Some(line) => {
                    let line = line?;
                    let resp = serde_json::from_str::<ListResponse>(&line)?;
                    println!("[OK] {} certs available", resp.certs.len());
                    let _ = resp.certs.into_iter().map(|cert| {
                        println!("{}|{}", cert.domain, cert.ver);
                    }).collect::<Vec<_>>();
                    break;
                },
                None => break,
            },
        }
    }
    Ok(())
}

// ---------------- common fn ----------------

fn encrypt_data(data: String, secret: String) -> Result<String> {
    let salt = [0, 1, 2, 3, 4, 5, 6, 7];

    let mut key = [0; 32];
    derive(PBKDF2_HMAC_SHA512, NonZeroU32::new(100).unwrap(), &salt, secret.as_ref(), &mut key);
    let data = data.as_bytes().to_vec();
    let mut in_out = data.clone();
    for _ in 0..CHACHA20_POLY1305.tag_len() {
        in_out.push(0);
    }

    let mut sealing_key = SealingKey::new(UnboundKey::new(&CHACHA20_POLY1305, &key).unwrap(), RingAeadNonceSequence::new());
    sealing_key.seal_in_place_append_tag(Aad::empty(), &mut in_out).unwrap();
    Ok(encode(in_out))
}

fn decrypt_data(data: Vec<u8>, secret: String) -> Result<String> {
    let salt = [0, 1, 2, 3, 4, 5, 6, 7];

    let mut key = [0; 32];
    derive(PBKDF2_HMAC_SHA512, NonZeroU32::new(100).unwrap(), &salt, secret.as_ref(), &mut key);
    let mut opening_key = OpeningKey::new(UnboundKey::new(&CHACHA20_POLY1305, &key).unwrap(), RingAeadNonceSequence::new());
    let mut data = data.clone();
    match opening_key.open_in_place(Aad::empty(), &mut data) {
        Ok(decrypted) => Ok(String::from_utf8(decrypted.to_vec())?),
        Err(_e) => Err(Box::new(std::io::Error::new(std::io::ErrorKind::InvalidData, "Cannot decode data"))),
    }
}

fn sha256_digest(data: &[u8], len: usize) -> Result<Digest> {
    let mut context = Context::new(&SHA256);
    context.update(&data[..len]);
    Ok(context.finish())
}

One thought on “Rust Learning from Zero (22) —— SSL certificates management and distribution with Rust”

Leave a Reply

Your email address will not be published. Required fields are marked *

18 − 13 =