src/server.rs 11.2 K raw
1
use askama::Template;
2
use askama_web::WebTemplate;
3
use subtle::ConstantTimeEq;
4
use axum::{
5
    Form, Json, Router,
6
    extract::{Path, Request, State},
7
    http::{HeaderMap, StatusCode, header},
8
    middleware::{self, Next},
9
    response::{Html, IntoResponse, Redirect, Response},
10
    routing::{delete, get, post, put},
11
};
12
use rust_embed::Embed;
13
use serde::Deserialize;
14
use crate::db::{self, Db, Snippet};
15
use crate::highlight::Highlighter;
16
use std::collections::HashSet;
17
use std::sync::Arc;
18
19
#[derive(Embed)]
20
#[folder = "assets/"]
21
struct Assets;
22
23
#[derive(Embed)]
24
#[folder = "static/"]
25
struct Static;
26
27
#[derive(Clone)]
28
struct ServerConfig {
29
    api_key: Option<String>,
30
    auth_endpoints: HashSet<String>,
31
}
32
33
impl ServerConfig {
34
    fn from_env() -> Self {
35
        let api_key = std::env::var("SIPP_API_KEY").ok();
36
        let auth_endpoints = match std::env::var("SIPP_AUTH_ENDPOINTS") {
37
            Ok(val) if val.trim().eq_ignore_ascii_case("none") => HashSet::new(),
38
            Ok(val) => val.split(',').map(|s| s.trim().to_lowercase()).collect(),
39
            Err(_) => ["api_delete", "api_list", "api_update"].iter().map(|s| s.to_string()).collect(),
40
        };
41
        ServerConfig { api_key, auth_endpoints }
42
    }
43
44
    fn requires_auth(&self, name: &str) -> bool {
45
        self.auth_endpoints.contains("all") || self.auth_endpoints.contains(name)
46
    }
47
}
48
49
#[derive(Clone)]
50
struct AppState {
51
    db: Db,
52
    highlighter: Arc<Highlighter>,
53
    server_config: ServerConfig,
54
}
55
56
#[derive(Template)]
57
#[template(path = "index.html")]
58
struct IndexTemplate;
59
60
#[derive(Template)]
61
#[template(path = "snippet.html")]
62
struct SnippetTemplate {
63
    name: String,
64
    content: String,
65
    highlighted_content: String,
66
}
67
68
#[derive(Deserialize)]
69
struct CreateSnippetForm {
70
    name: String,
71
    content: String,
72
}
73
74
async fn index() -> WebTemplate<IndexTemplate> {
75
    WebTemplate(IndexTemplate)
76
}
77
78
async fn view_snippet(
79
    State(state): State<AppState>,
80
    Path(short_id): Path<String>,
81
) -> Result<WebTemplate<SnippetTemplate>, (StatusCode, Html<String>)> {
82
    match db::get_snippet_by_short_id(&state.db, &short_id) {
83
        Ok(Some(snippet)) => {
84
            let highlighted_content = state.highlighter.highlight(&snippet.name, &snippet.content);
85
            Ok(WebTemplate(SnippetTemplate {
86
                name: snippet.name,
87
                content: snippet.content,
88
                highlighted_content,
89
            }))
90
        }
91
        Ok(None) => Err((
92
            StatusCode::NOT_FOUND,
93
            Html("<h1>Snippet not found</h1>".to_string()),
94
        )),
95
        Err(_) => Err((
96
            StatusCode::INTERNAL_SERVER_ERROR,
97
            Html("<h1>Internal server error</h1>".to_string()),
98
        )),
99
    }
100
}
101
102
async fn create_snippet(
103
    State(state): State<AppState>,
104
    Form(form): Form<CreateSnippetForm>,
105
) -> Result<Redirect, (StatusCode, Html<String>)> {
106
    match db::create_snippet(&state.db, &form.name, &form.content) {
107
        Ok(snippet) => Ok(Redirect::to(&format!("/s/{}", snippet.short_id))),
108
        Err(_) => Err((
109
            StatusCode::INTERNAL_SERVER_ERROR,
110
            Html("<h1>Internal server error</h1>".to_string()),
111
        )),
112
    }
113
}
114
115
async fn require_api_key(
116
    State(state): State<AppState>,
117
    headers: HeaderMap,
118
    request: Request,
119
    next: Next,
120
) -> Result<Response, (StatusCode, Json<serde_json::Value>)> {
121
    let server_key = match &state.server_config.api_key {
122
        Some(k) => k,
123
        None => return Err((
124
            StatusCode::FORBIDDEN,
125
            Json(serde_json::json!({"error": "No API key configured on server"})),
126
        )),
127
    };
128
    let provided = headers
129
        .get("x-api-key")
130
        .and_then(|v| v.to_str().ok());
131
    match provided {
132
        Some(k) if k.as_bytes().ct_eq(server_key.as_bytes()).into() => Ok(next.run(request).await),
133
        _ => Err((
134
            StatusCode::UNAUTHORIZED,
135
            Json(serde_json::json!({"error": "Invalid or missing API key"})),
136
        )),
137
    }
138
}
139
140
async fn api_list_snippets(
141
    State(state): State<AppState>,
142
) -> Result<Json<Vec<Snippet>>, (StatusCode, Json<serde_json::Value>)> {
143
    match db::get_all_snippets(&state.db) {
144
        Ok(snippets) => Ok(Json(snippets)),
145
        Err(_) => Err((StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({"error": "Internal server error"})))),
146
    }
