src/server.rs 13.2 K raw
1
use askama::Template;
2
use askama_web::WebTemplate;
3
use subtle::ConstantTimeEq;
4
use axum::{
5
    Form, Json, Router,
6
    extract::{Path, Request, State},
7
    http::{HeaderMap, StatusCode, header},
8
    middleware::{self, Next},
9
    response::{Html, IntoResponse, Redirect, Response},
10
    routing::{delete, get, post, put},
11
};
12
use rust_embed::Embed;
13
use serde::Deserialize;
14
use crate::db::{self, Db, Snippet};
15
use crate::highlight::Highlighter;
16
use std::collections::HashSet;
17
use std::sync::Arc;
18
19
#[derive(Embed)]
20
#[folder = "assets/"]
21
struct Assets;
22
23
#[derive(Embed)]
24
#[folder = "static/"]
25
struct Static;
26
27
#[derive(Clone)]
28
struct ServerConfig {
29
    api_key: Option<String>,
30
    auth_endpoints: HashSet<String>,
31
    max_content_size: usize,
32
}
33
34
impl ServerConfig {
35
    fn from_env() -> Self {
36
        let api_key = std::env::var("SIPP_API_KEY").ok();
37
        let auth_endpoints = match std::env::var("SIPP_AUTH_ENDPOINTS") {
38
            Ok(val) if val.trim().eq_ignore_ascii_case("none") => HashSet::new(),
39
            Ok(val) => val.split(',').map(|s| s.trim().to_lowercase()).collect(),
40
            Err(_) => ["api_delete", "api_list", "api_update"].iter().map(|s| s.to_string()).collect(),
41
        };
42
        let max_content_size = std::env::var("SIPP_MAX_CONTENT_SIZE")
43
            .ok()
44
            .and_then(|v| v.parse().ok())
45
            .unwrap_or(512_000);
46
        ServerConfig { api_key, auth_endpoints, max_content_size }
47
    }
48
49
    fn requires_auth(&self, name: &str) -> bool {
50
        self.auth_endpoints.contains("all") || self.auth_endpoints.contains(name)
51
    }
52
}
53
54
#[derive(Clone)]
55
struct AppState {
56
    db: Db,
57
    highlighter: Arc<Highlighter>,
58
    server_config: ServerConfig,
59
}
60
61
#[derive(Template)]
62
#[template(path = "index.html")]
63
struct IndexTemplate;
64
65
#[derive(Template)]
66
#[template(path = "admin.html")]
67
struct AdminTemplate;
68
69
#[derive(Template)]
70
#[template(path = "snippet.html")]
71
struct SnippetTemplate {
72
    name: String,
73
    content: String,
74
    highlighted_content: String,
75
}
76
77
#[derive(Deserialize)]
78
struct CreateSnippetForm {
79
    name: String,
80
    content: String,
81
}
82
83
async fn index() -> WebTemplate<IndexTemplate> {
84
    WebTemplate(IndexTemplate)
85
}
86
87
async fn admin() -> WebTemplate<AdminTemplate> {
88
    WebTemplate(AdminTemplate)
89
}
90
91
fn is_cli_user_agent(headers: &HeaderMap) -> bool {
92
    headers
93
        .get(header::USER_AGENT)
94
        .and_then(|v| v.to_str().ok())
95
        .map(|ua| {
96
            let ua = ua.to_lowercase();
97
            ua.starts_with("curl/") || ua.starts_with("wget/") || ua.starts_with("httpie/")
98
        })
99
        .unwrap_or(false)
100
}
101
102
async fn view_snippet(
103
    State(state): State<AppState>,
104
    Path(short_id): Path<String>,
105
    headers: HeaderMap,
106
) -> Result<Response, (StatusCode, Html<String>)> {
107
    match db::get_snippet_by_short_id(&state.db, &short_id) {
108
        Ok(Some(snippet)) => {
109
            if is_cli_user_agent(&headers) {
110
                Ok((
111
                    [(header::CONTENT_TYPE, "text/plain; charset=utf-8")],
112
                    snippet.content,
113
                )
114
                    .into_response())
115
            } else {
116
                let highlighted_content =
117
                    state.highlighter.highlight(&snippet.name, &snippet.content);
118
                Ok(WebTemplate(SnippetTemplate {
119
                    name: snippet.name,
120
                    content: snippet.content,
121
                    highlighted_content,
122
                })
123
                .into_response())
124
            }
125
        }
126
        Ok(None) => Err((
127
            StatusCode::NOT_FOUND,
128
            Html("<h1>Snippet not found</h1>".to_string()),
129
        )),
130
        Err(_) => Err((
131
            StatusCode::INTERNAL_SERVER_ERROR,
132
            Html("<h1>Internal server error</h1>".to_string()),
133
        )),
134
    }
135
}
136
137
async fn create_snippet(
138
    State(state): State<AppState>,
139
    Form(form): Form<CreateSnippetForm>,
140
) -> Result<Redirect, (StatusCode, Html<String>)> {
141
    if form.content.len() > state.server_config.max_content_size {
142
        return Err((
143
            StatusCode::PAYLOAD_TOO_LARGE,
144
            Html(format!(
145
                "<h1>Content too large</h1><p>Maximum size is {} bytes</p>",
146
                state.server_config.max_content_size
147
            )),
148
        ));
149
    }
150
    match db::create_snippet(&state.db, &form.name, &form.content) {
151
        Ok(snippet) => Ok(Redirect::to(&format!("/s/{}", snippet.short_id))),
152
        Err(_) => Err((
153
            StatusCode::INTERNAL_SERVER_ERROR,
154
            Html("<h1>Internal server error</h1>".to_string()),
155
        )),
156
    }
157
}
158
159
async fn require_api_key(
160
    State(state): State<AppState>,
161
    headers: HeaderMap,
162
    request: Request,
163
    next: Next,
164
) -> Result<Response, (StatusCode, Json<serde_json::Value>)> {
165
    let server_key = match &state.server_config.api_key {
166
        Some(k) => k,
167
        None => return Err((
168
            StatusCode::FORBIDDEN,
169
            Json(serde_json::json!({"error": "No API key configured on server"})),
170
        )),
171
    };
172
    let provided = headers
173
        .get("x-api-key")
174
        .and_then(|v| v.to_str().ok());
175
    match provided {
176
        Some(k) if k.as_bytes().ct_eq(server_key.as_bytes()).into() => Ok(next.run(request).await),
177
        _ => Err((
178
            StatusCode::UNAUTHORIZED,
179
            Json(serde_json::json!({"error": "Invalid or missing API key"})),
180
        )),
181
    }
182
}
183
184
async fn api_list_snippets(
185
    State(state): State<AppState>,
186
) -> Result<Json<Vec<Snippet>>, (StatusCode, Json<serde_json::Value>)> {
187
    match db::get_all_snippets(&state.db) {
188
        Ok(snippets) => Ok(Json(snippets)),
189
        Err(_) => Err((StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({"error": "Internal server error"})))),
190
    }
