mirror of
https://github.com/harivansh-afk/betterNAS.git
synced 2026-04-15 08:03:42 +00:00
662 lines
18 KiB
Go
662 lines
18 KiB
Go
package main
|
|
|
|
import (
|
|
"crypto/rand"
|
|
"database/sql"
|
|
"encoding/hex"
|
|
"errors"
|
|
"fmt"
|
|
"os"
|
|
"path/filepath"
|
|
"sort"
|
|
"strings"
|
|
"time"
|
|
|
|
"golang.org/x/crypto/bcrypt"
|
|
_ "modernc.org/sqlite"
|
|
)
|
|
|
|
var (
|
|
errUsernameTaken = errors.New("username already taken")
|
|
errInvalidLogin = errors.New("invalid username or password")
|
|
errSessionExpired = errors.New("session expired or invalid")
|
|
)
|
|
|
|
const sqliteSchema = `
|
|
CREATE TABLE IF NOT EXISTS ordinals (
|
|
name TEXT PRIMARY KEY,
|
|
value INTEGER NOT NULL DEFAULT 0
|
|
);
|
|
INSERT OR IGNORE INTO ordinals (name, value) VALUES ('node', 0), ('export', 0);
|
|
|
|
CREATE TABLE IF NOT EXISTS nodes (
|
|
id TEXT PRIMARY KEY,
|
|
machine_id TEXT NOT NULL UNIQUE,
|
|
owner_id TEXT REFERENCES users(id),
|
|
display_name TEXT NOT NULL DEFAULT '',
|
|
agent_version TEXT NOT NULL DEFAULT '',
|
|
status TEXT NOT NULL DEFAULT 'online',
|
|
last_seen_at TEXT,
|
|
direct_address TEXT,
|
|
relay_address TEXT
|
|
);
|
|
|
|
CREATE TABLE IF NOT EXISTS node_tokens (
|
|
node_id TEXT PRIMARY KEY REFERENCES nodes(id),
|
|
token_hash TEXT NOT NULL
|
|
);
|
|
|
|
CREATE TABLE IF NOT EXISTS exports (
|
|
id TEXT PRIMARY KEY,
|
|
node_id TEXT NOT NULL REFERENCES nodes(id),
|
|
owner_id TEXT REFERENCES users(id),
|
|
label TEXT NOT NULL DEFAULT '',
|
|
path TEXT NOT NULL,
|
|
mount_path TEXT NOT NULL DEFAULT '',
|
|
capacity_bytes INTEGER,
|
|
UNIQUE(node_id, path)
|
|
);
|
|
|
|
CREATE TABLE IF NOT EXISTS export_protocols (
|
|
export_id TEXT NOT NULL REFERENCES exports(id) ON DELETE CASCADE,
|
|
protocol TEXT NOT NULL,
|
|
PRIMARY KEY (export_id, protocol)
|
|
);
|
|
|
|
CREATE TABLE IF NOT EXISTS export_tags (
|
|
export_id TEXT NOT NULL REFERENCES exports(id) ON DELETE CASCADE,
|
|
tag TEXT NOT NULL,
|
|
PRIMARY KEY (export_id, tag)
|
|
);
|
|
|
|
CREATE TABLE IF NOT EXISTS users (
|
|
id TEXT PRIMARY KEY,
|
|
username TEXT NOT NULL UNIQUE COLLATE NOCASE,
|
|
password_hash TEXT NOT NULL,
|
|
created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ', 'now'))
|
|
);
|
|
|
|
CREATE TABLE IF NOT EXISTS sessions (
|
|
token TEXT PRIMARY KEY,
|
|
user_id TEXT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
|
created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ', 'now')),
|
|
expires_at TEXT NOT NULL
|
|
);
|
|
`
|
|
|
|
type sqliteStore struct {
|
|
db *sql.DB
|
|
}
|
|
|
|
func newSQLiteStore(dbPath string) (*sqliteStore, error) {
|
|
dir := filepath.Dir(dbPath)
|
|
if err := os.MkdirAll(dir, 0o750); err != nil {
|
|
return nil, fmt.Errorf("create database directory %s: %w", dir, err)
|
|
}
|
|
|
|
db, err := sql.Open("sqlite", dbPath+"?_pragma=journal_mode(wal)&_pragma=foreign_keys(1)")
|
|
if err != nil {
|
|
return nil, fmt.Errorf("open database %s: %w", dbPath, err)
|
|
}
|
|
|
|
if _, err := db.Exec(sqliteSchema); err != nil {
|
|
db.Close()
|
|
return nil, fmt.Errorf("initialize database schema: %w", err)
|
|
}
|
|
if err := migrateSQLiteSchema(db); err != nil {
|
|
db.Close()
|
|
return nil, err
|
|
}
|
|
|
|
return &sqliteStore{db: db}, nil
|
|
}
|
|
|
|
func migrateSQLiteSchema(db *sql.DB) error {
|
|
migrations := []string{
|
|
"ALTER TABLE nodes ADD COLUMN owner_id TEXT REFERENCES users(id)",
|
|
"ALTER TABLE exports ADD COLUMN owner_id TEXT REFERENCES users(id)",
|
|
}
|
|
for _, statement := range migrations {
|
|
if _, err := db.Exec(statement); err != nil && !strings.Contains(err.Error(), "duplicate column name") {
|
|
return fmt.Errorf("run sqlite migration %q: %w", statement, err)
|
|
}
|
|
}
|
|
|
|
if _, err := db.Exec(`
|
|
UPDATE exports
|
|
SET owner_id = (
|
|
SELECT owner_id
|
|
FROM nodes
|
|
WHERE nodes.id = exports.node_id
|
|
)
|
|
WHERE owner_id IS NULL
|
|
`); err != nil {
|
|
return fmt.Errorf("backfill export owners: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *sqliteStore) nextOrdinal(tx *sql.Tx, name string) (int, error) {
|
|
var value int
|
|
err := tx.QueryRow("UPDATE ordinals SET value = value + 1 WHERE name = ? RETURNING value", name).Scan(&value)
|
|
if err != nil {
|
|
return 0, fmt.Errorf("next ordinal %q: %w", name, err)
|
|
}
|
|
return value, nil
|
|
}
|
|
|
|
func ordinalToNodeID(ordinal int) string {
|
|
if ordinal == 1 {
|
|
return "dev-node"
|
|
}
|
|
return fmt.Sprintf("dev-node-%d", ordinal)
|
|
}
|
|
|
|
func ordinalToExportID(ordinal int) string {
|
|
if ordinal == 1 {
|
|
return "dev-export"
|
|
}
|
|
return fmt.Sprintf("dev-export-%d", ordinal)
|
|
}
|
|
|
|
func (s *sqliteStore) registerNode(ownerID string, request nodeRegistrationRequest, registeredAt time.Time) (nodeRegistrationResult, error) {
|
|
tx, err := s.db.Begin()
|
|
if err != nil {
|
|
return nodeRegistrationResult{}, fmt.Errorf("begin transaction: %w", err)
|
|
}
|
|
defer tx.Rollback()
|
|
|
|
// Check if machine already registered.
|
|
var nodeID string
|
|
var existingOwnerID sql.NullString
|
|
err = tx.QueryRow("SELECT id, owner_id FROM nodes WHERE machine_id = ?", request.MachineID).Scan(&nodeID, &existingOwnerID)
|
|
if err == sql.ErrNoRows {
|
|
ordinal, err := s.nextOrdinal(tx, "node")
|
|
if err != nil {
|
|
return nodeRegistrationResult{}, err
|
|
}
|
|
nodeID = ordinalToNodeID(ordinal)
|
|
} else if err != nil {
|
|
return nodeRegistrationResult{}, fmt.Errorf("lookup node by machine_id: %w", err)
|
|
} else if existingOwnerID.Valid && strings.TrimSpace(existingOwnerID.String) != "" && existingOwnerID.String != ownerID {
|
|
return nodeRegistrationResult{}, errNodeOwnedByAnotherUser
|
|
}
|
|
|
|
// Upsert node.
|
|
_, err = tx.Exec(`
|
|
INSERT INTO nodes (id, machine_id, owner_id, display_name, agent_version, status, last_seen_at, direct_address, relay_address)
|
|
VALUES (?, ?, ?, ?, ?, 'online', ?, ?, ?)
|
|
ON CONFLICT(id) DO UPDATE SET
|
|
owner_id = excluded.owner_id,
|
|
display_name = excluded.display_name,
|
|
agent_version = excluded.agent_version,
|
|
status = 'online',
|
|
last_seen_at = excluded.last_seen_at,
|
|
direct_address = excluded.direct_address,
|
|
relay_address = excluded.relay_address
|
|
`, nodeID, request.MachineID, ownerID, request.DisplayName, request.AgentVersion,
|
|
registeredAt.UTC().Format(time.RFC3339),
|
|
nullableString(request.DirectAddress), nullableString(request.RelayAddress))
|
|
if err != nil {
|
|
return nodeRegistrationResult{}, fmt.Errorf("upsert node: %w", err)
|
|
}
|
|
|
|
if err := tx.Commit(); err != nil {
|
|
return nodeRegistrationResult{}, fmt.Errorf("commit registration: %w", err)
|
|
}
|
|
|
|
node, _ := s.nodeByID(nodeID)
|
|
return nodeRegistrationResult{
|
|
Node: node,
|
|
}, nil
|
|
}
|
|
|
|
func (s *sqliteStore) upsertExports(nodeID string, ownerID string, request nodeExportsRequest) ([]storageExport, error) {
|
|
tx, err := s.db.Begin()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("begin transaction: %w", err)
|
|
}
|
|
defer tx.Rollback()
|
|
|
|
// Verify node exists.
|
|
var exists bool
|
|
err = tx.QueryRow("SELECT 1 FROM nodes WHERE id = ? AND owner_id = ?", nodeID, ownerID).Scan(&exists)
|
|
if err != nil {
|
|
return nil, errNodeNotFound
|
|
}
|
|
|
|
// Collect current export IDs for this node (by path).
|
|
currentExports := make(map[string]string) // path -> exportID
|
|
rows, err := tx.Query("SELECT id, path FROM exports WHERE node_id = ?", nodeID)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("query current exports: %w", err)
|
|
}
|
|
for rows.Next() {
|
|
var id, path string
|
|
if err := rows.Scan(&id, &path); err != nil {
|
|
rows.Close()
|
|
return nil, fmt.Errorf("scan current export: %w", err)
|
|
}
|
|
currentExports[path] = id
|
|
}
|
|
rows.Close()
|
|
|
|
keepPaths := make(map[string]struct{}, len(request.Exports))
|
|
for _, input := range request.Exports {
|
|
exportID, exists := currentExports[input.Path]
|
|
if !exists {
|
|
ordinal, err := s.nextOrdinal(tx, "export")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
exportID = ordinalToExportID(ordinal)
|
|
}
|
|
|
|
_, err = tx.Exec(`
|
|
INSERT INTO exports (id, node_id, owner_id, label, path, mount_path, capacity_bytes)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?)
|
|
ON CONFLICT(id) DO UPDATE SET
|
|
owner_id = excluded.owner_id,
|
|
label = excluded.label,
|
|
mount_path = excluded.mount_path,
|
|
capacity_bytes = excluded.capacity_bytes
|
|
`, exportID, nodeID, ownerID, input.Label, input.Path, input.MountPath, nullableInt64(input.CapacityBytes))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("upsert export %q: %w", input.Path, err)
|
|
}
|
|
|
|
// Replace protocols.
|
|
if _, err := tx.Exec("DELETE FROM export_protocols WHERE export_id = ?", exportID); err != nil {
|
|
return nil, fmt.Errorf("clear export protocols: %w", err)
|
|
}
|
|
for _, protocol := range input.Protocols {
|
|
if _, err := tx.Exec("INSERT INTO export_protocols (export_id, protocol) VALUES (?, ?)", exportID, protocol); err != nil {
|
|
return nil, fmt.Errorf("insert export protocol: %w", err)
|
|
}
|
|
}
|
|
|
|
// Replace tags.
|
|
if _, err := tx.Exec("DELETE FROM export_tags WHERE export_id = ?", exportID); err != nil {
|
|
return nil, fmt.Errorf("clear export tags: %w", err)
|
|
}
|
|
for _, tag := range input.Tags {
|
|
if _, err := tx.Exec("INSERT INTO export_tags (export_id, tag) VALUES (?, ?)", exportID, tag); err != nil {
|
|
return nil, fmt.Errorf("insert export tag: %w", err)
|
|
}
|
|
}
|
|
|
|
keepPaths[input.Path] = struct{}{}
|
|
}
|
|
|
|
// Remove exports not in the input.
|
|
for path, exportID := range currentExports {
|
|
if _, keep := keepPaths[path]; !keep {
|
|
if _, err := tx.Exec("DELETE FROM exports WHERE id = ?", exportID); err != nil {
|
|
return nil, fmt.Errorf("delete stale export %q: %w", exportID, err)
|
|
}
|
|
}
|
|
}
|
|
|
|
if err := tx.Commit(); err != nil {
|
|
return nil, fmt.Errorf("commit exports: %w", err)
|
|
}
|
|
|
|
return s.listExportsForNode(nodeID), nil
|
|
}
|
|
|
|
func (s *sqliteStore) recordHeartbeat(nodeID string, ownerID string, request nodeHeartbeatRequest) error {
|
|
result, err := s.db.Exec(
|
|
"UPDATE nodes SET status = ?, last_seen_at = ? WHERE id = ? AND owner_id = ?",
|
|
request.Status, request.LastSeenAt, nodeID, ownerID)
|
|
if err != nil {
|
|
return fmt.Errorf("update heartbeat: %w", err)
|
|
}
|
|
affected, _ := result.RowsAffected()
|
|
if affected == 0 {
|
|
return errNodeNotFound
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *sqliteStore) listExports(ownerID string) []storageExport {
|
|
rows, err := s.db.Query("SELECT id, node_id, owner_id, label, path, mount_path, capacity_bytes FROM exports WHERE owner_id = ? ORDER BY id", ownerID)
|
|
if err != nil {
|
|
return nil
|
|
}
|
|
defer rows.Close()
|
|
|
|
var exports []storageExport
|
|
for rows.Next() {
|
|
e := s.scanExport(rows)
|
|
if e.ID != "" {
|
|
exports = append(exports, e)
|
|
}
|
|
}
|
|
if exports == nil {
|
|
exports = []storageExport{}
|
|
}
|
|
|
|
// Load protocols and tags for each export.
|
|
for i := range exports {
|
|
exports[i].Protocols = s.loadExportProtocols(exports[i].ID)
|
|
exports[i].Tags = s.loadExportTags(exports[i].ID)
|
|
}
|
|
|
|
return exports
|
|
}
|
|
|
|
func (s *sqliteStore) listNodes(ownerID string) []nasNode {
|
|
rows, err := s.db.Query("SELECT id, machine_id, owner_id, display_name, agent_version, status, last_seen_at, direct_address, relay_address FROM nodes WHERE owner_id = ? ORDER BY id", ownerID)
|
|
if err != nil {
|
|
return nil
|
|
}
|
|
defer rows.Close()
|
|
|
|
var nodes []nasNode
|
|
for rows.Next() {
|
|
node := s.scanNode(rows)
|
|
if node.ID != "" {
|
|
nodes = append(nodes, node)
|
|
}
|
|
}
|
|
if nodes == nil {
|
|
nodes = []nasNode{}
|
|
}
|
|
|
|
return nodes
|
|
}
|
|
|
|
func (s *sqliteStore) listExportsForNode(nodeID string) []storageExport {
|
|
rows, err := s.db.Query("SELECT id, node_id, owner_id, label, path, mount_path, capacity_bytes FROM exports WHERE node_id = ? ORDER BY id", nodeID)
|
|
if err != nil {
|
|
return nil
|
|
}
|
|
defer rows.Close()
|
|
|
|
var exports []storageExport
|
|
for rows.Next() {
|
|
e := s.scanExport(rows)
|
|
if e.ID != "" {
|
|
exports = append(exports, e)
|
|
}
|
|
}
|
|
if exports == nil {
|
|
exports = []storageExport{}
|
|
}
|
|
|
|
for i := range exports {
|
|
exports[i].Protocols = s.loadExportProtocols(exports[i].ID)
|
|
exports[i].Tags = s.loadExportTags(exports[i].ID)
|
|
}
|
|
|
|
sort.Slice(exports, func(i, j int) bool { return exports[i].ID < exports[j].ID })
|
|
return exports
|
|
}
|
|
|
|
func (s *sqliteStore) exportContext(exportID string, ownerID string) (exportContext, bool) {
|
|
var e storageExport
|
|
var capacityBytes sql.NullInt64
|
|
var exportOwnerID sql.NullString
|
|
err := s.db.QueryRow(
|
|
"SELECT id, node_id, owner_id, label, path, mount_path, capacity_bytes FROM exports WHERE id = ? AND owner_id = ?",
|
|
exportID, ownerID).Scan(&e.ID, &e.NasNodeID, &exportOwnerID, &e.Label, &e.Path, &e.MountPath, &capacityBytes)
|
|
if err != nil {
|
|
return exportContext{}, false
|
|
}
|
|
if exportOwnerID.Valid {
|
|
e.OwnerID = exportOwnerID.String
|
|
}
|
|
if capacityBytes.Valid {
|
|
e.CapacityBytes = &capacityBytes.Int64
|
|
}
|
|
e.Protocols = s.loadExportProtocols(e.ID)
|
|
e.Tags = s.loadExportTags(e.ID)
|
|
|
|
node, ok := s.nodeByID(e.NasNodeID)
|
|
if !ok {
|
|
return exportContext{}, false
|
|
}
|
|
|
|
return exportContext{export: e, node: node}, true
|
|
}
|
|
|
|
func (s *sqliteStore) nodeByID(nodeID string) (nasNode, bool) {
|
|
row := s.db.QueryRow(
|
|
"SELECT id, machine_id, owner_id, display_name, agent_version, status, last_seen_at, direct_address, relay_address FROM nodes WHERE id = ?",
|
|
nodeID)
|
|
n := s.scanNode(row)
|
|
if n.ID == "" {
|
|
return nasNode{}, false
|
|
}
|
|
|
|
return n, true
|
|
}
|
|
|
|
type sqliteNodeScanner interface {
|
|
Scan(dest ...any) error
|
|
}
|
|
|
|
func (s *sqliteStore) scanNode(scanner sqliteNodeScanner) nasNode {
|
|
var n nasNode
|
|
var directAddr, relayAddr sql.NullString
|
|
var lastSeenAt sql.NullString
|
|
var ownerID sql.NullString
|
|
err := scanner.Scan(&n.ID, &n.MachineID, &ownerID, &n.DisplayName, &n.AgentVersion, &n.Status, &lastSeenAt, &directAddr, &relayAddr)
|
|
if err != nil {
|
|
return nasNode{}
|
|
}
|
|
if ownerID.Valid {
|
|
n.OwnerID = ownerID.String
|
|
}
|
|
if lastSeenAt.Valid {
|
|
n.LastSeenAt = lastSeenAt.String
|
|
}
|
|
if directAddr.Valid {
|
|
n.DirectAddress = &directAddr.String
|
|
}
|
|
if relayAddr.Valid {
|
|
n.RelayAddress = &relayAddr.String
|
|
}
|
|
return n
|
|
}
|
|
|
|
func (s *sqliteStore) nodeAuthByMachineID(machineID string) (nodeAuthState, bool) {
|
|
var state nodeAuthState
|
|
var tokenHash sql.NullString
|
|
err := s.db.QueryRow(`
|
|
SELECT n.id, nt.token_hash
|
|
FROM nodes n
|
|
LEFT JOIN node_tokens nt ON nt.node_id = n.id
|
|
WHERE n.machine_id = ?
|
|
`, machineID).Scan(&state.NodeID, &tokenHash)
|
|
if err != nil {
|
|
return nodeAuthState{}, false
|
|
}
|
|
if tokenHash.Valid {
|
|
state.TokenHash = tokenHash.String
|
|
}
|
|
return state, true
|
|
}
|
|
|
|
func (s *sqliteStore) nodeAuthByID(nodeID string) (nodeAuthState, bool) {
|
|
var state nodeAuthState
|
|
var tokenHash sql.NullString
|
|
err := s.db.QueryRow(`
|
|
SELECT n.id, nt.token_hash
|
|
FROM nodes n
|
|
LEFT JOIN node_tokens nt ON nt.node_id = n.id
|
|
WHERE n.id = ?
|
|
`, nodeID).Scan(&state.NodeID, &tokenHash)
|
|
if err != nil {
|
|
return nodeAuthState{}, false
|
|
}
|
|
if tokenHash.Valid {
|
|
state.TokenHash = tokenHash.String
|
|
}
|
|
return state, true
|
|
}
|
|
|
|
// --- helpers ---
|
|
|
|
func (s *sqliteStore) scanExport(rows *sql.Rows) storageExport {
|
|
var e storageExport
|
|
var capacityBytes sql.NullInt64
|
|
var ownerID sql.NullString
|
|
if err := rows.Scan(&e.ID, &e.NasNodeID, &ownerID, &e.Label, &e.Path, &e.MountPath, &capacityBytes); err != nil {
|
|
return storageExport{}
|
|
}
|
|
if ownerID.Valid {
|
|
e.OwnerID = ownerID.String
|
|
}
|
|
if capacityBytes.Valid {
|
|
e.CapacityBytes = &capacityBytes.Int64
|
|
}
|
|
return e
|
|
}
|
|
|
|
func (s *sqliteStore) loadExportProtocols(exportID string) []string {
|
|
rows, err := s.db.Query("SELECT protocol FROM export_protocols WHERE export_id = ? ORDER BY protocol", exportID)
|
|
if err != nil {
|
|
return []string{}
|
|
}
|
|
defer rows.Close()
|
|
|
|
var protocols []string
|
|
for rows.Next() {
|
|
var p string
|
|
if err := rows.Scan(&p); err == nil {
|
|
protocols = append(protocols, p)
|
|
}
|
|
}
|
|
if protocols == nil {
|
|
return []string{}
|
|
}
|
|
return protocols
|
|
}
|
|
|
|
func (s *sqliteStore) loadExportTags(exportID string) []string {
|
|
rows, err := s.db.Query("SELECT tag FROM export_tags WHERE export_id = ? ORDER BY tag", exportID)
|
|
if err != nil {
|
|
return []string{}
|
|
}
|
|
defer rows.Close()
|
|
|
|
var tags []string
|
|
for rows.Next() {
|
|
var t string
|
|
if err := rows.Scan(&t); err == nil {
|
|
tags = append(tags, t)
|
|
}
|
|
}
|
|
if tags == nil {
|
|
return []string{}
|
|
}
|
|
return tags
|
|
}
|
|
|
|
func nullableString(p *string) sql.NullString {
|
|
if p == nil {
|
|
return sql.NullString{}
|
|
}
|
|
return sql.NullString{String: *p, Valid: true}
|
|
}
|
|
|
|
func nullableInt64(p *int64) sql.NullInt64 {
|
|
if p == nil {
|
|
return sql.NullInt64{}
|
|
}
|
|
return sql.NullInt64{Int64: *p, Valid: true}
|
|
}
|
|
|
|
// --- user auth ---
|
|
|
|
func (s *sqliteStore) createUser(username string, password string) (user, error) {
|
|
hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
|
if err != nil {
|
|
return user{}, fmt.Errorf("hash password: %w", err)
|
|
}
|
|
|
|
id, err := newSessionToken()
|
|
if err != nil {
|
|
return user{}, err
|
|
}
|
|
|
|
var u user
|
|
err = s.db.QueryRow(`
|
|
INSERT INTO users (id, username, password_hash) VALUES (?, ?, ?)
|
|
RETURNING id, username, created_at
|
|
`, id, username, string(hash)).Scan(&u.ID, &u.Username, &u.CreatedAt)
|
|
if err != nil {
|
|
if strings.Contains(err.Error(), "UNIQUE constraint") {
|
|
return user{}, errUsernameTaken
|
|
}
|
|
return user{}, fmt.Errorf("create user: %w", err)
|
|
}
|
|
|
|
return u, nil
|
|
}
|
|
|
|
func (s *sqliteStore) authenticateUser(username string, password string) (user, error) {
|
|
var u user
|
|
var passwordHash string
|
|
err := s.db.QueryRow(
|
|
"SELECT id, username, password_hash, created_at FROM users WHERE username = ?",
|
|
username).Scan(&u.ID, &u.Username, &passwordHash, &u.CreatedAt)
|
|
if err != nil {
|
|
return user{}, errInvalidLogin
|
|
}
|
|
|
|
if err := bcrypt.CompareHashAndPassword([]byte(passwordHash), []byte(password)); err != nil {
|
|
return user{}, errInvalidLogin
|
|
}
|
|
|
|
return u, nil
|
|
}
|
|
|
|
func (s *sqliteStore) createSession(userID string, ttl time.Duration) (string, error) {
|
|
token, err := newSessionToken()
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
expiresAt := time.Now().UTC().Add(ttl).Format(time.RFC3339)
|
|
_, err = s.db.Exec(
|
|
"INSERT INTO sessions (token, user_id, expires_at) VALUES (?, ?, ?)",
|
|
token, userID, expiresAt)
|
|
if err != nil {
|
|
return "", fmt.Errorf("create session: %w", err)
|
|
}
|
|
|
|
// Clean up expired sessions opportunistically.
|
|
_, _ = s.db.Exec("DELETE FROM sessions WHERE expires_at < ?", time.Now().UTC().Format(time.RFC3339))
|
|
|
|
return token, nil
|
|
}
|
|
|
|
func (s *sqliteStore) validateSession(token string) (user, error) {
|
|
var u user
|
|
err := s.db.QueryRow(`
|
|
SELECT u.id, u.username, u.created_at
|
|
FROM sessions s
|
|
JOIN users u ON u.id = s.user_id
|
|
WHERE s.token = ? AND s.expires_at > ?
|
|
`, token, time.Now().UTC().Format(time.RFC3339)).Scan(&u.ID, &u.Username, &u.CreatedAt)
|
|
if err != nil {
|
|
return user{}, errSessionExpired
|
|
}
|
|
return u, nil
|
|
}
|
|
|
|
func (s *sqliteStore) deleteSession(token string) error {
|
|
_, err := s.db.Exec("DELETE FROM sessions WHERE token = ?", token)
|
|
return err
|
|
}
|
|
|
|
func newSessionToken() (string, error) {
|
|
raw := make([]byte, 32)
|
|
if _, err := rand.Read(raw); err != nil {
|
|
return "", fmt.Errorf("generate session token: %w", err)
|
|
}
|
|
return hex.EncodeToString(raw), nil
|
|
}
|