150 lines
4.5 KiB
Rust
150 lines
4.5 KiB
Rust
use super::Result;
|
|
use anyhow::format_err;
|
|
use base64::{decode_config, encode_config, URL_SAFE_NO_PAD};
|
|
use serde::{de::DeserializeOwned, Deserialize, Serialize};
|
|
use serde_json::{self, Value};
|
|
|
|
/// An extensible Header that provides only algorithm field and allows for additional fields to be
|
|
/// passed in via a struct that can be serialized and deserialized. Unlike the Claims struct, there
|
|
/// is no convenience type alias because headers seem to vary much more greatly in practice
|
|
/// depending on the application whereas claims seem to be shared as a function of registerest and
|
|
/// public claims.
|
|
#[derive(Debug, PartialEq, Serialize, Deserialize)]
|
|
pub struct Header<T = ()> {
|
|
pub alg: Algorithm,
|
|
#[serde(skip_serializing)]
|
|
pub headers: Option<T>,
|
|
}
|
|
|
|
/// Supported algorithms, each representing a valid signature and digest combination.
|
|
#[derive(Debug, PartialEq, Serialize, Deserialize)]
|
|
pub enum Algorithm {
|
|
HS256,
|
|
HS384,
|
|
HS512,
|
|
RS256,
|
|
RS384,
|
|
RS512,
|
|
}
|
|
|
|
impl<T: Serialize + DeserializeOwned> Header<T> {
|
|
/// Decode from base64.
|
|
pub fn from_base64(raw: &str) -> Result<Header<T>> {
|
|
let data = decode_config(raw, URL_SAFE_NO_PAD)?;
|
|
let own: Header<T> = serde_json::from_slice(&data)?;
|
|
|
|
let headers: Option<T> = serde_json::from_slice(&data).ok();
|
|
|
|
Ok(Header {
|
|
alg: own.alg,
|
|
headers,
|
|
})
|
|
}
|
|
|
|
/// Encode to a string.
|
|
pub fn to_base64(&self) -> Result<String> {
|
|
if let Value::Object(mut own_map) = serde_json::to_value(&self)? {
|
|
match self.headers {
|
|
Some(ref headers) => {
|
|
if let Value::Object(extra_map) = serde_json::to_value(&headers)? {
|
|
own_map.extend(extra_map);
|
|
let s = serde_json::to_string(&own_map)?;
|
|
let enc = encode_config((&*s).as_bytes(), URL_SAFE_NO_PAD);
|
|
Ok(enc)
|
|
} else {
|
|
Err(format_err!("Could not access additional headers."))
|
|
}
|
|
}
|
|
None => {
|
|
let s = serde_json::to_string(&own_map)?;
|
|
let enc = encode_config((&*s).as_bytes(), URL_SAFE_NO_PAD);
|
|
Ok(enc)
|
|
}
|
|
}
|
|
} else {
|
|
Err(format_err!("Could not access default header."))
|
|
}
|
|
}
|
|
}
|
|
|
|
impl<T> Default for Header<T> {
|
|
fn default() -> Header<T> {
|
|
Header {
|
|
alg: Algorithm::HS256,
|
|
headers: None,
|
|
}
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::{Algorithm, Header};
|
|
use serde::{Deserialize, Serialize};
|
|
|
|
#[derive(Debug, Serialize, Deserialize, PartialEq)]
|
|
struct CustomHeaders {
|
|
kid: String,
|
|
typ: String,
|
|
}
|
|
|
|
#[test]
|
|
fn from_base64() {
|
|
let enc = "eyJhbGciOiJIUzI1NiJ9";
|
|
let header: Header<()> = Header::from_base64(enc).unwrap();
|
|
|
|
assert_eq!(header.alg, Algorithm::HS256);
|
|
}
|
|
|
|
#[test]
|
|
fn custom_from_base64() {
|
|
let enc = "eyJhbGciOiJIUzI1NiIsImtpZCI6IjFLU0YzZyIsInR5cCI6IkpXVCJ9";
|
|
let header: Header<CustomHeaders> = Header::from_base64(enc).unwrap();
|
|
|
|
let headers = header.headers.unwrap();
|
|
assert_eq!(headers.kid, "1KSF3g".to_string());
|
|
assert_eq!(headers.typ, "JWT".to_string());
|
|
assert_eq!(header.alg, Algorithm::HS256);
|
|
}
|
|
|
|
#[test]
|
|
fn to_base64() {
|
|
let enc = "eyJhbGciOiJIUzI1NiJ9";
|
|
let header: Header<()> = Header::default();
|
|
|
|
assert_eq!(enc, header.to_base64().unwrap());
|
|
}
|
|
|
|
#[test]
|
|
fn custom_to_base64() {
|
|
let enc = "eyJhbGciOiJIUzI1NiIsImtpZCI6IjFLU0YzZyIsInR5cCI6IkpXVCJ9";
|
|
let header: Header<CustomHeaders> = Header {
|
|
headers: Some(CustomHeaders {
|
|
kid: "1KSF3g".into(),
|
|
typ: "JWT".into(),
|
|
}),
|
|
..Header::default()
|
|
};
|
|
|
|
assert_eq!(enc, header.to_base64().unwrap());
|
|
}
|
|
|
|
#[test]
|
|
fn roundtrip() {
|
|
let header: Header<()> = Header::default();
|
|
let enc = header.to_base64().unwrap();
|
|
assert_eq!(header, Header::from_base64(&*enc).unwrap());
|
|
}
|
|
|
|
#[test]
|
|
fn roundtrip_custom() {
|
|
let header: Header<CustomHeaders> = Header {
|
|
alg: Algorithm::RS512,
|
|
headers: Some(CustomHeaders {
|
|
kid: "1KSF3g".into(),
|
|
typ: "JWT".into(),
|
|
}),
|
|
};
|
|
let enc = header.to_base64().unwrap();
|
|
assert_eq!(header, Header::from_base64(&*enc).unwrap());
|
|
}
|
|
}
|