191
}
192
193
async fn api_get_snippet(
194
    State(state): State<AppState>,
195
    Path(short_id): Path<String>,
196
) -> Result<Json<Snippet>, (StatusCode, Json<serde_json::Value>)> {
197
    match db::get_snippet_by_short_id(&state.db, &short_id) {
198
        Ok(Some(snippet)) => Ok(Json(snippet)),
199
        Ok(None) => Err((StatusCode::NOT_FOUND, Json(serde_json::json!({"error": "Snippet not found"})))),
200
        Err(_) => Err((StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({"error": "Internal server error"})))),
201
    }
202
}
203
204
#[derive(Deserialize)]
205
struct ApiCreateSnippet {
206
    name: String,
207
    content: String,
208
}
209
210
async fn api_create_snippet(
211
    State(state): State<AppState>,
212
    Json(body): Json<ApiCreateSnippet>,
213
) -> Result<(StatusCode, Json<Snippet>), (StatusCode, Json<serde_json::Value>)> {
214
    if body.content.len() > state.server_config.max_content_size {
215
        return Err((
216
            StatusCode::PAYLOAD_TOO_LARGE,
217
            Json(serde_json::json!({
218
                "error": format!("Content too large. Maximum size is {} bytes", state.server_config.max_content_size)
219
            })),
220
        ));
221
    }
222
    match db::create_snippet(&state.db, &body.name, &body.content) {
223
        Ok(snippet) => Ok((StatusCode::CREATED, Json(snippet))),
224
        Err(_) => Err((StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({"error": "Internal server error"})))),
225
    }
