Refactor header claims (#1)

Simplify customization, bump to 2.0.0.
This commit is contained in:
Thomas Gideon 2017-03-07 14:03:24 -05:00 committed by GitHub
parent 9df2ac741e
commit 3c9fd6b13b
11 changed files with 522 additions and 292 deletions

View file

@ -1,129 +0,0 @@
use base64::{decode_config, encode_config, URL_SAFE_NO_PAD};
use Component;
use error::Error;
use serde::{Deserialize, Serialize};
use serde_json;
use serde_json::value::{Value};
use super::Result;
/// A default claim set, including the standard, or registered, claims and the ability to specify
/// your own as private claims.
#[derive(Debug, Default, PartialEq)]
pub struct Claims<T: Serialize + Deserialize> {
pub reg: Registered,
pub private: T
}
/// The registered claims from the spec.
#[derive(Debug, Default, PartialEq, Serialize, Deserialize)]
pub struct Registered {
pub iss: Option<String>,
pub sub: Option<String>,
pub aud: Option<String>,
pub exp: Option<u64>,
pub nbf: Option<u64>,
pub iat: Option<u64>,
pub jti: Option<String>,
}
impl<T: Serialize + Deserialize> Claims<T>{
/// Convenience factory method
pub fn new(reg: Registered, private: T) -> Claims<T> {
Claims {
reg: reg,
private: private
}
}
}
impl<T: Serialize + Deserialize> Component for Claims<T> {
/// This implementation simply parses the base64 data twice, each time applying it to the
/// registered and private claims.
fn from_base64(raw: &str) -> Result<Claims<T>> {
let data = decode_config(raw, URL_SAFE_NO_PAD)?;
let reg_claims: Registered = serde_json::from_slice(&data)?;
let pri_claims: T = serde_json::from_slice(&data)?;
Ok(Claims {
reg: reg_claims,
private: pri_claims
})
}
/// Renders both the registered and private claims into a single consolidated JSON
/// representation before encoding.
fn to_base64(&self) -> Result<String> {
if let Value::Object(mut reg_map) = serde_json::to_value(&self.reg)? {
if let Value::Object(pri_map) = serde_json::to_value(&self.private)? {
reg_map.extend(pri_map);
let s = serde_json::to_string(&reg_map)?;
let enc = encode_config((&*s).as_bytes(), URL_SAFE_NO_PAD);
Ok(enc)
} else {
Err(Error::Custom("Could not access registered claims.".to_owned()))
}
} else {
Err(Error::Custom("Could not access private claims.".to_owned()))
}
}
}
#[cfg(test)]
mod tests {
use std::default::Default;
use claims::{Claims, Registered};
use Component;
#[derive(Default, Debug, Serialize, Deserialize, PartialEq)]
struct EmptyClaim { }
#[derive(Default, Debug, Serialize, Deserialize, PartialEq)]
struct NonEmptyClaim {
user_id: String,
is_admin: bool,
first_name: Option<String>,
last_name: Option<String>
}
#[test]
fn from_base64() {
let enc = "eyJpc3MiOiJleGFtcGxlLmNvbSIsImV4cCI6MTMwMjMxOTEwMH0";
let claims: Claims<EmptyClaim> = Claims::from_base64(enc).unwrap();
assert_eq!(claims.reg.iss.unwrap(), "example.com");
assert_eq!(claims.reg.exp.unwrap(), 1302319100);
}
#[test]
fn multiple_types() {
let enc = "eyJpc3MiOiJleGFtcGxlLmNvbSIsImV4cCI6MTMwMjMxOTEwMH0";
let claims = Registered::from_base64(enc).unwrap();
assert_eq!(claims.iss.unwrap(), "example.com");
assert_eq!(claims.exp.unwrap(), 1302319100);
}
#[test]
fn roundtrip() {
let mut claims: Claims<EmptyClaim> = Default::default();
claims.reg.iss = Some("example.com".into());
claims.reg.exp = Some(1302319100);
let enc = claims.to_base64().unwrap();
assert_eq!(claims, Claims::from_base64(&*enc).unwrap());
}
#[test]
fn roundtrip_custom() {
let mut claims: Claims<NonEmptyClaim> = Default::default();
claims.reg.iss = Some("example.com".into());
claims.reg.exp = Some(1302319100);
claims.private.user_id = "123456".into();
claims.private.is_admin = false;
claims.private.first_name = Some("Random".into());
let enc = claims.to_base64().unwrap();
assert_eq!(claims, Claims::<NonEmptyClaim>::from_base64(&*enc).unwrap());
}
}

View file

@ -77,9 +77,6 @@ pub mod tests {
use std::fs::File;
use super::{sign, verify};
#[derive(Default, Debug, Serialize, Deserialize, PartialEq)]
struct EmptyClaim { }
#[test]
pub fn sign_data_hmac() {
let header = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9";

View file

@ -1,19 +1,21 @@
use base64::{encode_config, decode_config, URL_SAFE_NO_PAD};
use serde::{Serialize, Deserialize};
use serde_json::{self, Value};
use std::default::Default;
use Header;
/// A default Header providing the type, key id and algorithm fields.
use super::error::Error;
use super::Result;
/// A extensible Header that provides only algorithm field and allows for additional fields to be
/// passed in via a struct that can be serialized and deserialized. Unlike the Claims struct, there
/// is no convenience type alias because headers seem to vary much more greatly in practice
/// depending on the application whereas claims seem to be shared as a function of registerest and
/// public claims.
#[derive(Debug, PartialEq, Serialize, Deserialize)]
pub struct DefaultHeader {
pub typ: Option<HeaderType>,
pub kid: Option<String>,
pub struct Header<T: Serialize + Deserialize> {
pub alg: Algorithm,
}
/// Default value for the header type field.
#[derive(Debug, PartialEq, Serialize, Deserialize)]
pub enum HeaderType {
JWT,
#[serde(skip_serializing)]
pub headers: Option<T>,
}
/// Supported algorithms, each representing a valid signature and digest combination.
@ -24,56 +26,126 @@ pub enum Algorithm {
HS512,
RS256,
RS384,
RS512
RS512,
}
impl Default for DefaultHeader {
fn default() -> DefaultHeader {
DefaultHeader {
typ: Some(HeaderType::JWT),
kid: None,
alg: Algorithm::HS256,
impl<T: Serialize + Deserialize> Header<T> {
pub fn from_base64(raw: &str) -> Result<Header<T>> {
let data = decode_config(raw, URL_SAFE_NO_PAD)?;
let own: Header<T> = serde_json::from_slice(&data)?;
let headers: Option<T> = serde_json::from_slice(&data).ok();
Ok(Header {
alg: own.alg,
headers: headers,
})
}
/// Encode to a string.
pub fn to_base64(&self) -> Result<String> {
if let Value::Object(mut own_map) = serde_json::to_value(&self)? {
match self.headers {
Some(ref headers) => {
if let Value::Object(extra_map) = serde_json::to_value(&headers)? {
own_map.extend(extra_map);
let s = serde_json::to_string(&own_map)?;
let enc = encode_config((&*s).as_bytes(), URL_SAFE_NO_PAD);
Ok(enc)
} else {
Err(Error::Custom("Could not access additional headers.".to_owned()))
}
}
None => {
let s = serde_json::to_string(&own_map)?;
let enc = encode_config((&*s).as_bytes(), URL_SAFE_NO_PAD);
Ok(enc)
}
}
} else {
Err(Error::Custom("Could not access default header.".to_owned()))
}
}
}
/// Allow the rest of the library to access the configured algorithm without having to know the
/// specific type for the header.
impl Header for DefaultHeader {
fn alg(&self) -> &Algorithm {
&(self.alg)
impl<T: Serialize + Deserialize> Default for Header<T> {
fn default() -> Header<T> {
Header {
alg: Algorithm::HS256,
headers: None,
}
}
}
#[cfg(test)]
mod tests {
use Component;
use header::{
Algorithm,
DefaultHeader,
HeaderType,
};
use super::{Algorithm, Header};
#[derive(Debug, Serialize, Deserialize, PartialEq)]
struct CustomHeaders {
kid: String,
typ: String,
}
#[test]
fn from_base64() {
let enc = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9";
let header = DefaultHeader::from_base64(enc).unwrap();
let enc = "eyJhbGciOiJIUzI1NiJ9";
let header: Header<()> = Header::from_base64(enc).unwrap();
assert_eq!(header.typ.unwrap(), HeaderType::JWT);
assert_eq!(header.alg, Algorithm::HS256);
}
#[test]
fn custom_from_base64() {
let enc = "eyJhbGciOiJIUzI1NiIsImtpZCI6IjFLU0YzZyIsInR5cCI6IkpXVCJ9";
let header: Header<CustomHeaders> = Header::from_base64(enc).unwrap();
let enc = "eyJhbGciOiJSUzI1NiIsImtpZCI6IjFLU0YzZyJ9";
let header = DefaultHeader::from_base64(enc).unwrap();
let headers = header.headers.unwrap();
assert_eq!(headers.kid, "1KSF3g".to_string());
assert_eq!(headers.typ, "JWT".to_string());
assert_eq!(header.alg, Algorithm::HS256);
}
assert_eq!(header.kid.unwrap(), "1KSF3g".to_string());
assert_eq!(header.alg, Algorithm::RS256);
#[test]
fn to_base64() {
let enc = "eyJhbGciOiJIUzI1NiJ9";
let header: Header<()> = Default::default();
assert_eq!(enc, header.to_base64().unwrap());
}
#[test]
fn custom_to_base64() {
let enc = "eyJhbGciOiJIUzI1NiIsImtpZCI6IjFLU0YzZyIsInR5cCI6IkpXVCJ9";
let header: Header<CustomHeaders> = Header {
headers: Some(CustomHeaders {
kid: "1KSF3g".into(),
typ: "JWT".into(),
}),
..Default::default()
};
assert_eq!(enc, header.to_base64().unwrap());
}
#[test]
fn roundtrip() {
let header: DefaultHeader = Default::default();
let enc = Component::to_base64(&header).unwrap();
assert_eq!(header, DefaultHeader::from_base64(&*enc).unwrap());
let header: Header<()> = Default::default();
let enc = header.to_base64().unwrap();
assert_eq!(header, Header::from_base64(&*enc).unwrap());
}
#[test]
fn roundtrip_custom() {
let header: Header<CustomHeaders> = Header {
alg: Algorithm::RS512,
headers: Some(CustomHeaders {
kid: "1KSF3g".into(),
typ: "JWT".into(),
}),
};
let enc = header.to_base64().unwrap();
assert_eq!(header, Header::from_base64(&*enc).unwrap());
}
}

View file

@ -8,69 +8,45 @@ extern crate serde;
extern crate serde_derive;
extern crate serde_json;
use base64::{decode_config, encode_config, URL_SAFE_NO_PAD};
use serde::{Serialize, Deserialize};
pub use error::Error;
pub use header::DefaultHeader;
pub use header::Header;
pub use header::Algorithm;
pub use claims::Claims;
pub use claims::Registered;
pub use payload::{Payload, DefaultPayload};
pub mod error;
pub mod header;
pub mod claims;
mod header;
mod payload;
mod crypt;
pub type Result<T> = std::result::Result<T, Error>;
/// A convenient type that bins the same type parameter for the custom claims, an empty tuple, as
/// DefaultPayload so that the two aliases may be used together to reduce boilerplate when not
/// custom claims are needed.
pub type DefaultToken<H> = Token<H, ()>;
/// Main struct representing a JSON Web Token, composed of a header and a set of claims.
#[derive(Debug, Default)]
pub struct Token<H, C>
where H: Component, C: Component {
where H: Serialize + Deserialize + PartialEq,
C: Serialize + Deserialize + PartialEq
{
raw: Option<String>,
pub header: H,
pub claims: C,
}
/// Any header type must implement this trait so that signing and verification work.
pub trait Header {
fn alg(&self) -> &header::Algorithm;
}
/// Any header or claims type must implement this trait in order to serialize and deserialize
/// correctly.
pub trait Component: Sized {
fn from_base64(raw: &str) -> Result<Self>;
fn to_base64(&self) -> Result<String>;
}
/// Provide a default implementation that should work in almost all cases.
impl<T> Component for T
where T: Serialize + Deserialize + Sized {
/// Parse from a string.
fn from_base64(raw: &str) -> Result<T> {
let data = decode_config(raw, URL_SAFE_NO_PAD)?;
let s = String::from_utf8(data)?;
Ok(serde_json::from_str(&*s)?)
}
/// Encode to a string.
fn to_base64(&self) -> Result<String> {
let s = serde_json::to_string(&self)?;
let enc = encode_config((&*s).as_bytes(), URL_SAFE_NO_PAD);
Ok(enc)
}
pub header: Header<H>,
pub payload: Payload<C>,
}
/// Provide the ability to parse a token, verify it and sign/serialize it.
impl<H, C> Token<H, C>
where H: Component + Header, C: Component {
pub fn new(header: H, claims: C) -> Token<H, C> {
where H: Serialize + Deserialize + PartialEq,
C: Serialize + Deserialize + PartialEq
{
pub fn new(header: Header<H>, payload: Payload<C>) -> Token<H, C> {
Token {
raw: None,
header: header,
claims: claims,
payload: payload,
}
}
@ -80,8 +56,8 @@ impl<H, C> Token<H, C>
Ok(Token {
raw: Some(raw.into()),
header: Component::from_base64(pieces[0])?,
claims: Component::from_base64(pieces[1])?,
header: Header::from_base64(pieces[0])?,
payload: Payload::from_base64(pieces[1])?,
})
}
@ -96,44 +72,42 @@ impl<H, C> Token<H, C>
let sig = pieces[0];
let data = pieces[1];
Ok(crypt::verify(sig, data, key, &self.header.alg())?)
Ok(crypt::verify(sig, data, key, &self.header.alg)?)
}
/// Generate the signed token from a key with the specific algorithm as a url-safe, base64
/// string.
pub fn signed(&self, key: &[u8]) -> Result<String> {
let header = Component::to_base64(&self.header)?;
let claims = self.claims.to_base64()?;
let data = format!("{}.{}", header, claims);
pub fn sign(&self, key: &[u8]) -> Result<String> {
let header = self.header.to_base64()?;
let payload = self.payload.to_base64()?;
let data = format!("{}.{}", header, payload);
let sig = crypt::sign(&*data, key, &self.header.alg())?;
let sig = crypt::sign(&*data, key, &self.header.alg)?;
Ok(format!("{}.{}", data, sig))
}
}
impl<H, C> PartialEq for Token<H, C>
where H: Component + PartialEq, C: Component + PartialEq{
where H: Serialize + Deserialize + PartialEq,
C: Serialize + Deserialize + PartialEq
{
fn eq(&self, other: &Token<H, C>) -> bool {
self.header == other.header &&
self.claims == other.claims
self.header == other.header && self.payload == other.payload
}
}
#[cfg(test)]
mod tests {
use Claims;
use Token;
use {DefaultToken, Header};
use crypt::tests::load_pem;
use header::Algorithm::{HS256,RS512};
use header::DefaultHeader;
#[derive(Default, Debug, Serialize, Deserialize, PartialEq)]
struct EmptyClaim { }
use super::Algorithm::{HS256, RS512};
#[test]
pub fn raw_data() {
let raw = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ";
let token = Token::<DefaultHeader, Claims<EmptyClaim>>::parse(raw).unwrap();
let raw = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.\
eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.\
TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ";
let token = DefaultToken::<()>::parse(raw).unwrap();
{
assert_eq!(token.header.alg, HS256);
@ -143,10 +117,10 @@ mod tests {
#[test]
pub fn roundtrip_hmac() {
let token: Token<DefaultHeader, Claims<EmptyClaim>> = Default::default();
let token: DefaultToken<()> = Default::default();
let key = "secret".as_bytes();
let raw = token.signed(key).unwrap();
let same = Token::parse(&*raw).unwrap();
let raw = token.sign(key).unwrap();
let same = DefaultToken::parse(&*raw).unwrap();
assert_eq!(token, same);
assert!(same.verify(key).unwrap());
@ -154,16 +128,11 @@ mod tests {
#[test]
pub fn roundtrip_rsa() {
let token: Token<DefaultHeader, Claims<EmptyClaim>> = Token {
header: DefaultHeader {
alg: RS512,
..Default::default()
},
..Default::default()
};
let header: Header<()> = Header { alg: RS512, ..Default::default() };
let token = DefaultToken { header: header, ..Default::default() };
let private_key = load_pem("./examples/privateKey.pem").unwrap();
let raw = token.signed(private_key.as_bytes()).unwrap();
let same = Token::parse(&*raw).unwrap();
let raw = token.sign(private_key.as_bytes()).unwrap();
let same = DefaultToken::parse(&*raw).unwrap();
assert_eq!(token, same);
let public_key = load_pem("./examples/publicKey.pub").unwrap();

170
src/payload.rs Normal file
View file

@ -0,0 +1,170 @@
use base64::{decode_config, encode_config, URL_SAFE_NO_PAD};
use error::Error;
use serde::{Deserialize, Serialize};
use serde_json;
use serde_json::value::Value;
use super::Result;
/// A default claim set, including the standard, or registered, claims and the ability to specify
/// your own as custom claims.
#[derive(Debug, Serialize, Deserialize, Default, PartialEq)]
pub struct Payload<T: Serialize + Deserialize + PartialEq> {
#[serde(skip_serializing_if = "Option::is_none")]
pub iss: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub sub: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub aud: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub exp: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub nbf: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub iat: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub jti: Option<String>,
#[serde(skip_serializing)]
pub claims: Option<T>,
}
/// A convenient type alias that assumes the standard claims are sufficient, the empty tuple type
/// satisfies Claims' generic parameter as simply and clearly as possible.
pub type DefaultPayload = Payload<()>;
impl<T: Serialize + Deserialize + PartialEq> Payload<T> {
/// This implementation simply parses the base64 data twice, first parsing out the standard
/// claims then any custom claims, assigning the latter into a copy of the former before
/// returning registered and custom claims.
pub fn from_base64(raw: &str) -> Result<Payload<T>> {
let data = decode_config(raw, URL_SAFE_NO_PAD)?;
let claims: Payload<T> = serde_json::from_slice(&data)?;
let custom: Option<T> = serde_json::from_slice(&data).ok();
Ok(Payload {
iss: claims.iss,
sub: claims.sub,
aud: claims.aud,
exp: claims.exp,
nbf: claims.nbf,
iat: claims.iat,
jti: claims.jti,
claims: custom,
})
}
/// Renders both the standard and custom claims into a single consolidated JSON representation
/// before encoding.
pub fn to_base64(&self) -> Result<String> {
if let Value::Object(mut claims_map) = serde_json::to_value(&self)? {
match self.claims {
Some(ref custom) => {
if let Value::Object(custom_map) = serde_json::to_value(&custom)? {
claims_map.extend(custom_map);
let s = serde_json::to_string(&claims_map)?;
let enc = encode_config((&*s).as_bytes(), URL_SAFE_NO_PAD);
Ok(enc)
} else {
Err(Error::Custom("Could not access custom claims.".to_owned()))
}
}
None => {
let s = serde_json::to_string(&claims_map)?;
let enc = encode_config((&*s).as_bytes(), URL_SAFE_NO_PAD);
return Ok(enc);
}
}
} else {
Err(Error::Custom("Could not access standard claims.".to_owned()))
}
}
}
#[cfg(test)]
mod tests {
use std::default::Default;
use super::{Payload, DefaultPayload};
#[derive(Default, Debug, Serialize, Deserialize, PartialEq)]
struct CustomClaims {
user_id: String,
is_admin: bool,
first_name: Option<String>,
last_name: Option<String>,
}
#[test]
fn from_base64() {
let enc = "eyJhdWQiOiJsb2dpbl9zZXJ2aWNlIiwiZXhwIjoxMzAyMzE5MTAwLCJpYXQiOjEzMDIzMTcxMDAsImlzcyI6ImV4YW1wbGUuY29tIiwibmJmIjoxMzAyMzE3MTAwLCJzdWIiOiJSYW5kb20gVXNlciJ9";
let payload: DefaultPayload = Payload::from_base64(enc).unwrap();
assert_eq!(payload, create_default());
}
#[test]
fn custom_from_base64() {
let enc = "eyJleHAiOjEzMDIzMTkxMDAsImZpcnN0X25hbWUiOiJSYW5kb20iLCJpYXQiOjEzMDIzMTcxMDAsImlzX2FkbWluIjpmYWxzZSwiaXNzIjoiZXhhbXBsZS5jb20iLCJsYXN0X25hbWUiOiJVc2VyIiwidXNlcl9pZCI6IjEyMzQ1NiJ9";
let payload: Payload<CustomClaims> = Payload::from_base64(enc).unwrap();
assert_eq!(payload, create_custom());
}
#[test]
fn to_base64() {
let enc = "eyJhdWQiOiJsb2dpbl9zZXJ2aWNlIiwiZXhwIjoxMzAyMzE5MTAwLCJpYXQiOjEzMDIzMTcxMDAsImlzcyI6ImV4YW1wbGUuY29tIiwibmJmIjoxMzAyMzE3MTAwLCJzdWIiOiJSYW5kb20gVXNlciJ9";
let payload = create_default();
assert_eq!(enc, payload.to_base64().unwrap());
}
#[test]
fn custom_to_base64() {
let enc = "eyJleHAiOjEzMDIzMTkxMDAsImZpcnN0X25hbWUiOiJSYW5kb20iLCJpYXQiOjEzMDIzMTcxMDAsImlzX2FkbWluIjpmYWxzZSwiaXNzIjoiZXhhbXBsZS5jb20iLCJsYXN0X25hbWUiOiJVc2VyIiwidXNlcl9pZCI6IjEyMzQ1NiJ9";
let payload = create_custom();
assert_eq!(enc, payload.to_base64().unwrap());
}
#[test]
fn roundtrip() {
let payload = create_default();
let enc = payload.to_base64().unwrap();
assert_eq!(payload, Payload::from_base64(&*enc).unwrap());
}
#[test]
fn roundtrip_custom() {
let payload = create_custom();
let enc = payload.to_base64().unwrap();
assert_eq!(payload, Payload::<CustomClaims>::from_base64(&*enc).unwrap());
}
fn create_default() -> DefaultPayload {
DefaultPayload {
aud: Some("login_service".into()),
iat: Some(1302317100),
iss: Some("example.com".into()),
exp: Some(1302319100),
nbf: Some(1302317100),
sub: Some("Random User".into()),
..Default::default()
}
}
fn create_custom() -> Payload<CustomClaims> {
Payload {
iss: Some("example.com".into()),
iat: Some(1302317100),
exp: Some(1302319100),
claims: Some(CustomClaims {
user_id: "123456".into(),
is_admin: false,
first_name: Some("Random".into()),
last_name: Some("User".into()),
}),
..Default::default()
}
}
}