import logging from datetime import datetime, timezone, timedelta from sqlalchemy import select, create_engine from sqlalchemy.orm import Session, sessionmaker from sqlalchemy.dialects.postgresql import insert from backend.config import settings from backend.models.subreddit import MonitoredSubreddit from backend.models.author import Author from backend.models.post import Post from backend.models.comment import Comment from backend.worker.reddit_client import create_client, fetch_json logger = logging.getLogger(__name__) # Sync engine for worker (PRAW-replacement uses async httpx, but DB writes are sync for simplicity with APScheduler) _engine = create_engine(settings.database_url_sync, pool_size=3, max_overflow=5, pool_recycle=3600) SyncSession = sessionmaker(_engine) def _get_active_subreddits() -> list[dict]: with SyncSession() as db: stmt = select(MonitoredSubreddit).where(MonitoredSubreddit.is_active == True) # noqa: E712 result = db.execute(stmt) return [{"id": s.id, "name": s.name} for s in result.scalars()] def _upsert_author(db: Session, username: str) -> int | None: if not username or username == "[deleted]": return None now = datetime.now(timezone.utc) stmt = insert(Author).values(username=username, first_seen_at=now, last_seen_at=now) stmt = stmt.on_conflict_do_update( index_elements=[Author.username], set_={"last_seen_at": now}, ) db.execute(stmt) result = db.execute(select(Author.id).where(Author.username == username)) row = result.first() return row[0] if row else None def _parse_post(post_data: dict, subreddit_id: int, db: Session, hot_rank: int | None = None) -> dict: data = post_data.get("data", post_data) author_id = _upsert_author(db, data.get("author")) created = datetime.fromtimestamp(data.get("created_utc", 0), tz=timezone.utc) return { "reddit_id": data.get("name", f"t3_{data.get('id', '')}"), "subreddit_id": subreddit_id, "author_id": author_id, "title": data.get("title", ""), "selftext": data.get("selftext"), "url": data.get("url"), "permalink": data.get("permalink"), "flair": data.get("link_flair_text"), "score": data.get("score", 0), "upvote_ratio": data.get("upvote_ratio"), "num_comments": data.get("num_comments", 0), "is_self": data.get("is_self"), "over_18": data.get("over_18", False), "hot_rank": hot_rank, "created_utc": created, "collected_at": datetime.now(timezone.utc), "updated_at": datetime.now(timezone.utc), } def _upsert_posts(db: Session, posts: list[dict], update_hot_rank: bool = False): if not posts: return update_set = { "score": insert(Post).excluded.score, "upvote_ratio": insert(Post).excluded.upvote_ratio, "num_comments": insert(Post).excluded.num_comments, "updated_at": insert(Post).excluded.updated_at, } if update_hot_rank: update_set["hot_rank"] = insert(Post).excluded.hot_rank stmt = insert(Post).values(posts) stmt = stmt.on_conflict_do_update( index_elements=[Post.reddit_id], set_=update_set, ) db.execute(stmt) def _parse_comment(comment_data: dict, post_id: int, db: Session, parent_map: dict) -> dict | None: data = comment_data.get("data", comment_data) if data.get("kind") == "more" or not data.get("body"): return None reddit_id = data.get("name", f"t1_{data.get('id', '')}") author_id = _upsert_author(db, data.get("author")) created = datetime.fromtimestamp(data.get("created_utc", 0), tz=timezone.utc) parent_reddit_id = data.get("parent_id", "") parent_comment_id = parent_map.get(parent_reddit_id) return { "reddit_id": reddit_id, "post_id": post_id, "parent_comment_id": parent_comment_id, "author_id": author_id, "body": data.get("body", ""), "score": data.get("score", 0), "created_utc": created, "collected_at": datetime.now(timezone.utc), "updated_at": datetime.now(timezone.utc), } import asyncio def poll_new_posts(): """Fetch /new for each active subreddit and upsert posts.""" asyncio.run(_poll_new_posts_async()) async def _poll_new_posts_async(): subreddits = _get_active_subreddits() if not subreddits: return client = create_client() async with client: for sub in subreddits: data = await fetch_json(client, f"/r/{sub['name']}/new", {"limit": "100"}) if not data: continue children = data.get("data", {}).get("children", []) if not children: continue with SyncSession() as db: posts = [_parse_post(child, sub["id"], db) for child in children] _upsert_posts(db, posts) db.commit() logger.info(f"r/{sub['name']}: upserted {len(children)} new posts") def poll_hot_posts(): """Fetch /hot for each active subreddit and update hot_rank.""" asyncio.run(_poll_hot_posts_async()) async def _poll_hot_posts_async(): subreddits = _get_active_subreddits() if not subreddits: return client = create_client() async with client: for sub in subreddits: data = await fetch_json(client, f"/r/{sub['name']}/hot", {"limit": "100"}) if not data: continue children = data.get("data", {}).get("children", []) if not children: continue with SyncSession() as db: posts = [ _parse_post(child, sub["id"], db, hot_rank=i + 1) for i, child in enumerate(children) ] _upsert_posts(db, posts, update_hot_rank=True) db.commit() logger.info(f"r/{sub['name']}: updated hot ranks for {len(children)} posts") def collect_comments(): """Fetch comments for recent posts.""" asyncio.run(_collect_comments_async()) async def _collect_comments_async(): cutoff = datetime.now(timezone.utc) - timedelta(hours=48) with SyncSession() as db: stmt = ( select(Post.id, Post.reddit_id, Post.subreddit_id) .join(MonitoredSubreddit) .where( MonitoredSubreddit.is_active == True, # noqa: E712 Post.created_utc >= cutoff, ) .order_by(Post.created_utc.desc()) .limit(50) ) result = db.execute(stmt) recent_posts = [{"id": r[0], "reddit_id": r[1], "subreddit_id": r[2]} for r in result] if not recent_posts: return client = create_client() async with client: for post in recent_posts: short_id = post["reddit_id"].replace("t3_", "") data = await fetch_json(client, f"/comments/{short_id}", {"limit": "500", "sort": "new"}) if not data or len(data) < 2: continue comment_listing = data[1].get("data", {}).get("children", []) with SyncSession() as db: # Build parent_map from existing comments existing = db.execute( select(Comment.id, Comment.reddit_id).where(Comment.post_id == post["id"]) ) parent_map = {r[1]: r[0] for r in existing} comments_to_upsert = [] def process_comments(children): for child in children: if child.get("kind") == "more": continue c_data = child.get("data", {}) parsed = _parse_comment(c_data, post["id"], db, parent_map) if parsed: comments_to_upsert.append(parsed) # Process replies recursively replies = c_data.get("replies") if isinstance(replies, dict): reply_children = replies.get("data", {}).get("children", []) process_comments(reply_children) process_comments(comment_listing) if comments_to_upsert: # Upsert comments one at a time to handle parent references for comment in comments_to_upsert: stmt = insert(Comment).values(comment) stmt = stmt.on_conflict_do_update( index_elements=[Comment.reddit_id], set_={ "score": stmt.excluded.score, "body": stmt.excluded.body, "updated_at": stmt.excluded.updated_at, }, ) db.execute(stmt) db.commit() logger.info(f"Post {short_id}: upserted {len(comments_to_upsert)} comments") def update_scores(): """Re-fetch recent posts to update scores and comment counts.""" asyncio.run(_update_scores_async()) async def _update_scores_async(): cutoff = datetime.now(timezone.utc) - timedelta(days=7) with SyncSession() as db: stmt = ( select(Post.reddit_id, Post.subreddit_id, MonitoredSubreddit.name) .join(MonitoredSubreddit) .where( MonitoredSubreddit.is_active == True, # noqa: E712 Post.created_utc >= cutoff, ) ) result = db.execute(stmt) posts_by_sub: dict[str, list[str]] = {} for reddit_id, _, sub_name in result: posts_by_sub.setdefault(sub_name, []).append(reddit_id) if not posts_by_sub: return # Score updates piggyback on the new/hot polls — the upsert already updates scores. # This job explicitly re-fetches to catch score changes on older posts. client = create_client() async with client: for sub_name, reddit_ids in posts_by_sub.items(): data = await fetch_json(client, f"/r/{sub_name}/new", {"limit": "100"}) if not data: continue children = data.get("data", {}).get("children", []) with SyncSession() as db: sub = db.execute( select(MonitoredSubreddit).where(MonitoredSubreddit.name == sub_name) ).scalar_one_or_none() if not sub: continue posts = [_parse_post(child, sub.id, db) for child in children] _upsert_posts(db, posts) db.commit() logger.info(f"Score update complete for {len(posts_by_sub)} subreddits")