226
}
227
228
async fn api_delete_snippet(
229
    State(state): State<AppState>,
230
    Path(short_id): Path<String>,
231
) -> Result<Json<serde_json::Value>, (StatusCode, Json<serde_json::Value>)> {
232
    match db::delete_snippet_by_short_id(&state.db, &short_id) {
233
        Ok(true) => Ok(Json(serde_json::json!({"deleted": true}))),
234
        Ok(false) => Err((StatusCode::NOT_FOUND, Json(serde_json::json!({"error": "Snippet not found"})))),
235
        Err(_) => Err((StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({"error": "Internal server error"})))),
236
    }
237
}
238
239
async fn api_update_snippet(
240
    State(state): State<AppState>,
241
    Path(short_id): Path<String>,
242
    Json(body): Json<ApiCreateSnippet>,
243
) -> Result<Json<Snippet>, (StatusCode, Json<serde_json::Value>)> {
244
    if body.content.len() > state.server_config.max_content_size {
245
        return Err((
246
            StatusCode::PAYLOAD_TOO_LARGE,
247
            Json(serde_json::json!({
248
                "error": format!("Content too large. Maximum size is {} bytes", state.server_config.max_content_size)
249
            })),
250
        ));
251
    }
252
    match db::update_snippet_by_short_id(&state.db, &short_id, &body.name, &body.content) {
253
        Ok(Some(snippet)) => Ok(Json(snippet)),
254
        Ok(None) => Err((StatusCode::NOT_FOUND, Json(serde_json::json!({"error": "Snippet not found"})))),
255
        Err(_) => Err((StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({"error": "Internal server error"})))),
256
    }
257
}
258
259
fn build_api_routes(state: &AppState) -> Router<AppState> {
260
    let config = &state.server_config;
261
262
    let auth_layer = middleware::from_fn_with_state(state.clone(), require_api_key);
263
264
    // /api/snippets — GET (api_list) and POST (api_create)
265
    let list_authed = config.requires_auth("api_list");
266
    let create_authed = config.requires_auth("api_create");
267
268
    // /api/snippets/{short_id} — GET (api_get), PUT (api_update), and DELETE (api_delete)
269
    let get_authed = config.requires_auth("api_get");
270
    let update_authed = config.requires_auth("api_update");
271
    let delete_authed = config.requires_auth("api_delete");
272
273
    // Build authed router
274
    let mut authed = Router::new();
275
    if list_authed {
276
        authed = authed.route("/api/snippets", get(api_list_snippets));
277
    }
278
    if create_authed {
279
        authed = authed.route("/api/snippets", post(api_create_snippet));
280
    }
281
    if get_authed {
282
        authed = authed.route("/api/snippets/{short_id}", get(api_get_snippet));
283
    }
284
    if update_authed {
285
        authed = authed.route("/api/snippets/{short_id}", put(api_update_snippet));
286
    }
287
    if delete_authed {
288
        authed = authed.route("/api/snippets/{short_id}", delete(api_delete_snippet));
289
    }
290
    let authed = authed.route_layer(auth_layer);
291
292
    // Build open router
293
    let mut open = Router::new();
294
    if !list_authed {
295
        open = open.route("/api/snippets", get(api_list_snippets));
296
    }
297
    if !create_authed {
298
        open = open.route("/api/snippets", post(api_create_snippet));
299
    }
300
    if !get_authed {
301
        open = open.route("/api/snippets/{short_id}", get(api_get_snippet));
302
    }
303
    if !update_authed {
304
        open = open.route("/api/snippets/{short_id}", put(api_update_snippet));
305
    }
306
    if !delete_authed {
307
        open = open.route("/api/snippets/{short_id}", delete(api_delete_snippet));
308
    }
309
310
    authed.merge(open)
311
}
312
313
fn mime_from_path(path: &str) -> &'static str {
314
    match path.rsplit('.').next().unwrap_or("") {
315
        "css" => "text/css",
316
        "js" => "application/javascript",
317
        "html" => "text/html",
318
        "png" => "image/png",
319
        "ico" => "image/x-icon",
320
        "svg" => "image/svg+xml",
321
        "woff" => "font/woff",
322
        "woff2" => "font/woff2",
323
        "ttf" => "font/ttf",
324
        "otf" => "font/otf",
325
        "json" | "webmanifest" => "application/json",
326
        "jpg" | "jpeg" => "image/jpeg",
327
        _ => "application/octet-stream",
328
    }