147
}
148
149
async fn api_get_snippet(
150
    State(state): State<AppState>,
151
    Path(short_id): Path<String>,
152
) -> Result<Json<Snippet>, (StatusCode, Json<serde_json::Value>)> {
153
    match db::get_snippet_by_short_id(&state.db, &short_id) {
154
        Ok(Some(snippet)) => Ok(Json(snippet)),
155
        Ok(None) => Err((StatusCode::NOT_FOUND, Json(serde_json::json!({"error": "Snippet not found"})))),
156
        Err(_) => Err((StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({"error": "Internal server error"})))),
157
    }
158
}
159
160
#[derive(Deserialize)]
161
struct ApiCreateSnippet {
162
    name: String,
163
    content: String,
164
}
165
166
async fn api_create_snippet(
167
    State(state): State<AppState>,
168
    Json(body): Json<ApiCreateSnippet>,
169
) -> Result<(StatusCode, Json<Snippet>), (StatusCode, Json<serde_json::Value>)> {
170
    match db::create_snippet(&state.db, &body.name, &body.content) {
171
        Ok(snippet) => Ok((StatusCode::CREATED, Json(snippet))),
172
        Err(_) => Err((StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({"error": "Internal server error"})))),
173
    }
174
}
175
176
async fn api_delete_snippet(
177
    State(state): State<AppState>,
178
    Path(short_id): Path<String>,
179
) -> Result<Json<serde_json::Value>, (StatusCode, Json<serde_json::Value>)> {
180
    match db::delete_snippet_by_short_id(&state.db, &short_id) {
181
        Ok(true) => Ok(Json(serde_json::json!({"deleted": true}))),
182
        Ok(false) => Err((StatusCode::NOT_FOUND, Json(serde_json::json!({"error": "Snippet not found"})))),
183
        Err(_) => Err((StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({"error": "Internal server error"})))),
184
    }
185
}
186
187
async fn api_update_snippet(
188
    State(state): State<AppState>,
189
    Path(short_id): Path<String>,
190
    Json(body): Json<ApiCreateSnippet>,
191
) -> Result<Json<Snippet>, (StatusCode, Json<serde_json::Value>)> {
192
    match db::update_snippet_by_short_id(&state.db, &short_id, &body.name, &body.content) {
193
        Ok(Some(snippet)) => Ok(Json(snippet)),
194
        Ok(None) => Err((StatusCode::NOT_FOUND, Json(serde_json::json!({"error": "Snippet not found"})))),
195
        Err(_) => Err((StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({"error": "Internal server error"})))),
196
    }
