From cde08fe7886fea81955b3e50cfe76f504788b83e Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Ond=C5=99ej=20Hru=C5=A1ka?= <ondra@ondrovo.com>
Date: Tue, 31 Dec 2019 20:31:38 +0100
Subject: [PATCH] add examples, automatic expired removal, better
 configurability

---
 examples/dog_list/main.rs      |  70 ++++++++
 examples/visit_counter/main.rs |  72 +++++++++
 src/lib.rs                     | 286 +++++++++++++++++++++++----------
 3 files changed, 343 insertions(+), 85 deletions(-)
 create mode 100644 examples/dog_list/main.rs
 create mode 100644 examples/visit_counter/main.rs

diff --git a/examples/dog_list/main.rs b/examples/dog_list/main.rs
new file mode 100644
index 0000000..976ece7
--- /dev/null
+++ b/examples/dog_list/main.rs
@@ -0,0 +1,70 @@
+#![feature(proc_macro_hygiene, decl_macro)]
+#[macro_use]
+extern crate rocket;
+
+use rocket::response::content::Html;
+use rocket::response::Redirect;
+use rocket::request::Form;
+
+type Session<'a> = rocket_session::Session<'a, Vec<String>>;
+
+fn main() {
+    rocket::ignite()
+        .attach(Session::fairing())
+        .mount("/", routes![index, add, remove])
+        .launch();
+}
+
+#[get("/")]
+fn index(session: Session) -> Html<String> {
+    let mut page = String::new();
+    page.push_str(r#"
+            <!DOCTYPE html>
+            <h1>My Dogs</h1>
+
+            <form method="POST" action="/add">
+            Add Dog: <input type="text" name="name"> <input type="submit" value="Add">
+            </form>
+
+            <ul>
+        "#);
+
+    session.tap(|sess| {
+        for (n, dog) in sess.iter().enumerate() {
+            page.push_str(&format!(r#"
+                <li>&#x1F436; {} <a href="/remove/{}">Remove</a></li>
+            "#, dog, n));
+        }
+    });
+
+    page.push_str(r#"
+            </ul>
+        "#);
+
+    Html(page)
+}
+
+#[derive(FromForm)]
+struct AddForm {
+    name: String,
+}
+
+#[post("/add", data="<dog>")]
+fn add(session: Session, dog : Form<AddForm>) -> Redirect {
+    session.tap(move |sess| {
+        sess.push(dog.into_inner().name);
+    });
+
+    Redirect::found("/")
+}
+
+#[get("/remove/<dog>")]
+fn remove(session: Session, dog : usize) -> Redirect {
+    session.tap(|sess| {
+        if dog < sess.len() {
+            sess.remove(dog);
+        }
+    });
+
+    Redirect::found("/")
+}
diff --git a/examples/visit_counter/main.rs b/examples/visit_counter/main.rs
new file mode 100644
index 0000000..b4e29ae
--- /dev/null
+++ b/examples/visit_counter/main.rs
@@ -0,0 +1,72 @@
+//! This demo is a page visit counter, with a custom cookie name, length, and expiry time.
+//!
+//! The expiry time is set to 10 seconds to illustrate how a session is cleared if inactive.
+
+#![feature(proc_macro_hygiene, decl_macro)]
+#[macro_use]
+extern crate rocket;
+
+use std::time::Duration;
+use rocket::response::content::Html;
+
+#[derive(Default, Clone)]
+struct SessionData {
+    visits1: usize,
+    visits2: usize,
+}
+
+// It's convenient to define a type alias:
+type Session<'a> = rocket_session::Session<'a, SessionData>;
+
+fn main() {
+    rocket::ignite()
+        .attach(Session::fairing()
+            // 10 seconds of inactivity until session expires
+            // (wait 10s and refresh, the numbers will reset)
+            .with_lifetime(Duration::from_secs(10))
+            // custom cookie name and length
+            .with_cookie_name("my_cookie")
+            .with_cookie_len(20)
+        )
+        .mount("/", routes![index, about])
+        .launch();
+}
+
+#[get("/")]
+fn index(session: Session) -> Html<String> {
+    // Here we build the entire response inside the 'tap' closure.
+
+    // While inside, the session is locked to parallel changes, e.g.
+    // from a different browser tab.
+    session.tap(|sess| {
+        sess.visits1 += 1;
+
+        Html(format!(r##"
+                <!DOCTYPE html>
+                <h1>Home</h1>
+                <a href="/">Refresh</a> &bull; <a href="/about/">go to About</a>
+                <p>Visits: home {}, about {}</p>
+            "##,
+            sess.visits1,
+            sess.visits2
+        ))
+    })
+}
+
+#[get("/about")]
+fn about(session: Session) -> Html<String> {
+    // Here we return a value from the tap function and use it below
+    let count = session.tap(|sess| {
+        sess.visits2 += 1;
+        sess.visits2
+    });
+
+    Html(format!(r##"
+            <!DOCTYPE html>
+            <h1>About</h1>
+            <a href="/about">Refresh</a> &bull; <a href="/">go home</a>
+            <p>Page visits: {}</p>
+        "##,
+        count
+    ))
+}
diff --git a/src/lib.rs b/src/lib.rs
index 87fe7fb..b23a2cd 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -1,4 +1,4 @@
-use parking_lot::RwLock;
+use parking_lot::{RwLock, RwLockUpgradableReadGuard, Mutex};
 use rand::Rng;
 
 use rocket::{
@@ -12,60 +12,100 @@ use std::collections::HashMap;
 use std::marker::PhantomData;
 use std::ops::Add;
 use std::time::{Duration, Instant};
+use std::borrow::Cow;
+use std::fmt::{Display, Formatter, self};
 
-const SESSION_COOKIE: &str = "SESSID";
-const SESSION_ID_LEN: usize = 16;
-
-/// Session, as stored in the sessions store
+/// Session store (shared state)
 #[derive(Debug)]
-struct SessionInstance<D>
+pub struct SessionStore<D>
     where
         D: 'static + Sync + Send + Default,
 {
-    /// Data object
-    data: D,
-    /// Expiry
-    expires: Instant,
+    /// The internally mutable map of sessions
+    inner: RwLock<StoreInner<D>>,
+    // Session config
+    config: SessionConfig,
 }
 
-/// Session store (shared state)
-#[derive(Default, Debug)]
-pub struct SessionStore<D>
-    where
-        D: 'static + Sync + Send + Default,
-{
-    /// The internaly mutable map of sessions
-    inner: RwLock<HashMap<String, SessionInstance<D>>>,
+/// Session config object
+#[derive(Debug, Clone)]
+struct SessionConfig {
     /// Sessions lifespan
     lifespan: Duration,
+    /// Session cookie name
+    cookie_name: Cow<'static, str>,
+    /// Session cookie path
+    cookie_path: Cow<'static, str>,
+    /// Session ID character length
+    cookie_len: usize,
 }
 
-impl<D> SessionStore<D>
+impl Default for SessionConfig {
+    fn default() -> Self {
+        Self {
+            lifespan: Duration::from_secs(3600),
+            cookie_name: "rocket_session".into(),
+            cookie_path: "/".into(),
+            cookie_len: 16,
+        }
+    }
+}
+
+/// Mutable object stored inside SessionStore behind a RwLock
+#[derive(Debug)]
+struct StoreInner<D>
+    where
+        D: 'static + Sync + Send + Default {
+    sessions: HashMap<String, Mutex<SessionInstance<D>>>,
+    last_expiry_sweep: Instant,
+}
+
+impl<D> Default for StoreInner<D>
+    where
+        D: 'static + Sync + Send + Default {
+    fn default() -> Self {
+        Self {
+            sessions: Default::default(),
+            // the first expiry sweep is scheduled one lifetime from start-up
+            last_expiry_sweep: Instant::now(),
+        }
+    }
+}
+
+/// Session, as stored in the sessions store
+#[derive(Debug)]
+struct SessionInstance<D>
     where
         D: 'static + Sync + Send + Default,
 {
-    /// Remove all expired sessions
-    pub fn remove_expired(&self) {
-        let now = Instant::now();
-        self.inner.write().retain(|_k, v| v.expires > now);
-    }
+    /// Data object
+    data: D,
+    /// Expiry
+    expires: Instant,
 }
 
 /// Session ID newtype for rocket's "local_cache"
-#[derive(PartialEq, Hash, Clone, Debug)]
+#[derive(Clone, Debug)]
 struct SessionID(String);
 
 impl SessionID {
     fn as_str(&self) -> &str {
         self.0.as_str()
     }
+}
 
-    fn to_string(&self) -> String {
-        self.0.clone()
+impl Display for SessionID {
+    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
+        f.write_str(&self.0)
     }
 }
 
 /// Session instance
+///
+/// To access the active session, simply add it as an argument to a route function.
+///
+/// Sessions are started, restored, or expired in the `FromRequest::from_request()` method
+/// when a `Session` is prepared for one of the route functions.
 #[derive(Debug)]
 pub struct Session<'a, D>
     where
@@ -84,45 +124,76 @@ impl<'a, 'r, D> FromRequest<'a, 'r> for Session<'a, D>
     type Error = ();
 
     fn from_request(request: &'a Request<'r>) -> Outcome<Self, (Status, Self::Error), ()> {
-        let store : State<SessionStore<D>> = request.guard().unwrap();
+        let store: State<SessionStore<D>> = request.guard().unwrap();
         Outcome::Success(Session {
             id: request.local_cache(|| {
+                let store_ug = store.inner.upgradable_read();
+
                 // Resolve session ID
-                let id = if let Some(cookie) = request.cookies().get(SESSION_COOKIE) {
-                    SessionID(cookie.value().to_string())
+                let id = if let Some(cookie) = request.cookies().get(&store.config.cookie_name) {
+                    Some(SessionID(cookie.value().to_string()))
                 } else {
-                    SessionID(
-                        rand::thread_rng()
-                            .sample_iter(&rand::distributions::Alphanumeric)
-                            .take(SESSION_ID_LEN)
-                            .collect(),
-                    )
+                    None
                 };
 
-                let new_expiration = Instant::now().add(store.lifespan);
-                let mut wg = store.inner.write();
-                match wg.get_mut(id.as_str()) {
-                    Some(ses) => {
-                        // Check expiration
-                        if ses.expires <= Instant::now() {
-                            ses.data = D::default();
-                        }
-                        // Update expiry timestamp
-                        ses.expires = new_expiration;
-                    },
-                    None => {
-                        // New session
-                        wg.insert(
-                            id.to_string(),
-                            SessionInstance {
-                                data: D::default(),
-                                expires: new_expiration,
-                            }
-                        );
+                let expires = Instant::now().add(store.config.lifespan);
+
+                if let Some(m) = id.as_ref()
+                    .and_then(|token| store_ug.sessions.get(token.as_str()))
+                {
+                    // --- ID obtained from a cookie && session found in the store ---
+
+                    let mut inner = m.lock();
+                    if inner.expires <= Instant::now() {
+                        // Session expired, reuse the ID but drop data.
+                        inner.data = D::default();
                     }
-                };
 
-                id
+                    // Session is extended by making a request with valid ID
+                    inner.expires = expires;
+
+                    id.unwrap()
+                } else {
+                    // --- ID missing or session not found ---
+
+                    // Get exclusive write access to the map
+                    let mut store_wg = RwLockUpgradableReadGuard::upgrade(store_ug);
+
+                    // This branch runs less often, and we already have write access,
+                    // let's check if any sessions expired. We don't want to hog memory
+                    // forever by abandoned sessions (e.g. when a client lost their cookie)
+
+                    // Throttle by lifespan - e.g. sweep every hour
+                    if store_wg.last_expiry_sweep.elapsed() > store.config.lifespan {
+                        let now = Instant::now();
+                        store_wg.sessions
+                            .retain(|_k, v| v.lock().expires > now);
+
+                        store_wg.last_expiry_sweep = now;
+                    }
+
+                    // Find a new unique ID - we are still safely inside the write guard
+                    let new_id = SessionID(loop {
+                        let token: String = rand::thread_rng()
+                            .sample_iter(&rand::distributions::Alphanumeric)
+                            .take(store.config.cookie_len)
+                            .collect();
+
+                        if !store_wg.sessions.contains_key(&token) {
+                            break token;
+                        }
+                    });
+
+                    store_wg.sessions.insert(
+                        new_id.to_string(),
+                        Mutex::new(SessionInstance {
+                            data: Default::default(),
+                            expires,
+                        }),
+                    );
+
+                    new_id
+                }
             }),
             store,
         })
@@ -133,46 +204,90 @@ impl<'a, D> Session<'a, D>
     where
         D: 'static + Sync + Send + Default,
 {
-    /// Get the fairing object
-    pub fn fairing(lifespan: Duration) -> impl Fairing {
-        SessionFairing::<D> {
-            lifespan,
-            _phantom: PhantomData,
-        }
+    /// Create the session fairing.
+    ///
+    /// You can configure the session store by calling chained methods on the returned value
+    /// before passing it to `rocket.attach()`
+    pub fn fairing() -> SessionFairing<D> {
+        SessionFairing::<D>::new()
     }
 
-    /// Access the session store
-    pub fn get_store(&self) -> &SessionStore<D> {
-        &self.store
-    }
-
-    /// Set the session object to its default state
-    pub fn reset(&self) {
-        self.tap_mut(|m| {
+    /// Clear session data (replace the value with default)
+    pub fn clear(&self) {
+        self.tap(|m| {
             *m = D::default();
         })
     }
 
-    pub fn tap<T>(&self, func: impl FnOnce(&D) -> T) -> T {
-        let rg = self.store.inner.read();
-        let instance = rg.get(self.id.as_str()).unwrap();
-        func(&instance.data)
-    }
+    /// Access the session's data using a closure.
+    ///
+    /// The closure is called with the data value as a mutable argument,
+    /// and can return any value to be is passed up to the caller.
+    pub fn tap<T>(&self, func: impl FnOnce(&mut D) -> T) -> T {
+        // Use a read guard, so other already active sessions are not blocked
+        // from accessing the store. New incoming clients may be blocked until
+        // the tap() call finishes
+        let store_rg = self.store.inner.read();
+
+        // Unlock the session's mutex.
+        // Expiry was checked and prolonged at the beginning of the request
+        let mut instance = store_rg.sessions.get(self.id.as_str())
+            .expect("Session data unexpectedly missing")
+            .lock();
 
-    pub fn tap_mut<T>(&self, func: impl FnOnce(&mut D) -> T) -> T {
-        let mut wg = self.store.inner.write();
-        let instance = wg.get_mut(self.id.as_str()).unwrap();
         func(&mut instance.data)
     }
 }
 
 /// Fairing struct
-struct SessionFairing<D>
+#[derive(Default)]
+pub struct SessionFairing<D>
     where
         D: 'static + Sync + Send + Default,
 {
-    lifespan: Duration,
-    _phantom: PhantomData<D>,
+    config: SessionConfig,
+    phantom: PhantomData<D>,
+}
+
+impl<D> SessionFairing<D>
+    where
+        D: 'static + Sync + Send + Default
+{
+    fn new() -> Self {
+        Self::default()
+    }
+
+    /// Set session lifetime (expiration time).
+    ///
+    /// Call on the fairing before passing it to `rocket.attach()`
+    pub fn with_lifetime(mut self, time: Duration) -> Self {
+        self.config.lifespan = time;
+        self
+    }
+
+    /// Set session cookie name and length
+    ///
+    /// Call on the fairing before passing it to `rocket.attach()`
+    pub fn with_cookie_name(mut self, name: impl Into<Cow<'static, str>>) -> Self {
+        self.config.cookie_name = name.into();
+        self
+    }
+
+    /// Set session cookie name and length
+    ///
+    /// Call on the fairing before passing it to `rocket.attach()`
+    pub fn with_cookie_len(mut self, length: usize) -> Self {
+        self.config.cookie_len = length;
+        self
+    }
+
+    /// Set session cookie name and length
+    ///
+    /// Call on the fairing before passing it to `rocket.attach()`
+    pub fn with_cookie_path(mut self, path: impl Into<Cow<'static, str>>) -> Self {
+        self.config.cookie_path = path.into();
+        self
+    }
 }
 
 impl<D> Fairing for SessionFairing<D>
@@ -190,7 +305,7 @@ impl<D> Fairing for SessionFairing<D>
         // install the store singleton
         Ok(rocket.manage(SessionStore::<D> {
             inner: Default::default(),
-            lifespan: self.lifespan,
+            config: self.config.clone(),
         }))
     }
 
@@ -199,7 +314,8 @@ impl<D> Fairing for SessionFairing<D>
         let session = request.local_cache(|| SessionID("".to_string()));
 
         if !session.0.is_empty() {
-            response.adjoin_header(Cookie::build(SESSION_COOKIE, session.0.clone()).finish());
+            response.adjoin_header(Cookie::build(self.config.cookie_name.clone(), session.to_string())
+                .path("/").finish());
         }
     }
 }