789 lines
26 KiB
Python
789 lines
26 KiB
Python
"""
|
|
Database Helper for PDM Migration
|
|
==================================
|
|
Interactive tool for running SELECT queries, transforming results, and
|
|
inserting new rows — with mandatory terminal confirmation before any
|
|
write operation touches the database.
|
|
|
|
Usage:
|
|
python db_helper.py --db target_db --task copy_with_new_id
|
|
python db_helper.py --db source_db --query "SELECT TOP 10 * FROM Documents"
|
|
python db_helper.py --db target_db --task copy_with_new_id --dry-run
|
|
"""
|
|
|
|
import json
|
|
import logging
|
|
import argparse
|
|
import sys
|
|
import os
|
|
import glob
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
from typing import List, Optional, Dict, Any, Callable, Tuple, Set
|
|
|
|
# db_utils lives one directory up
|
|
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
|
from db_utils import DatabaseConnection
|
|
|
|
|
|
# =============================================================================
|
|
# CONFIGURATION
|
|
# =============================================================================
|
|
|
|
CONFIG_PATH = Path(__file__).resolve().parent.parent / "config.json"
|
|
QUERIES_DIR = Path(__file__).resolve().parent / "queries"
|
|
|
|
|
|
def load_config() -> dict:
|
|
"""Load config.json from the project root."""
|
|
with open(CONFIG_PATH, "r", encoding="utf-8") as f:
|
|
return json.load(f)
|
|
|
|
|
|
def load_query(name: str) -> str:
|
|
"""
|
|
Load a SQL query from the queries/ folder by name.
|
|
|
|
Args:
|
|
name: Query name (filename without .sql extension).
|
|
e.g. "get_var47" loads queries/get_var47.sql
|
|
|
|
Returns:
|
|
The SQL text from the file.
|
|
"""
|
|
sql_path = QUERIES_DIR / f"{name}.sql"
|
|
if not sql_path.exists():
|
|
available = sorted(p.stem for p in QUERIES_DIR.glob("*.sql"))
|
|
raise FileNotFoundError(
|
|
f"Query '{name}' not found at {sql_path}\n"
|
|
f"Available queries: {available}"
|
|
)
|
|
return sql_path.read_text(encoding="utf-8").strip()
|
|
|
|
|
|
def list_queries() -> List[str]:
|
|
"""Return names of all available .sql files in the queries/ folder."""
|
|
return sorted(p.stem for p in QUERIES_DIR.glob("*.sql"))
|
|
|
|
|
|
# =============================================================================
|
|
# LOGGING
|
|
# =============================================================================
|
|
|
|
def setup_logging(log_file: Optional[str] = None) -> logging.Logger:
|
|
"""Configure logging with file + console handlers."""
|
|
if log_file is None:
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
log_file = f"db_helper_{timestamp}.log"
|
|
|
|
logger = logging.getLogger("db_helper")
|
|
logger.setLevel(logging.DEBUG)
|
|
|
|
# File handler — everything
|
|
fh = logging.FileHandler(log_file)
|
|
fh.setLevel(logging.DEBUG)
|
|
|
|
# Console handler — INFO and above
|
|
ch = logging.StreamHandler()
|
|
ch.setLevel(logging.INFO)
|
|
|
|
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
|
|
fh.setFormatter(formatter)
|
|
ch.setFormatter(formatter)
|
|
|
|
logger.addHandler(fh)
|
|
logger.addHandler(ch)
|
|
|
|
return logger
|
|
|
|
|
|
# =============================================================================
|
|
# DATABASE CONNECTION
|
|
# =============================================================================
|
|
|
|
def connect_db(config_key: str) -> DatabaseConnection:
|
|
"""
|
|
Connect to a database using a named block from config.json.
|
|
|
|
Args:
|
|
config_key: "source_db" or "target_db"
|
|
|
|
Returns:
|
|
Connected DatabaseConnection instance.
|
|
"""
|
|
logger = logging.getLogger("db_helper")
|
|
config = load_config()
|
|
|
|
if config_key not in config:
|
|
raise ValueError(
|
|
f"Config key '{config_key}' not found in {CONFIG_PATH}. "
|
|
f"Available keys: {[k for k in config if k.endswith('_db')]}"
|
|
)
|
|
|
|
db_config = config[config_key]
|
|
logger.info(
|
|
f"Connecting to {db_config['database']} on {db_config['server']} "
|
|
f"({config_key})"
|
|
)
|
|
return DatabaseConnection(db_config)
|
|
|
|
|
|
# =============================================================================
|
|
# SELECT
|
|
# =============================================================================
|
|
|
|
def run_select(
|
|
db: DatabaseConnection,
|
|
query: str,
|
|
params: Optional[tuple] = None,
|
|
preview_rows: int = 10,
|
|
) -> List[Dict[str, Any]]:
|
|
"""
|
|
Execute a SELECT query, log it, print a preview, and return results.
|
|
|
|
Args:
|
|
db: Active DatabaseConnection
|
|
query: SQL SELECT statement
|
|
params: Optional query parameters
|
|
preview_rows: How many rows to preview on the console (0 = skip)
|
|
|
|
Returns:
|
|
List of row dicts.
|
|
"""
|
|
logger = logging.getLogger("db_helper")
|
|
logger.info(f"Running SELECT:\n{query}")
|
|
if params:
|
|
logger.debug(f" Params: {params}")
|
|
|
|
rows = db.execute_query(query, params)
|
|
logger.info(f" Returned {len(rows)} row(s)")
|
|
|
|
if rows and preview_rows > 0:
|
|
_print_table(rows[:preview_rows])
|
|
if len(rows) > preview_rows:
|
|
print(f" ... and {len(rows) - preview_rows} more rows")
|
|
|
|
return rows
|
|
|
|
|
|
def _print_table(rows: List[Dict[str, Any]]) -> None:
|
|
"""Pretty-print a list of row dicts as an aligned console table."""
|
|
if not rows:
|
|
return
|
|
columns = list(rows[0].keys())
|
|
# Compute column widths (header vs data)
|
|
widths = {col: len(col) for col in columns}
|
|
str_rows = []
|
|
for row in rows:
|
|
str_row = {col: str(row[col]) for col in columns}
|
|
for col in columns:
|
|
widths[col] = max(widths[col], len(str_row[col]))
|
|
str_rows.append(str_row)
|
|
|
|
header = " | ".join(col.ljust(widths[col]) for col in columns)
|
|
sep = "-+-".join("-" * widths[col] for col in columns)
|
|
print(f" {header}")
|
|
print(f" {sep}")
|
|
for sr in str_rows:
|
|
line = " | ".join(sr[col].ljust(widths[col]) for col in columns)
|
|
print(f" {line}")
|
|
|
|
|
|
# =============================================================================
|
|
# CONFIRMATION GATE
|
|
# =============================================================================
|
|
|
|
def preview_and_confirm(
|
|
action: str,
|
|
sql: str,
|
|
rows: List[Dict[str, Any]],
|
|
preview_rows: int = 5,
|
|
dry_run: bool = False,
|
|
total_row_count: Optional[int] = None,
|
|
) -> bool:
|
|
"""
|
|
Show the user what's about to happen and ask for confirmation.
|
|
|
|
Args:
|
|
action: Short description ("INSERT into Documents")
|
|
sql: The SQL statement that will be executed
|
|
rows: The data rows that will be written (or a sample of them)
|
|
preview_rows: How many sample rows to display
|
|
dry_run: If True, show the preview but return False without prompting
|
|
total_row_count: If `rows` is only a sample, pass the full count
|
|
here so the prompt shows the real number of rows
|
|
that will be written.
|
|
|
|
Returns:
|
|
True if user confirms, False otherwise.
|
|
"""
|
|
logger = logging.getLogger("db_helper")
|
|
full_count = total_row_count if total_row_count is not None else len(rows)
|
|
|
|
print("\n" + "=" * 60)
|
|
print(f" ACTION: {action}")
|
|
print(f" ROWS: {full_count}")
|
|
print(f" SQL: {sql}")
|
|
print("=" * 60)
|
|
|
|
if rows and preview_rows > 0:
|
|
shown = min(preview_rows, len(rows))
|
|
print(f"\n Sample data ({shown} of {full_count}):")
|
|
_print_table(rows[:preview_rows])
|
|
|
|
if dry_run:
|
|
print("\n [DRY RUN] — no changes will be made.")
|
|
logger.info(f"[DRY RUN] Would {action} ({full_count} rows)")
|
|
return False
|
|
|
|
print()
|
|
response = input(" Execute this? [y/N]: ").strip().lower()
|
|
if response in ("y", "yes"):
|
|
logger.info(f"User confirmed: {action} ({full_count} rows)")
|
|
return True
|
|
else:
|
|
logger.info(f"User declined: {action}")
|
|
print(" Aborted.")
|
|
return False
|
|
|
|
|
|
# =============================================================================
|
|
# INSERT
|
|
# =============================================================================
|
|
|
|
def _parse_insert_columns(sql: str) -> Optional[List[str]]:
|
|
"""
|
|
Extract the column name list from a standard INSERT statement.
|
|
|
|
Matches 'INSERT INTO <table> (col1, col2, ...) VALUES ...'. Returns
|
|
None if the INSERT has no explicit column list (e.g. 'INSERT INTO t
|
|
VALUES (...)') so the caller can fall back to positional labels.
|
|
"""
|
|
import re
|
|
# Match the first parenthesised group after INSERT INTO <table>
|
|
# Table name may be bracketed/dotted: [db].[dbo].[Table]
|
|
m = re.search(
|
|
r"INSERT\s+INTO\s+[\[\]\w\.]+\s*\(([^)]+)\)\s*VALUES",
|
|
sql,
|
|
re.IGNORECASE | re.DOTALL,
|
|
)
|
|
if not m:
|
|
return None
|
|
cols = [c.strip().strip("[]") for c in m.group(1).split(",")]
|
|
return [c for c in cols if c]
|
|
|
|
|
|
def _build_insert_preview_rows(
|
|
rows: List[Dict[str, Any]],
|
|
params_builder: Callable[[Dict[str, Any]], tuple],
|
|
column_names: Optional[List[str]],
|
|
) -> List[Dict[str, Any]]:
|
|
"""
|
|
Apply params_builder to each row and return dicts keyed by the INSERT's
|
|
column names — so the preview shows exactly what will be written.
|
|
Falls back to positional labels ('col_0', 'col_1', ...) if the column
|
|
list couldn't be parsed.
|
|
"""
|
|
preview = []
|
|
for row in rows:
|
|
params = params_builder(row)
|
|
if column_names and len(column_names) == len(params):
|
|
preview.append(dict(zip(column_names, params)))
|
|
else:
|
|
preview.append({f"col_{i}": v for i, v in enumerate(params)})
|
|
return preview
|
|
|
|
|
|
def run_insert(
|
|
db: DatabaseConnection,
|
|
insert_sql: str,
|
|
rows: List[Dict[str, Any]],
|
|
params_builder: Callable[[Dict[str, Any]], tuple],
|
|
action: str = "INSERT rows",
|
|
dry_run: bool = False,
|
|
preview_columns: Optional[List[str]] = None,
|
|
) -> Dict[str, int]:
|
|
"""
|
|
Insert rows with confirmation, logging, and transaction safety.
|
|
|
|
Args:
|
|
db: Active DatabaseConnection
|
|
insert_sql: Parameterised INSERT statement (use ? placeholders)
|
|
rows: Row dicts (typically from run_select, possibly transformed)
|
|
params_builder: Callable that converts a row dict into the param
|
|
tuple matching the INSERT's ? placeholders
|
|
action: Description shown in the confirmation prompt
|
|
dry_run: If True, preview only — don't execute
|
|
preview_columns: Optional list of column names for the preview
|
|
display. If None, parsed from the INSERT SQL.
|
|
|
|
Returns:
|
|
Dict with counts: inserted, skipped, errors
|
|
"""
|
|
logger = logging.getLogger("db_helper")
|
|
stats = {"inserted": 0, "skipped": 0, "errors": 0}
|
|
|
|
if not rows:
|
|
logger.info("No rows to insert.")
|
|
return stats
|
|
|
|
# Build the preview from the ACTUAL params that will be sent to the DB
|
|
# (not the raw SELECT rows) so users see what will really be inserted.
|
|
column_names = preview_columns or _parse_insert_columns(insert_sql)
|
|
preview_rows = _build_insert_preview_rows(
|
|
rows[:5], params_builder, column_names
|
|
)
|
|
# Attach the full row count so preview_and_confirm can report it
|
|
# accurately even though we only transformed the sample.
|
|
if not preview_and_confirm(
|
|
action, insert_sql, preview_rows,
|
|
total_row_count=len(rows),
|
|
dry_run=dry_run,
|
|
):
|
|
return stats
|
|
|
|
# Execute row-by-row inside a single transaction so we can log per-row
|
|
# and rollback cleanly on failure.
|
|
total = len(rows)
|
|
# Update progress ~50 times across the batch (minimum every row for
|
|
# tiny batches). Keeps the terminal feeling alive without spamming.
|
|
progress_step = max(1, total // 50)
|
|
print() # blank line before the progress indicator
|
|
|
|
for i, row in enumerate(rows, 1):
|
|
params = params_builder(row)
|
|
try:
|
|
db.execute_non_query_no_commit(insert_sql, params)
|
|
stats["inserted"] += 1
|
|
logger.debug(f" [{i}/{total}] Inserted: {params}")
|
|
except Exception as exc:
|
|
err_msg = str(exc)
|
|
if "duplicate" in err_msg.lower() or "violation of" in err_msg.lower():
|
|
stats["skipped"] += 1
|
|
logger.warning(f" [{i}/{total}] Skipped (duplicate): {params}")
|
|
else:
|
|
stats["errors"] += 1
|
|
logger.error(f" [{i}/{total}] Error: {exc} | params={params}")
|
|
|
|
# Live progress (overwrites the same line)
|
|
if i % progress_step == 0 or i == total:
|
|
pct = (i / total) * 100
|
|
print(
|
|
f"\r Progress: {i}/{total} ({pct:5.1f}%) "
|
|
f"inserted={stats['inserted']} skipped={stats['skipped']} "
|
|
f"errors={stats['errors']}",
|
|
end="",
|
|
flush=True,
|
|
)
|
|
print() # end the progress line
|
|
|
|
# Commit or rollback
|
|
if stats["errors"] == 0:
|
|
db.commit()
|
|
logger.info(
|
|
f"Committed. Inserted: {stats['inserted']}, "
|
|
f"Skipped: {stats['skipped']}"
|
|
)
|
|
else:
|
|
print(
|
|
f"\n {stats['errors']} error(s) occurred. "
|
|
f"Commit anyway? [y/N]: ", end=""
|
|
)
|
|
resp = input().strip().lower()
|
|
if resp in ("y", "yes"):
|
|
db.commit()
|
|
logger.info(f"Committed with errors. {stats}")
|
|
else:
|
|
db.rollback()
|
|
stats["inserted"] = 0
|
|
logger.warning(f"Rolled back all inserts. {stats}")
|
|
print(" Rolled back.")
|
|
|
|
# Summary
|
|
print(f"\n Results: {stats}")
|
|
return stats
|
|
|
|
|
|
# =============================================================================
|
|
# PREDEFINED TASKS
|
|
# =============================================================================
|
|
# Each task is a function that receives (db, args) and orchestrates a
|
|
# SELECT → transform → INSERT workflow. Register new tasks in TASK_REGISTRY
|
|
# at the bottom of this section.
|
|
|
|
def task_copy_with_new_id(db: DatabaseConnection, args: argparse.Namespace) -> None:
|
|
"""
|
|
Example task: query rows, swap the ID, and insert as new rows.
|
|
|
|
Customise the SELECT, INSERT, and transform logic below to match your
|
|
actual table and columns.
|
|
"""
|
|
logger = logging.getLogger("db_helper")
|
|
|
|
# ----- 1. SELECT the source rows -----
|
|
select_sql = """
|
|
SELECT TOP 10
|
|
ID, Name, Description
|
|
FROM YourTable
|
|
WHERE SomeCondition = 1
|
|
"""
|
|
rows = run_select(db, select_sql)
|
|
|
|
if not rows:
|
|
logger.info("No source rows found — nothing to do.")
|
|
return
|
|
|
|
# ----- 2. Transform: build new rows with modified values -----
|
|
# Adjust this logic to match your actual needs (new IDs, tweaked
|
|
# strings, mapped values, etc.)
|
|
new_rows = []
|
|
for row in rows:
|
|
new_row = dict(row) # shallow copy
|
|
new_row["ID"] = row["ID"] + 1000 # example: offset the ID
|
|
# new_row["Name"] = row["Name"] # keep as-is, or modify
|
|
new_rows.append(new_row)
|
|
|
|
# ----- 3. INSERT the transformed rows -----
|
|
insert_sql = """
|
|
INSERT INTO YourTable (ID, Name, Description)
|
|
VALUES (?, ?, ?)
|
|
"""
|
|
|
|
run_insert(
|
|
db,
|
|
insert_sql,
|
|
new_rows,
|
|
params_builder=lambda r: (r["ID"], r["Name"], r["Description"]),
|
|
action="INSERT transformed rows into YourTable",
|
|
dry_run=args.dry_run,
|
|
)
|
|
|
|
|
|
def task_check_vv50(db: DatabaseConnection, args: argparse.Namespace) -> None:
|
|
"""
|
|
For every document that has VariableID=57 (in DWS paths), check whether
|
|
it also has a VariableValue row for VariableID=50.
|
|
|
|
Steps:
|
|
1. Run DWS_GET_VV-57.sql → list of documents
|
|
2. For each DocumentID, run Get_All_VV_Per_DocID.sql
|
|
3. Log whether VariableID=50 is present or missing
|
|
"""
|
|
logger = logging.getLogger("db_helper")
|
|
|
|
# ----- Step 1: Get all documents with VV-57 -----
|
|
step1_sql = load_query("DWS_GET_VV-57")
|
|
docs = run_select(db, step1_sql, preview_rows=5)
|
|
|
|
if not docs:
|
|
logger.info("No documents returned — nothing to check.")
|
|
return
|
|
|
|
# ----- Step 2 & 3: Check each document for VV-50 -----
|
|
step2_sql = load_query("Get_All_VV_Per_DocID")
|
|
|
|
has_vv50 = []
|
|
missing_vv50 = []
|
|
|
|
total = len(docs)
|
|
for i, doc in enumerate(docs, 1):
|
|
doc_id = doc["DocumentID"]
|
|
file_name = doc.get("FileName", "")
|
|
full_path = doc.get("FullVaultPath", file_name)
|
|
|
|
var_rows = db.execute_query(step2_sql, (doc_id,))
|
|
var_ids = {row["VariableID"] for row in var_rows}
|
|
|
|
if 50 in var_ids:
|
|
has_vv50.append(doc)
|
|
logger.debug(
|
|
f" [{i}/{total}] VV-50 EXISTS | DocID={doc_id} | {full_path}"
|
|
)
|
|
else:
|
|
missing_vv50.append(doc)
|
|
logger.info(
|
|
f" [{i}/{total}] VV-50 MISSING | DocID={doc_id} | {full_path}"
|
|
)
|
|
|
|
# ----- Summary -----
|
|
logger.info("=" * 60)
|
|
logger.info("VV-50 CHECK COMPLETE")
|
|
logger.info("=" * 60)
|
|
logger.info(f"Total documents checked: {total}")
|
|
logger.info(f" Has VV-50: {len(has_vv50)}")
|
|
logger.info(f" Missing VV-50: {len(missing_vv50)}")
|
|
|
|
if has_vv50:
|
|
# Write missing list to file for follow-up
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
out_file = f"has_vv50_{timestamp}.txt"
|
|
with open(out_file, "w", encoding="utf-8") as f:
|
|
f.write("DocumentID,FileName,FullVaultPath\n")
|
|
for doc in has_vv50:
|
|
f.write(
|
|
f"{doc['DocumentID']},"
|
|
f"{doc.get('FileName', '')},"
|
|
f"{doc.get('FullVaultPath', '')}\n"
|
|
)
|
|
logger.info(f"Has VV-50 list saved to: {out_file}")
|
|
|
|
def copy_57_to_50(db: DatabaseConnection, args: argparse.Namespace) -> None:
|
|
"""
|
|
DWS had a variable called Number, but we want that info to show up on the data cards
|
|
in the field for "Drawing Number"
|
|
|
|
That means that anything in the DWS folder that has a VariableID = 57, we are going to take all of that
|
|
information and insert a new row in the VariableValues table, where everything is the same except the VariableID = 50
|
|
|
|
The one caveat is that we don't want to insert a row for VariableID = 50 if one already exists. For that we are going to reference
|
|
the has_vv50_{date}.txt file and exlude those document ID's
|
|
|
|
Steps:
|
|
1. Run DWS_VV-57_FullList.sql → list of documents
|
|
2. For each row returned in Step 1. check and see if DocumentID exists in the has_vv50_{date}.txt file
|
|
3. If it doesnt already exist insert a new row into VariableValue with all of the same info only change the VariableID to 50
|
|
"""
|
|
logger = logging.getLogger("db_helper")
|
|
|
|
# ----- Step 1: Fetch all VV-57 rows in DWS paths -----
|
|
rows_57 = run_select(
|
|
db, load_query("DWS_VV-57_FullList"), preview_rows=5
|
|
)
|
|
if not rows_57:
|
|
logger.info("No VV-57 rows found — nothing to copy.")
|
|
return
|
|
|
|
# ----- Step 2: Load DocumentIDs that already have VV-50 -----
|
|
exclude_file = args.exclude_file or _find_latest_has_vv50_file()
|
|
excluded_doc_ids = _load_excluded_doc_ids(exclude_file)
|
|
|
|
# ----- Step 3: Filter out rows whose DocumentID already has VV-50 -----
|
|
rows_to_insert = [
|
|
r for r in rows_57 if r["DocumentID"] not in excluded_doc_ids
|
|
]
|
|
skipped = len(rows_57) - len(rows_to_insert)
|
|
logger.info(
|
|
f"After filter: {len(rows_to_insert)} rows to insert, "
|
|
f"{skipped} skipped (DocumentID already has VV-50)"
|
|
)
|
|
|
|
if not rows_to_insert:
|
|
logger.info("Nothing to insert after filtering.")
|
|
return
|
|
|
|
# ----- Step 4: Insert (with preview + confirmation) -----
|
|
def build_params(row: Dict[str, Any]) -> tuple:
|
|
# Parameter order MUST match INSERT_VV50_Copy.sql:
|
|
# VariableID, DocumentID, ProjectID, RevisionNo, ConfigurationID,
|
|
# ValueText, ValueInt, ValueFloat, ValueDate, ValueCache, IsLongText
|
|
return (
|
|
50, # override VariableID
|
|
row["DocumentID"],
|
|
row["ProjectID"],
|
|
row["RevisionNo"],
|
|
row["ConfigurationID"],
|
|
row["ValueText"],
|
|
row["ValueInt"],
|
|
row["ValueFloat"],
|
|
row["ValueDate"],
|
|
row["ValueCache"],
|
|
row["IsLongText"],
|
|
)
|
|
|
|
run_insert(
|
|
db,
|
|
load_query("INSERT_VV50_Copy"),
|
|
rows_to_insert,
|
|
params_builder=build_params,
|
|
action="INSERT VariableID=50 copies of DWS VV-57 rows",
|
|
dry_run=args.dry_run,
|
|
)
|
|
|
|
|
|
def _find_latest_has_vv50_file() -> Optional[str]:
|
|
"""Find the most recent has_vv50_*.txt file in the current directory."""
|
|
logger = logging.getLogger("db_helper")
|
|
matches = sorted(glob.glob("has_vv50_*.txt"))
|
|
if not matches:
|
|
return None
|
|
latest = matches[-1]
|
|
logger.info(f"Auto-detected exclusion file: {latest}")
|
|
return latest
|
|
|
|
|
|
def _load_excluded_doc_ids(path: Optional[str]) -> Set[int]:
|
|
"""
|
|
Load DocumentIDs from a has_vv50_*.txt file (CSV format with header).
|
|
|
|
Returns an empty set if no file is provided and prompts the user to
|
|
confirm they want to proceed without any exclusions.
|
|
"""
|
|
logger = logging.getLogger("db_helper")
|
|
|
|
if not path:
|
|
logger.warning(
|
|
"No exclusion file found — ALL VV-57 DocumentIDs will get a "
|
|
"VV-50 copy, including ones that may already have VV-50."
|
|
)
|
|
resp = input(
|
|
" Proceed without an exclusion list? [y/N]: "
|
|
).strip().lower()
|
|
if resp not in ("y", "yes"):
|
|
logger.info("User aborted — no exclusion file.")
|
|
raise SystemExit(1)
|
|
return set()
|
|
|
|
excluded: Set[int] = set()
|
|
with open(path, "r", encoding="utf-8") as f:
|
|
header = f.readline() # discard "DocumentID,FileName,FullVaultPath"
|
|
for line in f:
|
|
line = line.strip()
|
|
if not line:
|
|
continue
|
|
first = line.split(",", 1)[0].strip()
|
|
if first.isdigit():
|
|
excluded.add(int(first))
|
|
logger.info(f"Loaded {len(excluded)} DocumentIDs to exclude from {path}")
|
|
return excluded
|
|
|
|
|
|
# Register tasks here — maps --task name to function
|
|
TASK_REGISTRY: Dict[str, Callable] = {
|
|
"copy_with_new_id": task_copy_with_new_id,
|
|
"check_vv50": task_check_vv50,
|
|
"copy_57_to_50": copy_57_to_50
|
|
}
|
|
|
|
|
|
# =============================================================================
|
|
# CLI
|
|
# =============================================================================
|
|
|
|
def parse_arguments() -> argparse.Namespace:
|
|
parser = argparse.ArgumentParser(
|
|
description="Database helper for PDM migration — interactive SQL tasks",
|
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
epilog="""
|
|
Examples:
|
|
python db_helper.py --db target_db --task copy_with_new_id
|
|
python db_helper.py --db target_db --task copy_with_new_id --dry-run
|
|
python db_helper.py --db source_db --query get_var47
|
|
python db_helper.py --db source_db --query "SELECT TOP 10 * FROM Documents"
|
|
python db_helper.py --list-queries
|
|
""",
|
|
)
|
|
parser.add_argument(
|
|
"--db",
|
|
help='Config key for the database: "source_db" or "target_db"',
|
|
)
|
|
parser.add_argument(
|
|
"--task",
|
|
choices=list(TASK_REGISTRY.keys()),
|
|
help="Name of a predefined task to run",
|
|
)
|
|
parser.add_argument(
|
|
"--query",
|
|
help=(
|
|
"Run a SELECT query. Pass a query name to load from "
|
|
"helpers/queries/<name>.sql, or pass raw SQL in quotes."
|
|
),
|
|
)
|
|
parser.add_argument(
|
|
"--dry-run",
|
|
action="store_true",
|
|
help="Preview what would happen without executing writes",
|
|
)
|
|
parser.add_argument(
|
|
"--list-queries",
|
|
action="store_true",
|
|
help="List all available saved queries and exit",
|
|
)
|
|
parser.add_argument(
|
|
"--exclude-file",
|
|
help=(
|
|
"Path to a has_vv50_*.txt file whose DocumentIDs should be "
|
|
"excluded from copy_57_to_50. If omitted, the most recent "
|
|
"has_vv50_*.txt in the current directory is used."
|
|
),
|
|
)
|
|
return parser.parse_args()
|
|
|
|
|
|
def _resolve_query(query_arg: str) -> str:
|
|
"""
|
|
Resolve a --query argument to SQL text.
|
|
|
|
If it looks like a SQL statement (contains a space), use it as-is.
|
|
Otherwise treat it as a saved query name and load from queries/<name>.sql.
|
|
"""
|
|
if " " in query_arg:
|
|
return query_arg
|
|
return load_query(query_arg)
|
|
|
|
|
|
def main() -> int:
|
|
args = parse_arguments()
|
|
|
|
# --list-queries doesn't need a DB connection or logging
|
|
if args.list_queries:
|
|
queries = list_queries()
|
|
if queries:
|
|
print(f"Available queries in {QUERIES_DIR}:")
|
|
for name in queries:
|
|
# Show the first line of each .sql as a description
|
|
sql_path = QUERIES_DIR / f"{name}.sql"
|
|
first_line = sql_path.read_text(encoding="utf-8").split("\n")[0]
|
|
print(f" {name:30s} {first_line}")
|
|
else:
|
|
print(f"No .sql files found in {QUERIES_DIR}")
|
|
return 0
|
|
|
|
if not args.db:
|
|
print("Error: --db is required (unless using --list-queries)")
|
|
return 1
|
|
|
|
logger = setup_logging()
|
|
|
|
logger.info("=" * 60)
|
|
logger.info("DB HELPER")
|
|
logger.info("=" * 60)
|
|
logger.info(f"Database: {args.db}")
|
|
logger.info(f"Task: {args.task or '(ad-hoc query)'}")
|
|
logger.info(f"Dry run: {args.dry_run}")
|
|
|
|
db = connect_db(args.db)
|
|
|
|
try:
|
|
if args.query:
|
|
sql = _resolve_query(args.query)
|
|
logger.info(f"Resolved query:\n{sql}")
|
|
run_select(db, sql)
|
|
|
|
elif args.task:
|
|
task_fn = TASK_REGISTRY[args.task]
|
|
task_fn(db, args)
|
|
|
|
else:
|
|
logger.error("Provide either --task, --query, or --list-queries")
|
|
return 1
|
|
|
|
except FileNotFoundError as exc:
|
|
logger.error(str(exc))
|
|
return 1
|
|
except KeyboardInterrupt:
|
|
logger.warning("Interrupted by user")
|
|
db.rollback()
|
|
return 130
|
|
except Exception:
|
|
logger.exception("Unhandled exception")
|
|
db.rollback()
|
|
return 1
|
|
finally:
|
|
db.close()
|
|
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
sys.exit(main())
|