290 lines
10 KiB
Python
290 lines
10 KiB
Python
import asyncio
|
|
import logging
|
|
from datetime import datetime, timezone, timedelta
|
|
|
|
from sqlalchemy import select, create_engine
|
|
from sqlalchemy.orm import Session, sessionmaker
|
|
from sqlalchemy.dialects.postgresql import insert
|
|
|
|
from backend.config import settings
|
|
from backend.models.subreddit import MonitoredSubreddit
|
|
from backend.models.author import Author
|
|
from backend.models.post import Post
|
|
from backend.models.comment import Comment
|
|
from backend.worker.reddit_client import create_client, fetch_json
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_engine = create_engine(settings.database_url_sync, pool_size=3, max_overflow=5, pool_recycle=3600)
|
|
SyncSession = sessionmaker(_engine)
|
|
|
|
|
|
def _get_active_subreddits() -> list[dict]:
|
|
with SyncSession() as db:
|
|
stmt = select(MonitoredSubreddit).where(MonitoredSubreddit.is_active == True) # noqa: E712
|
|
result = db.execute(stmt)
|
|
return [{"id": s.id, "name": s.name} for s in result.scalars()]
|
|
|
|
|
|
def _upsert_author(db: Session, username: str) -> int | None:
|
|
if not username or username == "[deleted]":
|
|
return None
|
|
now = datetime.now(timezone.utc)
|
|
stmt = insert(Author).values(username=username, first_seen_at=now, last_seen_at=now)
|
|
stmt = stmt.on_conflict_do_update(
|
|
index_elements=[Author.username],
|
|
set_={"last_seen_at": now},
|
|
)
|
|
db.execute(stmt)
|
|
result = db.execute(select(Author.id).where(Author.username == username))
|
|
row = result.first()
|
|
return row[0] if row else None
|
|
|
|
|
|
def _parse_post(post_data: dict, subreddit_id: int, db: Session, hot_rank: int | None = None) -> dict:
|
|
"""Parse a Pullpush submission object into a dict for DB upsert."""
|
|
author_id = _upsert_author(db, post_data.get("author"))
|
|
created = datetime.fromtimestamp(post_data.get("created_utc", 0), tz=timezone.utc)
|
|
reddit_id = post_data.get("name", f"t3_{post_data.get('id', '')}")
|
|
return {
|
|
"reddit_id": reddit_id,
|
|
"subreddit_id": subreddit_id,
|
|
"author_id": author_id,
|
|
"title": post_data.get("title", ""),
|
|
"selftext": post_data.get("selftext"),
|
|
"url": post_data.get("url"),
|
|
"permalink": post_data.get("permalink"),
|
|
"flair": post_data.get("link_flair_text"),
|
|
"score": post_data.get("score", 0),
|
|
"upvote_ratio": post_data.get("upvote_ratio"),
|
|
"num_comments": post_data.get("num_comments", 0),
|
|
"is_self": post_data.get("is_self"),
|
|
"over_18": post_data.get("over_18", False),
|
|
"hot_rank": hot_rank,
|
|
"created_utc": created,
|
|
"collected_at": datetime.now(timezone.utc),
|
|
"updated_at": datetime.now(timezone.utc),
|
|
}
|
|
|
|
|
|
def _upsert_posts(db: Session, posts: list[dict], update_hot_rank: bool = False):
|
|
if not posts:
|
|
return
|
|
update_set = {
|
|
"score": insert(Post).excluded.score,
|
|
"upvote_ratio": insert(Post).excluded.upvote_ratio,
|
|
"num_comments": insert(Post).excluded.num_comments,
|
|
"updated_at": insert(Post).excluded.updated_at,
|
|
}
|
|
if update_hot_rank:
|
|
update_set["hot_rank"] = insert(Post).excluded.hot_rank
|
|
|
|
stmt = insert(Post).values(posts)
|
|
stmt = stmt.on_conflict_do_update(
|
|
index_elements=[Post.reddit_id],
|
|
set_=update_set,
|
|
)
|
|
db.execute(stmt)
|
|
|
|
|
|
def _parse_comment(comment_data: dict, post_id: int, db: Session, parent_map: dict) -> dict | None:
|
|
"""Parse a Pullpush comment object into a dict for DB upsert."""
|
|
if not comment_data.get("body"):
|
|
return None
|
|
reddit_id = comment_data.get("name", f"t1_{comment_data.get('id', '')}")
|
|
author_id = _upsert_author(db, comment_data.get("author"))
|
|
created = datetime.fromtimestamp(comment_data.get("created_utc", 0), tz=timezone.utc)
|
|
|
|
parent_reddit_id = comment_data.get("parent_id", "")
|
|
parent_comment_id = parent_map.get(parent_reddit_id)
|
|
|
|
return {
|
|
"reddit_id": reddit_id,
|
|
"post_id": post_id,
|
|
"parent_comment_id": parent_comment_id,
|
|
"author_id": author_id,
|
|
"body": comment_data.get("body", ""),
|
|
"score": comment_data.get("score", 0),
|
|
"created_utc": created,
|
|
"collected_at": datetime.now(timezone.utc),
|
|
"updated_at": datetime.now(timezone.utc),
|
|
}
|
|
|
|
|
|
def poll_new_posts():
|
|
"""Fetch recent submissions from Pullpush for each active subreddit."""
|
|
asyncio.run(_poll_new_posts_async())
|
|
|
|
|
|
async def _poll_new_posts_async():
|
|
subreddits = _get_active_subreddits()
|
|
if not subreddits:
|
|
return
|
|
|
|
client = create_client()
|
|
async with client:
|
|
for sub in subreddits:
|
|
data = await fetch_json(client, "/reddit/search/submission/", {
|
|
"subreddit": sub["name"],
|
|
"sort": "created_utc",
|
|
"sort_type": "desc",
|
|
"size": 100,
|
|
})
|
|
if not data:
|
|
continue
|
|
posts_data = data.get("data", [])
|
|
if not posts_data:
|
|
continue
|
|
|
|
with SyncSession() as db:
|
|
posts = [_parse_post(p, sub["id"], db) for p in posts_data]
|
|
_upsert_posts(db, posts)
|
|
db.commit()
|
|
logger.info(f"r/{sub['name']}: upserted {len(posts_data)} new posts")
|
|
|
|
|
|
def poll_hot_posts():
|
|
"""Approximate hot posts by fetching recent high-scoring submissions."""
|
|
asyncio.run(_poll_hot_posts_async())
|
|
|
|
|
|
async def _poll_hot_posts_async():
|
|
subreddits = _get_active_subreddits()
|
|
if not subreddits:
|
|
return
|
|
|
|
after_epoch = int((datetime.now(timezone.utc) - timedelta(hours=24)).timestamp())
|
|
|
|
client = create_client()
|
|
async with client:
|
|
for sub in subreddits:
|
|
data = await fetch_json(client, "/reddit/search/submission/", {
|
|
"subreddit": sub["name"],
|
|
"sort": "score",
|
|
"sort_type": "desc",
|
|
"size": 100,
|
|
"after": after_epoch,
|
|
})
|
|
if not data:
|
|
continue
|
|
posts_data = data.get("data", [])
|
|
if not posts_data:
|
|
continue
|
|
|
|
with SyncSession() as db:
|
|
posts = [
|
|
_parse_post(p, sub["id"], db, hot_rank=i + 1)
|
|
for i, p in enumerate(posts_data)
|
|
]
|
|
_upsert_posts(db, posts, update_hot_rank=True)
|
|
db.commit()
|
|
logger.info(f"r/{sub['name']}: updated hot ranks for {len(posts_data)} posts")
|
|
|
|
|
|
def collect_comments():
|
|
"""Fetch recent comments from Pullpush for each active subreddit."""
|
|
asyncio.run(_collect_comments_async())
|
|
|
|
|
|
async def _collect_comments_async():
|
|
subreddits = _get_active_subreddits()
|
|
if not subreddits:
|
|
return
|
|
|
|
cutoff_epoch = int((datetime.now(timezone.utc) - timedelta(hours=48)).timestamp())
|
|
|
|
client = create_client()
|
|
async with client:
|
|
for sub in subreddits:
|
|
data = await fetch_json(client, "/reddit/search/comment/", {
|
|
"subreddit": sub["name"],
|
|
"sort": "created_utc",
|
|
"sort_type": "desc",
|
|
"size": 100,
|
|
"after": cutoff_epoch,
|
|
})
|
|
if not data:
|
|
continue
|
|
comments_data = data.get("data", [])
|
|
if not comments_data:
|
|
continue
|
|
|
|
with SyncSession() as db:
|
|
# Build lookup: reddit post fullname -> our DB post ID
|
|
link_ids = {c.get("link_id") for c in comments_data if c.get("link_id")}
|
|
if not link_ids:
|
|
continue
|
|
|
|
result = db.execute(
|
|
select(Post.id, Post.reddit_id).where(Post.reddit_id.in_(link_ids))
|
|
)
|
|
post_lookup = {reddit_id: post_id for post_id, reddit_id in result}
|
|
|
|
# Build parent_map: comment reddit_id -> our DB comment ID
|
|
existing = db.execute(
|
|
select(Comment.id, Comment.reddit_id)
|
|
.join(Post)
|
|
.where(Post.subreddit_id == sub["id"])
|
|
)
|
|
parent_map = {reddit_id: cid for cid, reddit_id in existing}
|
|
|
|
comments_to_upsert = []
|
|
for c in comments_data:
|
|
post_id = post_lookup.get(c.get("link_id"))
|
|
if not post_id:
|
|
continue # Post not in our DB yet
|
|
parsed = _parse_comment(c, post_id, db, parent_map)
|
|
if parsed:
|
|
comments_to_upsert.append(parsed)
|
|
|
|
if comments_to_upsert:
|
|
for comment in comments_to_upsert:
|
|
stmt = insert(Comment).values(comment)
|
|
stmt = stmt.on_conflict_do_update(
|
|
index_elements=[Comment.reddit_id],
|
|
set_={
|
|
"score": stmt.excluded.score,
|
|
"body": stmt.excluded.body,
|
|
"updated_at": stmt.excluded.updated_at,
|
|
},
|
|
)
|
|
db.execute(stmt)
|
|
db.commit()
|
|
logger.info(f"r/{sub['name']}: upserted {len(comments_to_upsert)} comments")
|
|
|
|
|
|
def update_scores():
|
|
"""Re-fetch recent posts to capture any score updates in Pullpush."""
|
|
asyncio.run(_update_scores_async())
|
|
|
|
|
|
async def _update_scores_async():
|
|
subreddits = _get_active_subreddits()
|
|
if not subreddits:
|
|
return
|
|
|
|
after_epoch = int((datetime.now(timezone.utc) - timedelta(days=7)).timestamp())
|
|
|
|
client = create_client()
|
|
async with client:
|
|
for sub in subreddits:
|
|
data = await fetch_json(client, "/reddit/search/submission/", {
|
|
"subreddit": sub["name"],
|
|
"sort": "created_utc",
|
|
"sort_type": "desc",
|
|
"size": 100,
|
|
"after": after_epoch,
|
|
})
|
|
if not data:
|
|
continue
|
|
posts_data = data.get("data", [])
|
|
if not posts_data:
|
|
continue
|
|
|
|
with SyncSession() as db:
|
|
posts = [_parse_post(p, sub["id"], db) for p in posts_data]
|
|
_upsert_posts(db, posts)
|
|
db.commit()
|
|
|
|
logger.info(f"Score update complete for {len(subreddits)} subreddits")
|