144 lines
3.8 KiB
Python
144 lines
3.8 KiB
Python
from datetime import date
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException, Query
|
|
from pydantic import BaseModel
|
|
from sqlalchemy import func, select
|
|
from sqlalchemy.orm import Session
|
|
|
|
from app.auth import get_current_user
|
|
from app.database import get_db
|
|
from app.models.finance import FinanceSnapshot, FinanceTransaction
|
|
|
|
router = APIRouter(
|
|
prefix="/api/finances", tags=["finances"], dependencies=[Depends(get_current_user)]
|
|
)
|
|
|
|
|
|
# --- Schemas ---
|
|
|
|
|
|
class TransactionCreate(BaseModel):
|
|
date: date
|
|
amount: float
|
|
category: str
|
|
description: str | None = None
|
|
|
|
|
|
class TransactionRead(TransactionCreate):
|
|
id: int
|
|
model_config = {"from_attributes": True}
|
|
|
|
|
|
class SnapshotCreate(BaseModel):
|
|
date: date
|
|
net_worth: float
|
|
notes: str | None = None
|
|
|
|
|
|
class SnapshotRead(SnapshotCreate):
|
|
id: int
|
|
model_config = {"from_attributes": True}
|
|
|
|
|
|
class CategorySummary(BaseModel):
|
|
category: str
|
|
total: float
|
|
|
|
|
|
# --- Transaction endpoints ---
|
|
|
|
|
|
@router.get("/transactions", response_model=list[TransactionRead])
|
|
def list_transactions(
|
|
from_date: date | None = Query(None, alias="from"),
|
|
to_date: date | None = Query(None, alias="to"),
|
|
category: str | None = None,
|
|
db: Session = Depends(get_db),
|
|
):
|
|
q = select(FinanceTransaction).order_by(FinanceTransaction.date.desc())
|
|
if from_date:
|
|
q = q.where(FinanceTransaction.date >= from_date)
|
|
if to_date:
|
|
q = q.where(FinanceTransaction.date <= to_date)
|
|
if category:
|
|
q = q.where(FinanceTransaction.category == category)
|
|
return db.scalars(q).all()
|
|
|
|
|
|
@router.post("/transactions", response_model=TransactionRead, status_code=201)
|
|
def create_transaction(body: TransactionCreate, db: Session = Depends(get_db)):
|
|
txn = FinanceTransaction(**body.model_dump())
|
|
db.add(txn)
|
|
db.commit()
|
|
db.refresh(txn)
|
|
return txn
|
|
|
|
|
|
@router.put("/transactions/{txn_id}", response_model=TransactionRead)
|
|
def update_transaction(
|
|
txn_id: int, body: TransactionCreate, db: Session = Depends(get_db)
|
|
):
|
|
txn = db.get(FinanceTransaction, txn_id)
|
|
if not txn:
|
|
raise HTTPException(status_code=404, detail="Transaction not found")
|
|
for key, val in body.model_dump().items():
|
|
setattr(txn, key, val)
|
|
db.commit()
|
|
db.refresh(txn)
|
|
return txn
|
|
|
|
|
|
@router.delete("/transactions/{txn_id}", status_code=204)
|
|
def delete_transaction(txn_id: int, db: Session = Depends(get_db)):
|
|
txn = db.get(FinanceTransaction, txn_id)
|
|
if not txn:
|
|
raise HTTPException(status_code=404, detail="Transaction not found")
|
|
db.delete(txn)
|
|
db.commit()
|
|
|
|
|
|
# --- Snapshot endpoints ---
|
|
|
|
|
|
@router.get("/snapshots", response_model=list[SnapshotRead])
|
|
def list_snapshots(db: Session = Depends(get_db)):
|
|
return db.scalars(
|
|
select(FinanceSnapshot).order_by(FinanceSnapshot.date.desc())
|
|
).all()
|
|
|
|
|
|
@router.post("/snapshots", response_model=SnapshotRead, status_code=201)
|
|
def create_snapshot(body: SnapshotCreate, db: Session = Depends(get_db)):
|
|
snap = FinanceSnapshot(**body.model_dump())
|
|
db.add(snap)
|
|
db.commit()
|
|
db.refresh(snap)
|
|
return snap
|
|
|
|
|
|
# --- Summary endpoint ---
|
|
|
|
|
|
@router.get("/summary", response_model=list[CategorySummary])
|
|
def spending_summary(
|
|
period: str = Query("month", pattern="^(month|year)$"),
|
|
db: Session = Depends(get_db),
|
|
):
|
|
today = date.today()
|
|
if period == "month":
|
|
start = today.replace(day=1)
|
|
else:
|
|
start = today.replace(month=1, day=1)
|
|
|
|
rows = db.execute(
|
|
select(
|
|
FinanceTransaction.category,
|
|
func.sum(FinanceTransaction.amount).label("total"),
|
|
)
|
|
.where(FinanceTransaction.date >= start)
|
|
.where(FinanceTransaction.amount < 0)
|
|
.group_by(FinanceTransaction.category)
|
|
.order_by(func.sum(FinanceTransaction.amount))
|
|
).all()
|
|
return [CategorySummary(category=r.category, total=float(r.total)) for r in rows]
|