chore: server auth config 2e9e63d9
Steve · 2026-02-19 11:49 2 file(s) · +121 −18
Cargo.toml +1 −0
27 27
rpassword = "5"
28 28
open = "5.3.3"
29 29
rust-embed = "8"
30 +
dotenvy = "0.15"
src/server.rs +120 −18
1 1
use askama::Template;
2 2
use askama_web::WebTemplate;
3 3
use axum::{
4 -
    Json, Router,
5 -
    extract::{Form, Path, State},
4 +
    Form, Json, Router,
5 +
    extract::{Path, Request, State},
6 6
    http::{HeaderMap, StatusCode, header},
7 +
    middleware::{self, Next},
7 8
    response::{Html, IntoResponse, Redirect, Response},
8 -
    routing::{get, post},
9 +
    routing::{delete, get, post},
9 10
};
10 11
use rust_embed::Embed;
11 12
use serde::Deserialize;
12 13
use crate::db::{self, Db, Snippet};
13 14
use crate::highlight::Highlighter;
15 +
use std::collections::HashSet;
14 16
use std::sync::Arc;
15 17
16 18
#[derive(Embed)]
22 24
struct Static;
23 25
24 26
#[derive(Clone)]
27 +
struct ServerConfig {
28 +
    api_key: Option<String>,
29 +
    auth_endpoints: HashSet<String>,
30 +
}
31 +
32 +
impl ServerConfig {
33 +
    fn from_env() -> Self {
34 +
        let api_key = std::env::var("SIPP_API_KEY").ok();
35 +
        let auth_endpoints = match std::env::var("SIPP_AUTH_ENDPOINTS") {
36 +
            Ok(val) if val.trim().eq_ignore_ascii_case("none") => HashSet::new(),
37 +
            Ok(val) => val.split(',').map(|s| s.trim().to_lowercase()).collect(),
38 +
            Err(_) => HashSet::new(),
39 +
        };
40 +
        ServerConfig { api_key, auth_endpoints }
41 +
    }
42 +
43 +
    fn requires_auth(&self, name: &str) -> bool {
44 +
        self.auth_endpoints.contains("all") || self.auth_endpoints.contains(name)
45 +
    }
46 +
}
47 +
48 +
#[derive(Clone)]
25 49
struct AppState {
26 50
    db: Db,
27 51
    highlighter: Arc<Highlighter>,
28 -
    api_key: Option<String>,
52 +
    server_config: ServerConfig,
29 53
}
30 54
31 55
#[derive(Template)]
86 110
    Redirect::to(&format!("/s/{}", snippet.short_id))
