Skip to main content

haste_server/auth_n/certificates/providers/
local.rs

1use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
2use haste_config::Config;
3use haste_fhir_model::r4::generated::terminology::IssueType;
4use haste_fhir_operation_error::OperationOutcomeError;
5use rand::rngs::OsRng;
6use rsa::{
7    RsaPrivateKey,
8    pkcs1::{DecodeRsaPrivateKey, EncodeRsaPrivateKey, EncodeRsaPublicKey},
9    pkcs8::LineEnding,
10    traits::PublicKeyParts,
11};
12use sha1::{Digest, Sha1};
13use std::{path::Path, sync::Arc};
14use walkdir::{DirEntry, WalkDir};
15
16use crate::{
17    ServerEnvironmentVariables,
18    auth_n::certificates::{
19        JSONWebKey, JSONWebKeyAlgorithm, JSONWebKeySet, JSONWebKeyType,
20        traits::{CertificationProvider, DecodingKey, EncodingKey},
21    },
22};
23
24fn derive_kid(cert_path: &Path) -> String {
25    let file_name = Path::file_stem(cert_path)
26        .unwrap()
27        .to_str()
28        .unwrap()
29        .to_string();
30    let chunks = file_name.split("_").collect::<Vec<&str>>();
31    chunks.get(0).unwrap().to_string()
32}
33
34fn get_sorted_private_cert_paths(config: &dyn Config<ServerEnvironmentVariables>) -> Vec<DirEntry> {
35    let certificate_dir = config
36        .get(ServerEnvironmentVariables::CertificationDir)
37        .unwrap();
38    let cert_dir: &Path = Path::new(&certificate_dir);
39    let walker = WalkDir::new(cert_dir).into_iter();
40    let mut entries = walker
41        .filter_map(|e| e.ok())
42        .filter(|e| e.metadata().unwrap().is_file())
43        .filter(|e| e.file_name().to_str().unwrap().ends_with(".pem"))
44        .collect::<Vec<DirEntry>>();
45
46    entries.sort_by(|a, b| {
47        let a_chunks = Path::file_stem(a.path())
48            .unwrap()
49            .to_str()
50            .unwrap()
51            .split("_")
52            .collect::<Vec<&str>>();
53        let b_chunks = Path::file_stem(b.path())
54            .unwrap()
55            .to_str()
56            .unwrap()
57            .split("_")
58            .collect::<Vec<&str>>();
59
60        let date_a =
61            chrono::NaiveDate::parse_from_str(a_chunks.get(1).unwrap(), "%Y-%m-%d").unwrap();
62        let date_b =
63            chrono::NaiveDate::parse_from_str(b_chunks.get(1).unwrap(), "%Y-%m-%d").unwrap();
64
65        // latest first.
66        date_b.cmp(&date_a)
67    });
68
69    entries
70}
71
72fn create_jwk_set(
73    certificate_entries: &Vec<DirEntry>,
74) -> Result<JSONWebKeySet, OperationOutcomeError> {
75    let mut jsonweb_key_set = JSONWebKeySet { keys: vec![] };
76
77    for certification_entry in certificate_entries.iter() {
78        let cert_path = certification_entry.path();
79        let rsa_private =
80            RsaPrivateKey::from_pkcs1_pem(&std::fs::read_to_string(cert_path).unwrap()).unwrap();
81        let rsa_public_key = rsa_private.to_public_key();
82
83        let mut hasher = Sha1::new();
84        hasher.update(rsa_public_key.to_pkcs1_der().unwrap().as_bytes());
85        let x5t = hasher.finalize();
86
87        jsonweb_key_set.keys.push(JSONWebKey {
88            kid: derive_kid(cert_path),
89            alg: JSONWebKeyAlgorithm::RS256,
90            kty: JSONWebKeyType::RSA,
91            e: URL_SAFE_NO_PAD.encode(&rsa_public_key.e().clone().to_bytes_be()),
92            n: URL_SAFE_NO_PAD.encode(&rsa_public_key.n().clone().to_bytes_be()),
93            x5t: Some(URL_SAFE_NO_PAD.encode(&x5t)),
94        });
95    }
96
97    Ok(jsonweb_key_set)
98}
99
100fn create_decoding_keys(
101    certificate_entries: &Vec<DirEntry>,
102) -> Result<Vec<DecodingKey>, OperationOutcomeError> {
103    let mut decoding_keys = vec![];
104
105    for certification_entry in certificate_entries.iter() {
106        let cert_path = certification_entry.path();
107        let rsa_private =
108            RsaPrivateKey::from_pkcs1_pem(&std::fs::read_to_string(cert_path).unwrap()).unwrap();
109
110        let rsa_public_key = rsa_private.to_public_key();
111
112        let decoding_key = jsonwebtoken::DecodingKey::from_rsa_pem(
113            rsa_public_key
114                .to_pkcs1_pem(LineEnding::default())
115                .unwrap()
116                .as_bytes(),
117        )
118        .unwrap();
119
120        decoding_keys.push(DecodingKey {
121            kid: derive_kid(cert_path),
122            decoding_key,
123        });
124    }
125
126    Ok(decoding_keys)
127}
128
129/// Latest key is first. this is set by date_b.cmp(&date_a) in get_sorted_private_cert_paths
130fn get_encoding_keys(
131    certificate_entries: &Vec<DirEntry>,
132) -> Result<Vec<EncodingKey>, OperationOutcomeError> {
133    let mut encoding_keys = vec![];
134
135    for certification_entry in certificate_entries.iter() {
136        let cert_path = certification_entry.path();
137        let encoding_key =
138            jsonwebtoken::EncodingKey::from_rsa_pem(&std::fs::read(cert_path).unwrap()).unwrap();
139
140        encoding_keys.push(EncodingKey {
141            kid: derive_kid(cert_path),
142            encoding_key,
143        });
144    }
145
146    Ok(encoding_keys)
147}
148
149fn create_certifications_if_needed(
150    config: &dyn Config<ServerEnvironmentVariables>,
151) -> Result<(), OperationOutcomeError> {
152    let certificate_dir = config
153        .get(ServerEnvironmentVariables::CertificationDir)
154        .unwrap();
155    let cert_dir: &Path = Path::new(&certificate_dir);
156
157    let private_key_files = get_sorted_private_cert_paths(config);
158
159    // If no private key than write.
160    if private_key_files.is_empty() {
161        let mut rng = OsRng;
162        let bits: usize = 2048;
163
164        // Use rfc 3339 format for date. Same as time_rotating.id.
165        let date = chrono::Utc::now();
166        let date2 = date + chrono::Days::new(5);
167
168        let private_key_file_name1 = format!("k1_{}.pem", date.format("%Y-%m-%d"));
169        let private_key_file_name2 = format!("k2_{}.pem", date2.format("%Y-%m-%d"));
170
171        let priv_key1 = RsaPrivateKey::new(&mut rng, bits).expect("failed to generate a key");
172        let priv_key2 = RsaPrivateKey::new(&mut rng, bits).expect("failed to generate a key");
173
174        std::fs::create_dir_all(cert_dir).unwrap();
175        std::fs::write(
176            cert_dir.join(private_key_file_name1),
177            priv_key1.to_pkcs1_pem(LineEnding::default()).unwrap(),
178        )
179        .map_err(|e| OperationOutcomeError::fatal(IssueType::Exception(None), e.to_string()))?;
180        std::fs::write(
181            cert_dir.join(private_key_file_name2),
182            priv_key2.to_pkcs1_pem(LineEnding::default()).unwrap(),
183        )
184        .map_err(|e| OperationOutcomeError::fatal(IssueType::Exception(None), e.to_string()))?;
185    }
186
187    Ok(())
188}
189
190pub struct LocalCertifications {
191    decoding_key: Arc<Vec<DecodingKey>>,
192    encoding_keys: Arc<Vec<EncodingKey>>,
193    jwk_set: Arc<JSONWebKeySet>,
194}
195
196impl LocalCertifications {
197    pub fn new(
198        config: &dyn Config<ServerEnvironmentVariables>,
199    ) -> Result<Self, OperationOutcomeError> {
200        create_certifications_if_needed(config)?;
201
202        let private_certificate_entries = get_sorted_private_cert_paths(config);
203
204        Ok(LocalCertifications {
205            decoding_key: Arc::new(create_decoding_keys(&private_certificate_entries)?),
206            encoding_keys: Arc::new(get_encoding_keys(&private_certificate_entries)?),
207            jwk_set: Arc::new(create_jwk_set(&private_certificate_entries)?),
208        })
209    }
210}
211
212impl CertificationProvider for LocalCertifications {
213    fn decoding_key<'a>(&'a self, kid: &str) -> Result<&'a DecodingKey, OperationOutcomeError> {
214        self.decoding_key
215            .iter()
216            .find(|d| d.kid == kid)
217            .ok_or_else(|| {
218                OperationOutcomeError::error(
219                    IssueType::Exception(None),
220                    format!("No decoding key found for kid: '{}'", kid),
221                )
222            })
223    }
224
225    fn encoding_key<'a>(&'a self) -> Result<&'a EncodingKey, OperationOutcomeError> {
226        self.encoding_keys.first().ok_or_else(|| {
227            OperationOutcomeError::error(
228                IssueType::Exception(None),
229                "No encoding key available".to_string(),
230            )
231        })
232    }
233
234    fn jwk_set(&self) -> Arc<JSONWebKeySet> {
235        self.jwk_set.clone()
236    }
237}