| 1 | use axum::{ |
| 2 | extract::{Path, State}, |
| 3 | http::{HeaderMap, Request, StatusCode}, |
| 4 | response::IntoResponse, |
| 5 | routing::{any, delete, get, post}, |
| 6 | Router, |
| 7 | }; |
| 8 | use std::{sync::Arc, time::Duration}; |
| 9 | use tower_http::{ |
| 10 | limit::RequestBodyLimitLayer, |
| 11 | timeout::TimeoutLayer, |
| 12 | trace::TraceLayer, |
| 13 | }; |
| 14 | |
| 15 | mod api; |
| 16 | mod auth; |
| 17 | mod config; |
| 18 | mod db; |
| 19 | mod git; |
| 20 | mod git_http; |
| 21 | mod git_ssh; |
| 22 | mod repos; |
| 23 | mod validate; |
| 24 | mod web; |
| 25 | |
| 26 | use config::Config; |
| 27 | use git_http::GitRunner; |
| 28 | use russh::server::Server as _; |
| 29 | |
| 30 | #[derive(Clone)] |
| 31 | struct AppState { |
| 32 | cfg: Arc<Config>, |
| 33 | pool: sqlx::SqlitePool, |
| 34 | git: GitRunner, |
| 35 | } |
| 36 | |
| 37 | fn extract_token(headers: &HeaderMap) -> Option<String> { |
| 38 | let v = headers.get("authorization")?.to_str().ok()?; |
| 39 | if let Some(token) = v.strip_prefix("Bearer ") { |
| 40 | return Some(token.to_string()); |
| 41 | } |
| 42 | // Git credential helpers send Basic auth — password is the token |
| 43 | if let Some(encoded) = v.strip_prefix("Basic ") { |
| 44 | use base64::Engine; |
| 45 | let decoded = base64::engine::general_purpose::STANDARD.decode(encoded).ok()?; |
| 46 | let decoded = String::from_utf8(decoded).ok()?; |
| 47 | let (_user, token) = decoded.split_once(':')?; |
| 48 | return Some(token.to_string()); |
| 49 | } |
| 50 | None |
| 51 | } |
| 52 | |
| 53 | #[tokio::main] |
| 54 | async fn main() { |
| 55 | tracing_subscriber::fmt() |
| 56 | .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) |
| 57 | .init(); |
| 58 | |
| 59 | let cfg = Arc::new(Config::from_env()); |
| 60 | let pool = db::connect(&cfg.db_url).await; |
| 61 | let git = GitRunner::new(cfg.clone()); |
| 62 | |
| 63 | let api_state = api::ApiState { cfg: cfg.clone(), pool: pool.clone() }; |
| 64 | let app_state = AppState { cfg: cfg.clone(), pool: pool.clone(), git }; |
| 65 | |
| 66 | let api_router = Router::new() |
| 67 | .route("/", get(web::homepage).with_state( |
| 68 | web::WebState { cfg: cfg.clone(), pool: pool.clone() } |
| 69 | )) |
| 70 | .route("/api/signup", post(api::signup)) |
| 71 | .route("/api/login/challenge", post(api::login_challenge)) |
| 72 | .route("/api/login/verify", post(api::login_verify)) |
| 73 | .route("/api/admin/invites", post(api::create_invite)) |
| 74 | .route("/api/stats", get(api::stats)) |
| 75 | .route("/api/repos", post(api::create_repo)) |
| 76 | .route("/api/repos/:username", get(api::list_repos)) |
| 77 | .route("/api/repos/:username/:repo", delete(api::delete_repo)) |
| 78 | .route("/scripts/git-credential-openhub", get(credential_helper)) |
| 79 | .route("/install.sh", get(install_script)) |
| 80 | .with_state(api_state); |
| 81 | |
| 82 | // Catch-all: dispatches git (.git in path) vs web browsing |
| 83 | let catchall_router = Router::new() |
| 84 | .route("/*path", any(catchall_handler)) |
| 85 | .with_state(app_state); |
| 86 | |
| 87 | let app = api_router |
| 88 | .merge(catchall_router) |
| 89 | .layer(TraceLayer::new_for_http()) |
| 90 | .layer(TimeoutLayer::new(Duration::from_secs(60))) |
| 91 | .layer(RequestBodyLimitLayer::new(cfg.max_push_bytes)); |
| 92 | |
| 93 | let listener = tokio::net::TcpListener::bind(&cfg.bind).await.unwrap(); |
| 94 | |
| 95 | // SSH server setup |
| 96 | let host_key = load_or_generate_host_key(&cfg.ssh_host_key_path); |
| 97 | let ssh_config = Arc::new(russh::server::Config { |
| 98 | inactivity_timeout: Some(Duration::from_secs(600)), |
| 99 | auth_rejection_time: Duration::from_secs(3), |
| 100 | auth_rejection_time_initial: Some(Duration::from_secs(0)), |
| 101 | keys: vec![host_key], |
| 102 | ..Default::default() |
| 103 | }); |
| 104 | |
| 105 | let ssh_listener = tokio::net::TcpListener::bind(&cfg.ssh_bind).await.unwrap(); |
| 106 | let mut ssh_server = git_ssh::SshServer::new(cfg.clone(), pool.clone()); |
| 107 | |
| 108 | tracing::info!("HTTP on {}, SSH on {}", cfg.bind, cfg.ssh_bind); |
| 109 | |
| 110 | tokio::select! { |
| 111 | r = axum::serve(listener, app) => { |
| 112 | r.unwrap(); |
| 113 | } |
| 114 | r = ssh_server.run_on_socket(ssh_config, &ssh_listener) => { |
| 115 | if let Err(e) = r { tracing::error!("SSH server error: {e}"); } |
| 116 | } |
| 117 | } |
| 118 | } |
| 119 | |
| 120 | fn load_or_generate_host_key(path: &str) -> russh::keys::key::KeyPair { |
| 121 | let path = std::path::Path::new(path); |
| 122 | |
| 123 | if path.exists() { |
| 124 | return russh::keys::load_secret_key(path, None).expect("Failed to load SSH host key"); |
| 125 | } |
| 126 | |
| 127 | tracing::info!("Generating new SSH host key at {}", path.display()); |
| 128 | |
| 129 | if let Some(parent) = path.parent() { |
| 130 | std::fs::create_dir_all(parent).expect("Failed to create host key directory"); |
| 131 | } |
| 132 | |
| 133 | let key = russh::keys::key::KeyPair::generate_ed25519(); |
| 134 | |
| 135 | #[cfg(unix)] |
| 136 | { |
| 137 | use std::os::unix::fs::OpenOptionsExt; |
| 138 | let file = std::fs::OpenOptions::new() |
| 139 | .write(true) |
| 140 | .create_new(true) |
| 141 | .mode(0o600) |
| 142 | .open(path) |
| 143 | .expect("Failed to create SSH host key file"); |
| 144 | russh::keys::encode_pkcs8_pem(&key, file).expect("Failed to write SSH host key"); |
| 145 | } |
| 146 | #[cfg(not(unix))] |
| 147 | { |
| 148 | let file = std::fs::File::create(path).expect("Failed to create SSH host key file"); |
| 149 | russh::keys::encode_pkcs8_pem(&key, file).expect("Failed to write SSH host key"); |
| 150 | } |
| 151 | |
| 152 | key |
| 153 | } |
| 154 | |
| 155 | async fn install_script() -> impl IntoResponse { |
| 156 | ( |
| 157 | [("content-type", "text/plain")], |
| 158 | include_str!("../../scripts/install.sh"), |
| 159 | ) |
| 160 | } |
| 161 | |
| 162 | async fn credential_helper() -> impl IntoResponse { |
| 163 | ( |
| 164 | [("content-type", "text/plain")], |
| 165 | include_str!("../../scripts/git-credential-openhub"), |
| 166 | ) |
| 167 | } |
| 168 | |
| 169 | fn auth_required() -> axum::response::Response { |
| 170 | axum::response::Response::builder() |
| 171 | .status(StatusCode::UNAUTHORIZED) |
| 172 | .header("WWW-Authenticate", "Basic realm=\"openhub\"") |
| 173 | .body(axum::body::Body::empty()) |
| 174 | .unwrap() |
| 175 | } |
| 176 | |
| 177 | async fn catchall_handler( |
| 178 | State(st): State<AppState>, |
| 179 | Path(path): Path<String>, |
| 180 | headers: HeaderMap, |
| 181 | req: Request<axum::body::Body>, |
| 182 | ) -> axum::response::Response { |
| 183 | // Check if this is a git request (path contains .git segment) |
| 184 | if path.contains(".git") { |
| 185 | return git_transport(&st, &path, headers, req).await; |
| 186 | } |
| 187 | |
| 188 | // Extract query string before consuming the request |
| 189 | let query = req.uri().query().unwrap_or("").to_string(); |
| 190 | |
| 191 | // Otherwise, dispatch to web browsing |
| 192 | web_dispatch(&st, &path, &query).await |
| 193 | } |
| 194 | |
| 195 | async fn git_transport( |
| 196 | st: &AppState, |
| 197 | path: &str, |
| 198 | headers: HeaderMap, |
| 199 | req: Request<axum::body::Body>, |
| 200 | ) -> axum::response::Response { |
| 201 | let mut it = path.splitn(3, '/'); |
| 202 | let user = match it.next() { |
| 203 | Some(u) => u, |
| 204 | None => return StatusCode::NOT_FOUND.into_response(), |
| 205 | }; |
| 206 | let repo_git = match it.next() { |
| 207 | Some(r) => r, |
| 208 | None => return StatusCode::NOT_FOUND.into_response(), |
| 209 | }; |
| 210 | let rest = it.next().unwrap_or(""); |
| 211 | |
| 212 | if !repo_git.ends_with(".git") { |
| 213 | return StatusCode::NOT_FOUND.into_response(); |
| 214 | } |
| 215 | let repo = repo_git.trim_end_matches(".git"); |
| 216 | |
| 217 | if !repos::repo_exists(&st.pool, user, repo).await { |
| 218 | return StatusCode::NOT_FOUND.into_response(); |
| 219 | } |
| 220 | |
| 221 | let is_receive_pack = rest.ends_with("git-receive-pack"); |
| 222 | |
| 223 | let remote_user = if is_receive_pack { |
| 224 | let token = match extract_token(&headers) { |
| 225 | Some(t) => t, |
| 226 | None => return auth_required(), |
| 227 | }; |
| 228 | let (_uid, uname) = match auth::auth_user_by_token(&st.pool, &token).await { |
| 229 | Some(x) => x, |
| 230 | None => return auth_required(), |
| 231 | }; |
| 232 | if uname != user { |
| 233 | return StatusCode::FORBIDDEN.into_response(); |
| 234 | } |
| 235 | Some(uname) |
| 236 | } else { |
| 237 | None |
| 238 | }; |
| 239 | |
| 240 | let path_info = format!("/{}/{}.git/{}", user, repo, rest); |
| 241 | match st.git.run_cgi(&path_info, req, remote_user.as_deref()).await { |
| 242 | Ok(resp) => resp, |
| 243 | Err(status) => status.into_response(), |
| 244 | } |
| 245 | } |
| 246 | |
| 247 | async fn web_dispatch(st: &AppState, path: &str, query: &str) -> axum::response::Response { |
| 248 | let segments: Vec<&str> = path.split('/').filter(|s| !s.is_empty()).collect(); |
| 249 | let web_state = web::WebState { |
| 250 | cfg: st.cfg.clone(), |
| 251 | pool: st.pool.clone(), |
| 252 | }; |
| 253 | |
| 254 | match segments.as_slice() { |
| 255 | // /:username |
| 256 | [username] => { |
| 257 | if validate::is_reserved_path(username) { |
| 258 | return StatusCode::NOT_FOUND.into_response(); |
| 259 | } |
| 260 | web::user_page( |
| 261 | State(web_state), |
| 262 | Path(username.to_string()), |
| 263 | ).await |
| 264 | } |
| 265 | // /:owner/:repo |
| 266 | [owner, repo] => { |
| 267 | web::repo_page( |
| 268 | State(web_state), |
| 269 | Path(vec![ |
| 270 | ("owner".to_string(), owner.to_string()), |
| 271 | ("repo".to_string(), repo.to_string()), |
| 272 | ]), |
| 273 | ).await |
| 274 | } |
| 275 | // /:owner/:repo/tree/:branch/...path |
| 276 | [owner, repo, "tree", branch, rest @ ..] => { |
| 277 | let subpath = rest.join("/"); |
| 278 | web::repo_tree_page( |
| 279 | State(web_state), |
| 280 | Path((owner.to_string(), repo.to_string(), branch.to_string(), subpath)), |
| 281 | ).await |
| 282 | } |
| 283 | // /:owner/:repo/blob/:branch/...path |
| 284 | [owner, repo, "blob", branch, rest @ ..] => { |
| 285 | let subpath = rest.join("/"); |
| 286 | web::blob_page( |
| 287 | State(web_state), |
| 288 | Path((owner.to_string(), repo.to_string(), branch.to_string(), subpath)), |
| 289 | ).await |
| 290 | } |
| 291 | // /:owner/:repo/commits/:branch |
| 292 | [owner, repo, "commits", branch] => { |
| 293 | let page = query.split('&') |
| 294 | .find_map(|p| p.strip_prefix("page=")) |
| 295 | .and_then(|v| v.parse().ok()); |
| 296 | web::commits_page( |
| 297 | State(web_state), |
| 298 | Path((owner.to_string(), repo.to_string(), branch.to_string())), |
| 299 | axum::extract::Query(web::CommitsQuery { page }), |
| 300 | ).await |
| 301 | } |
| 302 | _ => StatusCode::NOT_FOUND.into_response(), |
| 303 | } |
| 304 | } |