87 111
}
88 112
89 -
fn check_api_key(state: &AppState, headers: &HeaderMap) -> Result<(), (StatusCode, Json<serde_json::Value>)> {
90 -
    let server_key = match &state.api_key {
113 +
async fn require_api_key(
114 +
    State(state): State<AppState>,
115 +
    headers: HeaderMap,
116 +
    request: Request,
117 +
    next: Next,
118 +
) -> Result<Response, (StatusCode, Json<serde_json::Value>)> {
119 +
    let server_key = match &state.server_config.api_key {
91 120
        Some(k) => k,
92 -
        None => return Err((StatusCode::FORBIDDEN, Json(serde_json::json!({"error": "No API key configured on server"})))),
121 +
        None => return Err((
122 +
            StatusCode::FORBIDDEN,
123 +
            Json(serde_json::json!({"error": "No API key configured on server"})),
124 +
        )),
93 125
    };
94 126
    let provided = headers
95 127
        .get("x-api-key")
96 128
        .and_then(|v| v.to_str().ok());
97 129
    match provided {
98 -
        Some(k) if k == server_key => Ok(()),
99 -
        _ => Err((StatusCode::UNAUTHORIZED, Json(serde_json::json!({"error": "Invalid or missing API key"})))),
130 +
        Some(k) if k == server_key => Ok(next.run(request).await),
131 +
        _ => Err((
132 +
            StatusCode::UNAUTHORIZED,
133 +
            Json(serde_json::json!({"error": "Invalid or missing API key"})),
134 +
        )),
100 135
    }
101 136
}
102 137
103 138
async fn api_list_snippets(
104 139
    State(state): State<AppState>,
105 -
    headers: HeaderMap,
106 -
) -> Result<Json<Vec<Snippet>>, (StatusCode, Json<serde_json::Value>)> {
107 -
    check_api_key(&state, &headers)?;
108 -
    Ok(Json(db::get_all_snippets(&state.db)))
140 +
) -> Json<Vec<Snippet>> {
141 +
    Json(db::get_all_snippets(&state.db))
109 142
}
110 143
111 144
async fn api_get_snippet(
134 167
135 168
async fn api_delete_snippet(
136 169
    State(state): State<AppState>,
137 -
    headers: HeaderMap,
138 170
    Path(short_id): Path<String>,
139 171
) -> Result<Json<serde_json::Value>, (StatusCode, Json<serde_json::Value>)> {
140 -
    check_api_key(&state, &headers)?;
141 172
    if db::delete_snippet_by_short_id(&state.db, &short_id) {
142 173
        Ok(Json(serde_json::json!({"deleted": true})))
143 174
    } else {
145 176
    }
146 177
}
147 178
179 +
fn build_api_routes(state: &AppState) -> Router<AppState> {
180 +
    let config = &state.server_config;
181 +
182 +
    let auth_layer = middleware::from_fn_with_state(state.clone(), require_api_key);
183 +
184 +
    // /api/snippets — GET (api_list) and POST (api_create)
185 +
    let list_authed = config.requires_auth("api_list");
186 +
    let create_authed = config.requires_auth("api_create");
187 +
188 +
    // /api/snippets/{short_id} — GET (api_get) and DELETE (api_delete)
189 +
    let get_authed = config.requires_auth("api_get");
190 +
    let delete_authed = config.requires_auth("api_delete");
191 +
192 +
    // Build authed router
193 +
    let mut authed = Router::new();
194 +
    if list_authed {
195 +
        authed = authed.route("/api/snippets", get(api_list_snippets));
196 +
    }
197 +
    if create_authed {
198 +
        authed = authed.route("/api/snippets", post(api_create_snippet));
199 +
    }
200 +
    if get_authed {
201 +
        authed = authed.route("/api/snippets/{short_id}", get(api_get_snippet));
202 +
    }
203 +
    if delete_authed {
204 +
        authed = authed.route("/api/snippets/{short_id}", delete(api_delete_snippet));
205 +
    }
206 +
    let authed = authed.route_layer(auth_layer);
207 +
208 +
    // Build open router
209 +
    let mut open = Router::new();
210 +
    if !list_authed {
211 +
        open = open.route("/api/snippets", get(api_list_snippets));
212 +
    }
213 +
    if !create_authed {
214 +
        open = open.route("/api/snippets", post(api_create_snippet));
215 +
    }
216 +
    if !get_authed {
217 +
        open = open.route("/api/snippets/{short_id}", get(api_get_snippet));
218 +
    }
219 +
    if !delete_authed {
220 +
        open = open.route("/api/snippets/{short_id}", delete(api_delete_snippet));
221 +
    }
222 +
223 +
    authed.merge(open)
224 +
}
225 +
148 226
fn mime_from_path(path: &str) -> &'static str {
149 227
    match path.rsplit('.').next().unwrap_or("") {
150 228
        "css" => "text/css",
184 262
}
185 263
186 264
pub async fn run(host: String, port: u16) {
265 +
    dotenvy::dotenv().ok();
266 +
267 +
    let server_config = ServerConfig::from_env();
268 +
269 +
    // Validate endpoint names
270 +
    let known = ["api_list", "api_create", "api_get", "api_delete", "all", "none"];
271 +
    for name in &server_config.auth_endpoints {
272 +
        if !known.contains(&name.as_str()) {
273 +
            eprintln!("Warning: unknown auth endpoint name '{}' in SIPP_AUTH_ENDPOINTS", name);
274 +
        }
275 +
    }
276 +
277 +
    if !server_config.auth_endpoints.is_empty() && server_config.api_key.is_none() {
278 +
        eprintln!("Warning: SIPP_AUTH_ENDPOINTS is set but SIPP_API_KEY is not configured");
279 +
    }
280 +
281 +
    if server_config.auth_endpoints.is_empty() {
282 +
        println!("Auth: disabled (no endpoints require authentication)");
283 +
    } else {
284 +
        let names: Vec<&str> = server_config.auth_endpoints.iter().map(|s| s.as_str()).collect();
285 +
        println!("Auth: enabled for endpoints: {}", names.join(", "));
286 +
    }
287 +
187 288
    let state = AppState {
188 289
        db: db::init_db(),
189 290
        highlighter: Arc::new(Highlighter::new()),
190 -
        api_key: std::env::var("SIPP_API_KEY").ok(),
291 +
        server_config,
191 292
    };
192 293
294 +
    let api_routes = build_api_routes(&state);
295 +
193 296
    let app = Router::new()
194 297
        .route("/", get(index))
195 298
        .route("/about", get(about))
196 299
        .route("/s/{short_id}", get(view_snippet))
197 300
        .route("/snippets", post(create_snippet))
198 -
        .route("/api/snippets", get(api_list_snippets).post(api_create_snippet))
199 -
        .route("/api/snippets/{short_id}", get(api_get_snippet).delete(api_delete_snippet))
301 +
        .merge(api_routes)
200 302
        .route("/assets/{*path}", get(serve_assets))
201 303
        .route("/static/{*path}", get(serve_static))
202 304
        .with_state(state);