diff --git a/backend/services/author_service.py b/backend/services/author_service.py index cbf9fec..f2ce4d1 100644 --- a/backend/services/author_service.py +++ b/backend/services/author_service.py @@ -1,5 +1,5 @@ from datetime import datetime -from sqlalchemy import select, func +from sqlalchemy import select, func, desc as sa_desc, asc as sa_asc, literal from sqlalchemy.ext.asyncio import AsyncSession from backend.models.author import Author @@ -17,60 +17,75 @@ async def list_authors( page: int = 1, per_page: int = 25, ) -> tuple[list[dict], int]: - base = select(Author) + # Always compute counts from actual data + post_count = select(func.count(Post.id)).where(Post.author_id == Author.id) + comment_count = select(func.count(Comment.id)).where(Comment.author_id == Author.id) - if subreddit_id or since or until: - # Need to compute activity counts with filters - post_count = ( - select(func.count(Post.id)) - .where(Post.author_id == Author.id) - ) - comment_count = ( - select(func.count(Comment.id)) - .where(Comment.author_id == Author.id) - ) + if subreddit_id: + post_count = post_count.where(Post.subreddit_id == subreddit_id) + comment_count = comment_count.join(Post).where(Post.subreddit_id == subreddit_id) + if since: + post_count = post_count.where(Post.created_utc >= since) + comment_count = comment_count.where(Comment.created_utc >= since) + if until: + post_count = post_count.where(Post.created_utc <= until) + comment_count = comment_count.where(Comment.created_utc <= until) - if subreddit_id: - post_count = post_count.where(Post.subreddit_id == subreddit_id) - comment_count = comment_count.join(Post).where(Post.subreddit_id == subreddit_id) - if since: - post_count = post_count.where(Post.created_utc >= since) - comment_count = comment_count.where(Comment.created_utc >= since) - if until: - post_count = post_count.where(Post.created_utc <= until) - comment_count = comment_count.where(Comment.created_utc <= until) + post_sub = post_count.correlate(Author).scalar_subquery().label("total_posts") + comment_sub = comment_count.correlate(Author).scalar_subquery().label("total_comments") - base = select( - Author, - post_count.correlate(Author).scalar_subquery().label("filtered_posts"), - comment_count.correlate(Author).scalar_subquery().label("filtered_comments"), - ) - else: - base = select(Author) + base = select(Author, post_sub, comment_sub) - count_stmt = select(func.count()).select_from(base.subquery()) + # Only show authors with activity + having_activity = base.having( + (post_sub > 0) | (comment_sub > 0) + ).group_by(Author.id) + + # Count total active authors + count_stmt = select(func.count()).select_from(having_activity.subquery()) total = (await db.execute(count_stmt)).scalar() or 0 - sort_col = getattr(Author, sort_by, Author.total_comments) - if sort_order == "asc": - base = base.order_by(sort_col.asc()) + # Sort by computed counts + if sort_by == "total_posts": + sort_col = post_sub else: - base = base.order_by(sort_col.desc()) + sort_col = comment_sub - base = base.offset((page - 1) * per_page).limit(per_page) + order_fn = sa_desc if sort_order == "desc" else sa_asc + query = base.order_by(order_fn(sort_col)).offset((page - 1) * per_page).limit(per_page) - result = await db.execute(base) + result = await db.execute(query) authors = [] for row in result.all(): author = row[0] data = {c.name: getattr(author, c.name) for c in author.__table__.columns} + data["total_posts"] = row[1] or 0 + data["total_comments"] = row[2] or 0 authors.append(data) return authors, total async def get_author(db: AsyncSession, author_id: int) -> dict | None: - author = await db.get(Author, author_id) - if not author: + post_count = ( + select(func.count(Post.id)) + .where(Post.author_id == author_id) + .scalar_subquery() + ) + comment_count = ( + select(func.count(Comment.id)) + .where(Comment.author_id == author_id) + .scalar_subquery() + ) + result = await db.execute( + select(Author, post_count.label("total_posts"), comment_count.label("total_comments")) + .where(Author.id == author_id) + ) + row = result.first() + if not row: return None - return {c.name: getattr(author, c.name) for c in author.__table__.columns} + author = row[0] + data = {c.name: getattr(author, c.name) for c in author.__table__.columns} + data["total_posts"] = row[1] or 0 + data["total_comments"] = row[2] or 0 + return data