1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
use std::borrow::Cow;
use std::path::PathBuf;

use crate::store::{BlobEntry, BlobStoreInterface, BlobTable, BlobTableInterface, StoreError};
use directories::ProjectDirs;
use log::*;
use percent_encoding::{percent_decode_str, utf8_percent_encode, AsciiSet, CONTROLS};
use tokio::fs;

/// Defines the characters to escape in keys
const ESCAPE_CHARS: &AsciiSet = &CONTROLS
    .add(b' ')
    .add(b'"')
    .add(b'<')
    .add(b'>')
    .add(b'`')
    .add(b'\'')
    .add(b'/');

#[derive(Clone, Debug)]
pub struct BlobStoreFS {
    basedir: PathBuf,
}

impl BlobStoreFS {
    pub async fn new(basedir: PathBuf) -> Result<BlobStoreFS, StoreError> {
        if !fs::try_exists(&basedir).await? {
            fs::create_dir_all(&basedir).await?;
        }

        Ok(BlobStoreFS { basedir })
    }

    pub async fn from_standards(dirname: &str) -> Result<BlobStoreFS, StoreError> {
        let standards_dir =
            ProjectDirs::from("rs.xmpp", "", dirname).expect("Failed to create project directory");

        let basedir = standards_dir.data_dir();

        info!("Using {} for data store.", basedir.display());
        Self::new(basedir.to_path_buf()).await
    }

    #[allow(dead_code)]
    async fn len(&self, entry: &BlobEntry) -> Result<u64, StoreError> {
        let path = self.basedir.join(entry.table).join(&entry.key);
        let metadata = fs::metadata(path).await?;
        Ok(metadata.len())
    }

    /// Removes slashes from a key by urlencoding it. Prevents path traversal in the file store.
    pub fn sanitize_key(key: &str) -> String {
        utf8_percent_encode(key, ESCAPE_CHARS).to_string()
    }
}

#[async_trait::async_trait]
impl BlobStoreInterface for BlobStoreFS {
    async fn table(&self, table: &BlobTable) -> Result<Box<dyn BlobTableInterface>, StoreError> {
        let store = BlobStoreFS::new(self.basedir.join(table)).await?;
        Ok(Box::new(store))
    }
}

#[async_trait::async_trait]
impl<'a> BlobTableInterface<'a> for BlobStoreFS {
    async fn has(&self, key: &str) -> Result<bool, StoreError> {
        let key = Self::sanitize_key(key);
        let path = self.basedir.join(key);
        fs::try_exists(path).await.map_err(|e| StoreError::from(e))
    }

    async fn get(&self, key: &str) -> Result<Cow<'a, [u8]>, StoreError> {
        let key = Self::sanitize_key(key);
        let path = self.basedir.join(key);
        debug!("Reading data from {}", path.display());
        fs::read(path)
            .await
            .map_err(|e| StoreError::from(e))
            .map(|v| Cow::from(v))
    }

    // TODO: make write more resilient to power loss (write to temp file + rename)
    async fn set(&mut self, key: &str, value: &[u8]) -> Result<(), StoreError> {
        let key = Self::sanitize_key(key);
        let path = self.basedir.join(key);
        debug!("Saving data to {}", path.display());
        fs::write(path, value)
            .await
            .map_err(|e| StoreError::from(e))
    }

    async fn delete(&mut self, key: &str) -> Result<(), StoreError> {
        let key = Self::sanitize_key(key);
        let path = self.basedir.join(key);
        fs::remove_file(path).await?;
        Ok(())
    }

    async fn delete_all(&mut self) -> Result<(), StoreError> {
        fs::remove_dir_all(&self.basedir).await?;
        fs::create_dir_all(&self.basedir).await?;
        Ok(())
    }

    async fn list(&self) -> Result<Vec<String>, StoreError> {
        let mut l: Vec<String> = vec![];

        let mut entries = fs::read_dir(&self.basedir).await?;
        while let Some(entry) = entries.next_entry().await? {
            // Unwrapping is safe because key was valid UTF-8 when it was created
            let key = percent_decode_str(entry.file_name().to_str().unwrap())
                .decode_utf8()
                .unwrap()
                .to_string();
            l.push(key);
        }

        Ok(l)
    }
}

#[cfg(test)]
mod tests {
    use crate::store::{BlobEntry, BlobStoreFS, BlobStoreInterface, TABLE_AVATAR};
    use std::path::PathBuf;
    use temp_dir::TempDir;

    #[tokio::test]
    async fn blob_store_fs_entry() {
        let temp_dir = PathBuf::from(TempDir::new().unwrap().path());
        let mut store = BlobStoreFS::new(temp_dir.clone()).await.unwrap();
        let entry = BlobEntry::new(TABLE_AVATAR, "foo".to_string());
        store.set_in_table(&entry, &vec![1, 2, 3, 4]).await.unwrap();
        assert_eq!(
            *vec!(1, 2, 3, 4),
            *store.get_in_table(&entry).await.unwrap()
        );
        tokio::fs::remove_dir_all(temp_dir).await.unwrap();
    }

    #[tokio::test]
    async fn blob_store_fs_table() {
        let temp_dir = PathBuf::from(TempDir::new().unwrap().path());
        println!("{:?}", temp_dir);
        let store = BlobStoreFS::new(temp_dir.clone()).await.unwrap();
        let mut avatars = store.table(&TABLE_AVATAR).await.unwrap();
        avatars.set("foo", &vec![1, 2, 3, 4]).await.unwrap();
        assert_eq!(*vec!(1, 2, 3, 4), *avatars.get("foo").await.unwrap());
        tokio::fs::remove_dir_all(temp_dir).await.unwrap();
    }

    #[tokio::test]
    async fn blob_store_fs_escape() {
        let temp_dir = PathBuf::from(TempDir::new().unwrap().path());
        let store = BlobStoreFS::new(temp_dir.clone()).await.unwrap();
        let mut avatars = store.table(&TABLE_AVATAR).await.unwrap();
        avatars.set("foo", &vec![1, 2, 3, 4]).await.unwrap();
        avatars.set("/etc/passwd", &vec![1, 3, 1, 2]).await.unwrap();

        let mut expected = vec!["foo".to_string(), "/etc/passwd".to_string()];
        expected.sort_unstable();

        let mut found = avatars.list().await.unwrap();
        found.sort_unstable();

        assert_eq!(expected, found);
        assert_eq!(
            *vec!(1, 3, 1, 2),
            *avatars.get("/etc/passwd").await.unwrap()
        );
    }
}