From af30c552ac4aaa9a59ba8e5e56b0b08122fd7b1f Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Ond=C5=99ej=20Hru=C5=A1ka?= <ondra@ondrovo.com>
Date: Tue, 31 Dec 2019 17:41:07 +0100
Subject: [PATCH] move code from tap to session's from_request, add tap_mut and
 make tap read-only

---
 src/lib.rs | 91 ++++++++++++++++++++++++++----------------------------
 1 file changed, 44 insertions(+), 47 deletions(-)

diff --git a/src/lib.rs b/src/lib.rs
index 177a68f..87fe7fb 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -55,22 +55,13 @@ impl<D> SessionStore<D>
 #[derive(PartialEq, Hash, Clone, Debug)]
 struct SessionID(String);
 
-impl<'a, 'r> FromRequest<'a, 'r> for &'a SessionID {
-    type Error = ();
+impl SessionID {
+    fn as_str(&self) -> &str {
+        self.0.as_str()
+    }
 
-    fn from_request(request: &'a Request<'r>) -> Outcome<Self, (Status, Self::Error), ()> {
-        Outcome::Success(request.local_cache(|| {
-            if let Some(cookie) = request.cookies().get(SESSION_COOKIE) {
-                SessionID(cookie.value().to_string()) // FIXME avoid cloning (cow?)
-            } else {
-                SessionID(
-                    rand::thread_rng()
-                        .sample_iter(&rand::distributions::Alphanumeric)
-                        .take(16)
-                        .collect(),
-                )
-            }
-        }))
+    fn to_string(&self) -> String {
+        self.0.clone()
     }
 }
 
@@ -93,9 +84,11 @@ impl<'a, 'r, D> FromRequest<'a, 'r> for Session<'a, D>
     type Error = ();
 
     fn from_request(request: &'a Request<'r>) -> Outcome<Self, (Status, Self::Error), ()> {
+        let store : State<SessionStore<D>> = request.guard().unwrap();
         Outcome::Success(Session {
             id: request.local_cache(|| {
-                if let Some(cookie) = request.cookies().get(SESSION_COOKIE) {
+                // Resolve session ID
+                let id = if let Some(cookie) = request.cookies().get(SESSION_COOKIE) {
                     SessionID(cookie.value().to_string())
                 } else {
                     SessionID(
@@ -104,9 +97,34 @@ impl<'a, 'r, D> FromRequest<'a, 'r> for Session<'a, D>
                             .take(SESSION_ID_LEN)
                             .collect(),
                     )
-                }
+                };
+
+                let new_expiration = Instant::now().add(store.lifespan);
+                let mut wg = store.inner.write();
+                match wg.get_mut(id.as_str()) {
+                    Some(ses) => {
+                        // Check expiration
+                        if ses.expires <= Instant::now() {
+                            ses.data = D::default();
+                        }
+                        // Update expiry timestamp
+                        ses.expires = new_expiration;
+                    },
+                    None => {
+                        // New session
+                        wg.insert(
+                            id.to_string(),
+                            SessionInstance {
+                                data: D::default(),
+                                expires: new_expiration,
+                            }
+                        );
+                    }
+                };
+
+                id
             }),
-            store: request.guard().unwrap(),
+            store,
         })
     }
 }
@@ -130,42 +148,21 @@ impl<'a, D> Session<'a, D>
 
     /// Set the session object to its default state
     pub fn reset(&self) {
-        self.tap(|m| {
+        self.tap_mut(|m| {
             *m = D::default();
         })
     }
 
-    /// Renew the session without changing any data
-    pub fn renew(&self) {
-        self.tap(|_| ())
+    pub fn tap<T>(&self, func: impl FnOnce(&D) -> T) -> T {
+        let rg = self.store.inner.read();
+        let instance = rg.get(self.id.as_str()).unwrap();
+        func(&instance.data)
     }
 
-    /// Run a closure with a mutable reference to the session object.
-    /// The closure's return value is send to the caller.
-    pub fn tap<T>(&self, func: impl FnOnce(&mut D) -> T) -> T {
+    pub fn tap_mut<T>(&self, func: impl FnOnce(&mut D) -> T) -> T {
         let mut wg = self.store.inner.write();
-        if let Some(instance) = wg.get_mut(&self.id.0) {
-            // wipe session data if expired
-            if instance.expires <= Instant::now() {
-                instance.data = D::default();
-            }
-            // update expiry timestamp
-            instance.expires = Instant::now().add(self.store.lifespan);
-
-            func(&mut instance.data)
-        } else {
-            // no object in the store yet, start fresh
-            let mut data = D::default();
-            let result = func(&mut data);
-            wg.insert(
-                self.id.0.clone(),
-                SessionInstance {
-                    data,
-                    expires: Instant::now().add(self.store.lifespan),
-                },
-            );
-            result
-        }
+        let instance = wg.get_mut(self.id.as_str()).unwrap();
+        func(&mut instance.data)
     }
 }