329
}
330
331
async fn serve_assets(Path(path): Path<String>) -> Response {
332
    match Assets::get(&path) {
333
        Some(file) => {
334
            let mime = mime_from_path(&path);
335
            ([(header::CONTENT_TYPE, mime)], file.data).into_response()
336
        }
337
        None => StatusCode::NOT_FOUND.into_response(),
338
    }
339
}
340
341
async fn serve_static(Path(path): Path<String>) -> Response {
342
    match Static::get(&path) {
343
        Some(file) => {
344
            let mime = mime_from_path(&path);
345
            ([(header::CONTENT_TYPE, mime)], file.data).into_response()
346
        }
347
        None => StatusCode::NOT_FOUND.into_response(),
348
    }
349
}
350
351
pub async fn run(host: String, port: u16) {
352
    dotenvy::dotenv().ok();
353
354
    let server_config = ServerConfig::from_env();
355
356
    // Validate endpoint names
357
    let known = ["api_list", "api_create", "api_get", "api_update", "api_delete", "all", "none"];
358
    for name in &server_config.auth_endpoints {
359
        if !known.contains(&name.as_str()) {
360
            eprintln!("Warning: unknown auth endpoint name '{}' in SIPP_AUTH_ENDPOINTS", name);
361
        }
362
    }
363
364
    if !server_config.auth_endpoints.is_empty() && server_config.api_key.is_none() {
365
        eprintln!("Warning: SIPP_AUTH_ENDPOINTS is set but SIPP_API_KEY is not configured");
366
    }
367
368
    if server_config.auth_endpoints.is_empty() {
369
        println!("Auth: disabled (no endpoints require authentication)");
370
    } else {
371
        let names: Vec<&str> = server_config.auth_endpoints.iter().map(|s| s.as_str()).collect();
372
        println!("Auth: enabled for endpoints: {}", names.join(", "));
373
    }
374
375
    println!("Max content size: {} bytes", server_config.max_content_size);
376
377
    let state = AppState {
378
        db: db::init_db().expect("Failed to initialize database"),
379
        highlighter: Arc::new(Highlighter::new()),
380
        server_config,
381
    };
382
383
    let api_routes = build_api_routes(&state);
384
385
    let app = Router::new()
386
        .route("/", get(index))
387
        .route("/admin", get(admin))
388
        .route("/s/{short_id}", get(view_snippet))
389
        .route("/snippets", post(create_snippet))
390
        .merge(api_routes)
391
        .route("/assets/{*path}", get(serve_assets))
392
        .route("/static/{*path}", get(serve_static))
393
        .with_state(state);
394
395
    let addr = format!("{}:{}", host, port);
396
    let listener = tokio::net::TcpListener::bind(&addr)
397
        .await
398
        .unwrap_or_else(|_| panic!("Failed to bind to {}", addr));
399
400
    println!("Server running at http://{}:{}", host, port);
401
402
    axum::serve(listener, app)
403
        .await
404
        .expect("Failed to start server");
405
}