197
}
198
199
fn build_api_routes(state: &AppState) -> Router<AppState> {
200
    let config = &state.server_config;
201
202
    let auth_layer = middleware::from_fn_with_state(state.clone(), require_api_key);
203
204
    // /api/snippets — GET (api_list) and POST (api_create)
205
    let list_authed = config.requires_auth("api_list");
206
    let create_authed = config.requires_auth("api_create");
207
208
    // /api/snippets/{short_id} — GET (api_get), PUT (api_update), and DELETE (api_delete)
209
    let get_authed = config.requires_auth("api_get");
210
    let update_authed = config.requires_auth("api_update");
211
    let delete_authed = config.requires_auth("api_delete");
212
213
    // Build authed router
214
    let mut authed = Router::new();
215
    if list_authed {
216
        authed = authed.route("/api/snippets", get(api_list_snippets));
217
    }
218
    if create_authed {
219
        authed = authed.route("/api/snippets", post(api_create_snippet));
220
    }
221
    if get_authed {
222
        authed = authed.route("/api/snippets/{short_id}", get(api_get_snippet));
223
    }
224
    if update_authed {
225
        authed = authed.route("/api/snippets/{short_id}", put(api_update_snippet));
226
    }
227
    if delete_authed {
228
        authed = authed.route("/api/snippets/{short_id}", delete(api_delete_snippet));
229
    }
230
    let authed = authed.route_layer(auth_layer);
231
232
    // Build open router
233
    let mut open = Router::new();
234
    if !list_authed {
235
        open = open.route("/api/snippets", get(api_list_snippets));
236
    }
237
    if !create_authed {
238
        open = open.route("/api/snippets", post(api_create_snippet));
239
    }
240
    if !get_authed {
241
        open = open.route("/api/snippets/{short_id}", get(api_get_snippet));
242
    }
243
    if !update_authed {
244
        open = open.route("/api/snippets/{short_id}", put(api_update_snippet));
245
    }
246
    if !delete_authed {
247
        open = open.route("/api/snippets/{short_id}", delete(api_delete_snippet));
248
    }
249
250
    authed.merge(open)
251
}
252
253
fn mime_from_path(path: &str) -> &'static str {
254
    match path.rsplit('.').next().unwrap_or("") {
255
        "css" => "text/css",
256
        "js" => "application/javascript",
257
        "html" => "text/html",
258
        "png" => "image/png",
259
        "ico" => "image/x-icon",
260
        "svg" => "image/svg+xml",
261
        "woff" => "font/woff",
262
        "woff2" => "font/woff2",
263
        "ttf" => "font/ttf",
264
        "otf" => "font/otf",
265
        "json" | "webmanifest" => "application/json",
266
        "jpg" | "jpeg" => "image/jpeg",
267
        _ => "application/octet-stream",
268
    }
269
}
270
271
async fn serve_assets(Path(path): Path<String>) -> Response {
272
    match Assets::get(&path) {
273
        Some(file) => {
274
            let mime = mime_from_path(&path);
275
            ([(header::CONTENT_TYPE, mime)], file.data).into_response()
276
        }
277
        None => StatusCode::NOT_FOUND.into_response(),
278
    }
279
}
280
281
async fn serve_static(Path(path): Path<String>) -> Response {
282
    match Static::get(&path) {
283
        Some(file) => {
284
            let mime = mime_from_path(&path);
285
            ([(header::CONTENT_TYPE, mime)], file.data).into_response()
286
        }
287
        None => StatusCode::NOT_FOUND.into_response(),
288
    }
289
}
290
291
pub async fn run(host: String, port: u16) {
292
    dotenvy::dotenv().ok();
293
294
    let server_config = ServerConfig::from_env();
295
296
    // Validate endpoint names
297
    let known = ["api_list", "api_create", "api_get", "api_update", "api_delete", "all", "none"];
298
    for name in &server_config.auth_endpoints {
299
        if !known.contains(&name.as_str()) {
300
            eprintln!("Warning: unknown auth endpoint name '{}' in SIPP_AUTH_ENDPOINTS", name);
301
        }
302
    }
303
304
    if !server_config.auth_endpoints.is_empty() && server_config.api_key.is_none() {
305
        eprintln!("Warning: SIPP_AUTH_ENDPOINTS is set but SIPP_API_KEY is not configured");
306
    }
307
308
    if server_config.auth_endpoints.is_empty() {
309
        println!("Auth: disabled (no endpoints require authentication)");
310
    } else {
311
        let names: Vec<&str> = server_config.auth_endpoints.iter().map(|s| s.as_str()).collect();
312
        println!("Auth: enabled for endpoints: {}", names.join(", "));
313
    }
314
315
    let state = AppState {
316
        db: db::init_db().expect("Failed to initialize database"),
317
        highlighter: Arc::new(Highlighter::new()),
318
        server_config,
319
    };
320
321
    let api_routes = build_api_routes(&state);
322
323
    let app = Router::new()
324
        .route("/", get(index))
325
        .route("/s/{short_id}", get(view_snippet))
326
        .route("/snippets", post(create_snippet))
327
        .merge(api_routes)
328
        .route("/assets/{*path}", get(serve_assets))
329
        .route("/static/{*path}", get(serve_static))
330
        .with_state(state);
331
332
    let addr = format!("{}:{}", host, port);
333
    let listener = tokio::net::TcpListener::bind(&addr)
334
        .await
335
        .unwrap_or_else(|_| panic!("Failed to bind to {}", addr));
336
337
    println!("Server running at http://{}:{}", host, port);
338
339
    axum::serve(listener, app)
340
        .await
341
        .expect("Failed to start server");
342
}