diff --git a/backend/services/author_service.py b/backend/services/author_service.py index f2ce4d1..37fefa4 100644 --- a/backend/services/author_service.py +++ b/backend/services/author_service.py @@ -1,5 +1,5 @@ from datetime import datetime -from sqlalchemy import select, func, desc as sa_desc, asc as sa_asc, literal +from sqlalchemy import select, func, desc as sa_desc, asc as sa_asc from sqlalchemy.ext.asyncio import AsyncSession from backend.models.author import Author @@ -17,51 +17,62 @@ async def list_authors( page: int = 1, per_page: int = 25, ) -> tuple[list[dict], int]: - # Always compute counts from actual data - post_count = select(func.count(Post.id)).where(Post.author_id == Author.id) - comment_count = select(func.count(Comment.id)).where(Comment.author_id == Author.id) + # Build count subqueries + post_q = select(func.count(Post.id)).where(Post.author_id == Author.id) + comment_q = select(func.count(Comment.id)).where(Comment.author_id == Author.id) if subreddit_id: - post_count = post_count.where(Post.subreddit_id == subreddit_id) - comment_count = comment_count.join(Post).where(Post.subreddit_id == subreddit_id) + post_q = post_q.where(Post.subreddit_id == subreddit_id) + comment_q = comment_q.join(Post).where(Post.subreddit_id == subreddit_id) if since: - post_count = post_count.where(Post.created_utc >= since) - comment_count = comment_count.where(Comment.created_utc >= since) + post_q = post_q.where(Post.created_utc >= since) + comment_q = comment_q.where(Comment.created_utc >= since) if until: - post_count = post_count.where(Post.created_utc <= until) - comment_count = comment_count.where(Comment.created_utc <= until) + post_q = post_q.where(Post.created_utc <= until) + comment_q = comment_q.where(Comment.created_utc <= until) - post_sub = post_count.correlate(Author).scalar_subquery().label("total_posts") - comment_sub = comment_count.correlate(Author).scalar_subquery().label("total_comments") + post_sub = post_q.correlate(Author).scalar_subquery().label("total_posts") + comment_sub = comment_q.correlate(Author).scalar_subquery().label("total_comments") - base = select(Author, post_sub, comment_sub) + # Build main query as a subquery so we can filter on computed columns + inner = select( + Author.id, + Author.username, + Author.first_seen_at, + Author.last_seen_at, + post_sub, + comment_sub, + ).subquery() - # Only show authors with activity - having_activity = base.having( - (post_sub > 0) | (comment_sub > 0) - ).group_by(Author.id) - - # Count total active authors - count_stmt = select(func.count()).select_from(having_activity.subquery()) + # Count authors with any activity + count_stmt = select(func.count()).select_from(inner).where( + (inner.c.total_posts > 0) | (inner.c.total_comments > 0) + ) total = (await db.execute(count_stmt)).scalar() or 0 - # Sort by computed counts - if sort_by == "total_posts": - sort_col = post_sub - else: - sort_col = comment_sub - + # Sort and paginate + sort_col = inner.c.total_posts if sort_by == "total_posts" else inner.c.total_comments order_fn = sa_desc if sort_order == "desc" else sa_asc - query = base.order_by(order_fn(sort_col)).offset((page - 1) * per_page).limit(per_page) + + query = ( + select(inner) + .where((inner.c.total_posts > 0) | (inner.c.total_comments > 0)) + .order_by(order_fn(sort_col)) + .offset((page - 1) * per_page) + .limit(per_page) + ) result = await db.execute(query) authors = [] for row in result.all(): - author = row[0] - data = {c.name: getattr(author, c.name) for c in author.__table__.columns} - data["total_posts"] = row[1] or 0 - data["total_comments"] = row[2] or 0 - authors.append(data) + authors.append({ + "id": row.id, + "username": row.username, + "first_seen_at": row.first_seen_at, + "last_seen_at": row.last_seen_at, + "total_posts": row.total_posts or 0, + "total_comments": row.total_comments or 0, + }) return authors, total @@ -78,8 +89,11 @@ async def get_author(db: AsyncSession, author_id: int) -> dict | None: .scalar_subquery() ) result = await db.execute( - select(Author, post_count.label("total_posts"), comment_count.label("total_comments")) - .where(Author.id == author_id) + select( + Author, + post_count.label("total_posts"), + comment_count.label("total_comments"), + ).where(Author.id == author_id) ) row = result.first() if not row: