diff --git a/job_script_tagger.py b/job_script_tagger.py new file mode 100644 index 0000000..d3087ed --- /dev/null +++ b/job_script_tagger.py @@ -0,0 +1,1174 @@ +#!/usr/bin/env python3 + +import os +import re +import csv +import json +import sqlite3 +import tarfile +import logging +import threading +import argparse +import pandas as pd +import time +from pathlib import Path +from typing import Dict, List, Tuple, Set, Any, Optional, Union +from concurrent.futures import ThreadPoolExecutor, as_completed +import webdataset as wds +from PIL import Image +import io +import hashlib + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', + handlers=[ + logging.FileHandler("job_tagger_dataset.log"), + logging.StreamHandler() + ] +) +logger = logging.getLogger("JobScriptTagger") + +class JobScriptTagger: + """Process job scripts and images to create ML datasets""" + + def __init__( + self, + db_path: str, + keywords_csv: str, + output_dir: str, + input_mode: str = 'tar', + input_paths: List[str] = None, + num_workers: int = None, + shard_size: int = 1000, + resume: bool = False, + max_debug_samples: int = 100, + start_time_limit: Optional[int] = None, + sqlite_output_path: Optional[str] = None + ): + """ + Initialize the tagger + + Args: + db_path: Path to the SQLite database + keywords_csv: Path to the CSV file with application names and keywords + output_dir: Directory for output files + input_mode: 'tar' for tar archives or 'dir' for directories + input_paths: List of tar archives or directories containing images + num_workers: Number of parallel workers (defaults to CPU count) + shard_size: Number of samples per output shard + resume: Whether to resume from a previous run + max_debug_samples: Maximum number of debug samples to keep + start_time_limit: Unix timestamp limit for filtering jobs (None for no limit) + sqlite_output_path: Path to SQLite database for output (if None, use WebDataset mode) + """ + self.db_path = db_path + self.keywords_csv = keywords_csv + self.output_dir = output_dir + self.input_mode = input_mode + self.input_paths = input_paths or [] + self.num_workers = num_workers or os.cpu_count() + self.shard_size = shard_size + self.resume = resume + self.max_debug_samples = max_debug_samples + self.start_time_limit = start_time_limit + self.sqlite_output_path = sqlite_output_path + + # Ensure output directory exists (only needed for WebDataset mode) + if not sqlite_output_path: + os.makedirs(output_dir, exist_ok=True) + + # Thread-local storage for database connections + self.local_db = threading.local() + + self.app_keywords = {} # Maps app names to lists of keywords + + # Track processed folders and progress + if sqlite_output_path: + self.progress_file = 'progress_sqlite.json' + else: + self.progress_file = os.path.join(output_dir, 'progress.json') + self.processed_folders = set() + self.current_shard = 0 + self.stats = { + 'valid_count': 0, + 'multi_tag_count': 0, + 'unknown_count': 0, + 'error_count': 0, + 'total_processed_folders': 0, + 'total_folders': 0, + 'total_files': 0, + 'processed_files': 0 + } + + # User statistics tracking + self.user_stats = {} # tag -> set of users + self.user_stats_recent = {} # tag -> set of users (within time limit) + + # Debug sample collections + self.multi_tag_samples = [] + + # Multi-label logging - keep track of files per shard + self.multi_label_logs = {} # shard_id -> list of filenames + + # Initialize or load progress + self._init_progress() + + logger.info(f"Initialized with {self.num_workers} workers, input mode: {input_mode}") + + def _track_user_stats(self, tag: str, user: str, start_time: Optional[int] = None): + """ + Track user statistics for a given tag + + Args: + tag: Application tag + user: HPC user + start_time: Job start time (Unix timestamp) + """ + if not user: + return + + # Track overall user stats + if tag not in self.user_stats: + self.user_stats[tag] = set() + self.user_stats[tag].add(user) + + # Track recent user stats if start_time is within limit + if self.start_time_limit is not None and start_time is not None: + if start_time >= self.start_time_limit: + if tag not in self.user_stats_recent: + self.user_stats_recent[tag] = set() + self.user_stats_recent[tag].add(user) + + def _init_progress(self): + """Initialize or load progress tracking""" + if self.resume and os.path.exists(self.progress_file): + try: + with open(self.progress_file, 'r') as f: + progress_data = json.load(f) + + self.processed_folders = set(progress_data.get('processed_folders', [])) + self.current_shard = progress_data.get('current_shard', 0) + self.stats = progress_data.get('stats', self.stats) + + # Load user stats + user_stats_data = progress_data.get('user_stats', {}) + self.user_stats = {tag: set(users) for tag, users in user_stats_data.items()} + + user_stats_recent_data = progress_data.get('user_stats_recent', {}) + self.user_stats_recent = {tag: set(users) for tag, users in user_stats_recent_data.items()} + + logger.info(f"Resumed from previous run: {len(self.processed_folders)} folders processed, " + f"current shard: {self.current_shard}") + except Exception as e: + logger.warning(f"Failed to load progress file: {e}. Starting fresh.") + self._save_progress() + else: + self._save_progress() + + def _save_progress(self): + """Save current progress to file""" + progress_data = { + 'processed_folders': list(self.processed_folders), + 'current_shard': self.current_shard, + 'stats': self.stats, + 'user_stats': {tag: list(users) for tag, users in self.user_stats.items()}, + 'user_stats_recent': {tag: list(users) for tag, users in self.user_stats_recent.items()}, + 'last_updated': time.strftime('%Y-%m-%d %H:%M:%S') + } + + with open(self.progress_file, 'w') as f: + json.dump(progress_data, f, indent=2) + + logger.debug(f"Progress saved: {self.stats['total_processed_folders']} folders processed") + + def get_db_connection(self): + """ + Get a thread-local database connection + + Returns: + SQLite connection object + """ + # Check if this thread already has a connection + if not hasattr(self.local_db, 'conn') or self.local_db.conn is None: + try: + self.local_db.conn = sqlite3.connect(self.db_path) + logger.debug(f"Created new database connection for thread {threading.get_ident()}") + except sqlite3.Error as e: + logger.error(f"Database connection error in thread {threading.get_ident()}: {e}") + raise + + return self.local_db.conn + + def close_db_connections(self): + """Close all database connections""" + # Since connections are thread-local, we can only close the one in the current thread + if hasattr(self.local_db, 'conn') and self.local_db.conn is not None: + self.local_db.conn.close() + self.local_db.conn = None + logger.info("Database connection closed") + + def load_keywords(self): + """Load application keywords from CSV file""" + try: + # Read CSV file + df = pd.read_csv(self.keywords_csv) + + # Process each row + for _, row in df.iterrows(): + app_name = row['Name'].strip() + # Split keywords by semicolon and strip whitespace + keywords = [kw.strip() for kw in row['Keyword'].split(';') if kw.strip()] + if app_name and keywords: + self.app_keywords[app_name] = keywords + + logger.info(f"Loaded {len(self.app_keywords)} applications with keywords") + + # Log the apps and their keywords for verification + for app_name, keywords in self.app_keywords.items(): + logger.debug(f"App: {app_name}, Keywords: {keywords}") + + except Exception as e: + logger.error(f"Error loading keywords from {self.keywords_csv}: {e}") + raise + + def extract_job_id(self, filename: str) -> Optional[int]: + """ + Extract job ID from image filename + + Args: + filename: Image filename in format {job_id}_{timestamp}_color.png + + Returns: + Job ID if successfully extracted, None otherwise + """ + match = re.match(r'(\d+)_\d+_color\.png$', os.path.basename(filename)) + if match: + return int(match.group(1)) + return None + + def get_job_script(self, job_id: int) -> Optional[Dict[str, Any]]: + """ + Fetch job script from database + + Args: + job_id: Job ID + + Returns: + Dict with job script and metadata if found, None otherwise + """ + try: + conn = self.get_db_connection() + cursor = conn.cursor() + + query = "SELECT meta_data, duration, hpc_user, start_time FROM job WHERE job_id = ?" + cursor.execute(query, (job_id,)) + result = cursor.fetchone() + + if result: + meta_data = json.loads(result[0]) if result[0] else {} + duration = result[1] + hpc_user = result[2] + start_time = result[3] + script = meta_data.get('jobScript', '') + name = meta_data.get('jobName', '') + + # Check start_time limit if specified + if self.start_time_limit is not None and start_time is not None: + if start_time < self.start_time_limit: + return None # Job is too old + + return { + 'script': script, + 'meta_data': meta_data, + 'duration': duration, + 'hpc_user': hpc_user, + 'start_time': start_time, + 'interactive': True if name == "interactive" else False + } + return None + + except sqlite3.Error as e: + logger.error(f"Error fetching job script for job {job_id}: {e}") + return None + + def is_valid_script(self, script: str) -> bool: + """ + Check if a script is valid (starts with #!) + + Args: + script: Job script text + + Returns: + True if script is valid, False otherwise + """ + if not script: + return False + + # Check if script starts with #! + return script.strip().startswith('#!') + + def filter_script_comments(self, script: str) -> Tuple[str, List[str]]: + """ + Filter out commented and empty lines from script + + Args: + script: Job script text + + Returns: + Tuple of (filtered_script_text, filtered_lines) + """ + lines = script.splitlines() + filtered_lines = [] + + for line in lines: + stripped = line.strip() + # Keep shebang line and non-comment lines + if stripped and (stripped.startswith('#!') or not stripped.startswith('#')): + filtered_lines.append(line) + + filtered_script = '\n'.join(filtered_lines) + return filtered_script, filtered_lines + + def check_keywords(self, script: str) -> Tuple[List[str], Dict[str, List[str]]]: + """ + Check which applications' keywords match in the script and extract matching lines + + Args: + script: Job script text + + Returns: + Tuple of (matching_app_names, {app_name: [matching_lines]}) + """ + # Filter out commented lines + filtered_script, filtered_lines = self.filter_script_comments(script) + + script_lower = filtered_script.lower() + + matched_apps = [] + matched_lines = {} + + for app_name, keywords in self.app_keywords.items(): + app_matches = [] + + for keyword in keywords: + keyword_lower = keyword.lower() + + # Check if keyword is in filtered script + if keyword_lower in script_lower: + matched_apps.append(app_name) + + # Find lines containing this keyword in filtered lines + for line in filtered_lines: + if keyword_lower in line.lower(): + app_matches.append(line.strip()) + + # Break after first matching keyword for this app + break + + # Save unique matching lines for this app + if app_matches: + matched_lines[app_name] = list(set(app_matches)) + + return matched_apps, matched_lines + + def get_folder_path(self, file_info: Dict[str, str]) -> str: + """ + Get the subfolder path for a file + + Args: + file_info: File information dictionary + + Returns: + Subfolder path string + """ + if self.input_mode == 'tar': + # For tar files, use the parent directory in the archive + path_parts = file_info['file_path'].split('/') + if len(path_parts) > 1: + # Return the tar name and top-level directory + return f"{os.path.basename(file_info['tar_path'])}:{path_parts[0]}" + else: + return os.path.basename(file_info['tar_path']) + else: + # For directories, use the relative path's parent directory + path = Path(file_info['rel_path']) + if path.parent != Path('.'): + return f"{os.path.basename(file_info['dir_path'])}/{path.parent}" + else: + return os.path.basename(file_info['dir_path']) + + def list_input_folders(self) -> Dict[str, List[Dict[str, str]]]: + """ + List all image files from input sources grouped by subfolder + + Returns: + Dictionary mapping folder paths to lists of file information + """ + folders = {} + total_files = 0 + + if self.input_mode == 'tar': + # Process tar archives + for tar_path in self.input_paths: + if not os.path.exists(tar_path): + logger.warning(f"Tar archive not found: {tar_path}") + continue + + try: + with tarfile.open(tar_path, 'r') as tar: + for member in tar.getmembers(): + if member.name.endswith('_color.png'): + file_info = { + 'source_type': 'tar', + 'tar_path': tar_path, + 'file_path': member.name, + 'file_id': f"{tar_path}:{member.name}", + 'filename': os.path.basename(member.name) + } + + # Get folder path + folder_path = self.get_folder_path(file_info) + + # Skip if folder already processed + if folder_path in self.processed_folders and self.resume: + continue + + # Add to folders dictionary + if folder_path not in folders: + folders[folder_path] = [] + folders[folder_path].append(file_info) + total_files += 1 + + except Exception as e: + logger.error(f"Error reading tar archive {tar_path}: {e}") + else: + # Process directories + for dir_path in self.input_paths: + if not os.path.isdir(dir_path): + logger.warning(f"Directory not found: {dir_path}") + continue + + for root, _, files in os.walk(dir_path): + # Filter image files + image_files = [f for f in files if f.endswith('_color.png')] + + if not image_files: + continue + + for filename in image_files: + file_path = os.path.join(root, filename) + rel_path = os.path.relpath(file_path, dir_path) + + file_info = { + 'source_type': 'dir', + 'dir_path': dir_path, + 'file_path': file_path, + 'rel_path': rel_path, + 'file_id': f"{dir_path}:{rel_path}", + 'filename': filename + } + + # Get folder path + folder_path = self.get_folder_path(file_info) + + # Skip if folder already processed + if folder_path in self.processed_folders and self.resume: + continue + + # Add to folders dictionary + if folder_path not in folders: + folders[folder_path] = [] + folders[folder_path].append(file_info) + total_files += 1 + + # Update stats + self.stats['total_files'] = total_files + self.stats['total_folders'] = len(folders) + + logger.info(f"Found {len(folders)} unprocessed folders with {total_files} files") + return folders + + def process_file(self, file_info: Dict[str, str]) -> Dict[str, Any]: + """ + Process a single image file (worker function for thread pool) + + Args: + file_info: Dictionary with file information + + Returns: + Dictionary with processing results + """ + result = { + 'file_info': file_info, + 'job_id': None, + 'tags': [], + 'valid': False, + 'error': None, + 'script': None, + 'script_excerpt': None, + 'matching_lines': {}, + 'duration': None, + 'hpc_user': None, + 'start_time': None + } + + try: + # Extract job ID from filename + job_id = self.extract_job_id(file_info['filename']) + if not job_id: + result['error'] = f"Could not extract job ID from filename: {file_info['filename']}" + return result + + result['job_id'] = job_id + + # Get job script + job_data = self.get_job_script(job_id) + + # Check if job was filtered out by start_time_limit + if job_data is None: + result['error'] = f"Job {job_id} filtered out by start_time_limit" + result['tags'] = ["Filtered"] + result['valid'] = False + return result + + # Store duration from job data + result['duration'] = job_data.get('duration') if job_data else None + result['hpc_user'] = job_data.get('hpc_user') if job_data else None + result['start_time'] = job_data.get('start_time') if job_data else None + + if job_data['interactive']: + result['tags'] = ["Interactive"] + result['valid'] = True + # Track user stats for interactive jobs + if job_data.get('hpc_user'): + self._track_user_stats("Interactive", job_data['hpc_user'], job_data.get('start_time')) + return result + + # If no script, assign "Unknown" app label + if not job_data or not job_data['script']: + result['tags'] = ["Unknown"] + result['valid'] = False + # Track user stats for unknown jobs + if job_data and job_data.get('hpc_user'): + self._track_user_stats("Unknown", job_data['hpc_user'], job_data.get('start_time')) + # result['error'] = "No job script found" + return result + + script = job_data['script'] + + # If script doesn't start with #!, assign "Unknown" app label + if not self.is_valid_script(script): + result['tags'] = ["Unknown"] + result['valid'] = True # Mark as valid with "Unknown" tag + result['error'] = "Script doesn't start with #!" + # Track user stats for unknown jobs + if job_data.get('hpc_user'): + self._track_user_stats("Unknown", job_data['hpc_user'], job_data.get('start_time')) + return result + + # Store script excerpt for debugging + lines = script.strip().split('\n') + result['script_excerpt'] = '\n'.join(lines[:60]) + if len(lines) > 60: + result['script_excerpt'] += f"\n... [+ {len(lines) - 60} more lines]" + + # Store full script + result['script'] = script + + # Check keywords + matched_apps, matching_lines = self.check_keywords(script) + + # No matches, assign "Unknown" + if not matched_apps: + result['tags'] = ["Unknown"] + result['valid'] = True + result['error'] = "No matching keywords found" + # Track user stats for unknown jobs + if job_data.get('hpc_user'): + self._track_user_stats("Unknown", job_data['hpc_user'], job_data.get('start_time')) + else: + result['tags'] = matched_apps + result['matching_lines'] = matching_lines + result['valid'] = len(matched_apps) == 1 # Valid if exactly one tag + + # Track user stats for matched apps + if job_data.get('hpc_user'): + for tag in matched_apps: + self._track_user_stats(tag, job_data['hpc_user'], job_data.get('start_time')) + + except Exception as e: + result['error'] = str(e) + result['tags'] = ["Unknown"] # Default to "Unknown" on error + result['valid'] = False + logger.error(f"Error processing file {file_info['file_id']}: {e}") + + # Track user stats for error cases if we have job data + if 'job_data' in locals() and job_data and job_data.get('hpc_user'): + self._track_user_stats("Unknown", job_data['hpc_user'], job_data.get('start_time')) + + return result + + def read_image_data(self, file_info: Dict[str, str]) -> Optional[bytes]: + """ + Read image data from the source (tar or directory) + + Args: + file_info: File information dictionary + + Returns: + Image data as bytes if successful, None otherwise + """ + try: + if file_info['source_type'] == 'tar': + # Extract from tar archive + with tarfile.open(file_info['tar_path'], 'r') as tar: + img_file = tar.extractfile(file_info['file_path']) + if img_file: + return img_file.read() + else: + # Read from filesystem + with open(file_info['file_path'], 'rb') as f: + return f.read() + + return None + except Exception as e: + logger.error(f"Error reading image {file_info['file_id']}: {e}") + return None + + def log_multi_label_files(self, shard_idx: int, multi_label_samples: List[Dict[str, Any]]): + """ + Log files with multiple labels to a file + + Args: + shard_idx: Shard index number (ignored in SQLite mode) + multi_label_samples: List of samples with multiple labels + """ + if not multi_label_samples: + return + + if self.sqlite_output_path: + log_path = 'multi_label_sqlite.txt' + else: + log_path = os.path.join(self.output_dir, f"multi_label_shard_{shard_idx:06d}.txt") + + with open(log_path, 'w') as f: + f.write(f"Multi-label files in shard {shard_idx} ({len(multi_label_samples)} files):\n\n") + + for i, sample in enumerate(multi_label_samples, 1): + f.write(f"File {i}: {sample['file_info']['filename']}\n") + f.write(f"Job ID: {sample['job_id']}\n") + f.write(f"Labels: {', '.join(sample['tags'])}\n") + + # Write matching lines for each app + for app, lines in sample['matching_lines'].items(): + f.write(f"\n {app} matching lines:\n") + for line in lines: + f.write(f" {line}\n") + + f.write("\n" + "-" * 80 + "\n\n") + + logger.info(f"Logged {len(multi_label_samples)} multi-label files to {log_path}") + + def log_unknown_script_files(self, shard_idx: int, unknown_samples: List[Dict[str, Any]]): + """ + Log files with unknown scripts to a file + + Args: + shard_idx: Shard index number (ignored in SQLite mode) + unknown_samples: List of samples with unknown scripts + """ + if not unknown_samples: + return + + if self.sqlite_output_path: + log_path = 'unknown_script_sqlite.txt' + else: + log_path = os.path.join(self.output_dir, f"unknown_script_shard_{shard_idx:06d}.txt") + + with open(log_path, 'w', encoding='utf-8') as f: + f.write(f"Unknown script files in shard {shard_idx} ({len(unknown_samples)} files):\n\n") + + for i, sample in enumerate(unknown_samples, 1): + f.write(f"File {i}: {sample['file_info']['filename']}\n") + f.write(f"Job ID: {sample['job_id']}\n") + f.write(f"Error: {sample.get('error', '')}\n") + + # Filter out comment lines (starting with #) from script excerpt + script_excerpt = sample.get('script_excerpt', '') + if script_excerpt: + lines = script_excerpt.split('\n') + filtered_lines = [line for line in lines if not line.strip().startswith('#')] + filtered_excerpt = '\n'.join(filtered_lines) + f.write(f"Script excerpt (comments removed):\n{filtered_excerpt}\n") + else: + f.write("Script excerpt: (none)\n") + + f.write("\n" + "-" * 80 + "\n\n") + + logger.info(f"Logged {len(unknown_samples)} unknown script files to {log_path}") + + def write_shard(self, samples: List[Dict[str, Any]], shard_idx: int): + """ + Write samples to a WebDataset shard + + Args: + samples: List of samples to write + shard_idx: Shard index number + """ + if not samples: + return + + # Create shard filename with proper padding + shard_path = os.path.join( + self.output_dir, + f"roofline_dataset_{shard_idx:06d}.tar" + ) + + # Track multi-label and unknown samples for this shard + multi_label_samples = [] + unknown_samples = [] + + valid_count = 0 + with wds.TarWriter(shard_path) as sink: + for sample in samples: + try: + file_info = sample['file_info'] + + # Extract image data + img_data = self.read_image_data(file_info) + + if not img_data: + logger.warning(f"Could not read image for job {sample['job_id']}") + continue + + # Check if this is a multi-label sample + if len(sample['tags']) > 1: + multi_label_samples.append(sample) + # Check if this is an unknown sample + if sample['tags'] == ["Unknown"]: + unknown_samples.append(sample) + + # Create a unique key from original filename + key = file_info['filename'].split('.')[0] # Remove extension + + # Extract lines that contain the keywords for the chosen app + keyword_lines = [] + if sample['tags'][0] != "Unknown" and sample['tags'][0] in sample['matching_lines']: + keyword_lines = sample['matching_lines'][sample['tags'][0]] + + # Create sample + wds_sample = { + "__key__": key, + "png": img_data, + "json": json.dumps({ + "job_id": sample['job_id'], + "tags": sample['tags'], # Use the first tag for single-label classification + "script": keyword_lines, # Add lines containing detected keywords + "duration": sample.get('duration'), # Add job duration + "hpc_user": sample.get('hpc_user'), # Add HPC user + "start_time": sample.get('start_time') # Add job start time + }) + } + + # Write to WebDataset + sink.write(wds_sample) + valid_count += 1 + + except Exception as e: + logger.error(f"Error adding sample {sample.get('job_id')} to WebDataset: {e}") + + # Log multi-label samples for this shard + if multi_label_samples: + self.log_multi_label_files(shard_idx, multi_label_samples) + + # Log unknown samples for this shard + if unknown_samples: + self.log_unknown_script_files(shard_idx, unknown_samples) + + logger.info(f"Created shard {shard_path} with {valid_count} samples") + return valid_count + + def write_to_sqlite(self, samples: List[Dict[str, Any]]): + """ + Write samples to SQLite database + + Args: + samples: List of samples to write + """ + if not samples: + return + + try: + # Connect to output database + conn = sqlite3.connect(self.sqlite_output_path) + cursor = conn.cursor() + + # Create table if it doesn't exist + cursor.execute(''' + CREATE TABLE IF NOT EXISTS job_tags ( + job_id INTEGER PRIMARY KEY, + tags TEXT NOT NULL, + script_lines TEXT, + duration REAL, + hpc_user TEXT, + start_time INTEGER, + filename TEXT, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + ''') + + # Track multi-label and unknown samples for logging + multi_label_samples = [] + unknown_samples = [] + + # Insert samples + for sample in samples: + try: + # Check if this is a multi-label sample + if len(sample['tags']) > 1: + multi_label_samples.append(sample) + # Check if this is an unknown sample + if sample['tags'] == ["Unknown"]: + unknown_samples.append(sample) + + # Extract lines that contain the keywords for the chosen app + keyword_lines = [] + if sample['tags'] and sample['tags'][0] != "Unknown" and sample['tags'][0] in sample.get('matching_lines', {}): + keyword_lines = sample['matching_lines'][sample['tags'][0]] + + # Create filename key + filename = sample['file_info']['filename'].split('.')[0] # Remove extension + + cursor.execute(''' + INSERT OR REPLACE INTO job_tags + (job_id, tags, script_lines, duration, hpc_user, start_time, filename) + VALUES (?, ?, ?, ?, ?, ?, ?) + ''', ( + sample['job_id'], + json.dumps(sample['tags']), + json.dumps(keyword_lines), + sample.get('duration'), + sample.get('hpc_user'), + sample.get('start_time'), + filename + )) + + except Exception as e: + logger.error(f"Error inserting sample {sample.get('job_id')} to SQLite: {e}") + + conn.commit() + logger.info(f"Inserted {len(samples)} samples to SQLite database") + + # Log multi-label samples + if multi_label_samples: + self.log_multi_label_files(0, multi_label_samples) # Use 0 as shard_idx for SQLite mode + + # Log unknown samples + if unknown_samples: + self.log_unknown_script_files(0, unknown_samples) # Use 0 as shard_idx for SQLite mode + + except Exception as e: + logger.error(f"Error writing to SQLite database: {e}") + raise + finally: + if 'conn' in locals(): + conn.close() + + def update_debug_samples(self, results: List[Dict[str, Any]]): + """ + Update debug samples collections with new results + + Args: + results: List of processing results + """ + for result in results: + # Only track multi-tag samples for debugging + if len(result['tags']) > 1: + # Add to multi-tag samples if not full + if len(self.multi_tag_samples) < self.max_debug_samples: + self.multi_tag_samples.append(result) + + def write_debug_file(self): + """Write debug samples to file""" + if self.sqlite_output_path: + debug_path = 'multi_tag_samples_sqlite.json' + else: + debug_path = os.path.join(self.output_dir, 'multi_tag_samples.json') + + # Process multi-tag samples + debug_data = [] + for sample in self.multi_tag_samples: + debug_entry = { + "job_id": sample['job_id'], + "file_id": sample['file_info']['file_id'], + "filename": sample['file_info']['filename'], + "tags": sample['tags'], + "tag_count": len(sample['tags']), + "matching_lines": sample['matching_lines'] + } + debug_data.append(debug_entry) + + # Write debug file + with open(debug_path, 'w') as f: + json.dump(debug_data, f, indent=2) + + logger.info(f"Wrote {len(debug_data)} multi-tag samples to {debug_path}") + + def write_user_stats(self): + """Write user statistics to file""" + if self.sqlite_output_path: + user_stats_path = 'user_statistics_sqlite.json' + else: + user_stats_path = os.path.join(self.output_dir, 'user_statistics.json') + + # Prepare user stats data + stats_data = { + 'overall_stats': {tag: len(users) for tag, users in self.user_stats.items()}, + 'overall_users': {tag: sorted(list(users)) for tag, users in self.user_stats.items()}, + 'total_unique_users_overall': len(set().union(*self.user_stats.values())) if self.user_stats else 0 + } + + # Add recent stats if time limit is set + if self.start_time_limit is not None: + stats_data['recent_stats'] = {tag: len(users) for tag, users in self.user_stats_recent.items()} + stats_data['recent_users'] = {tag: sorted(list(users)) for tag, users in self.user_stats_recent.items()} + stats_data['total_unique_users_recent'] = len(set().union(*self.user_stats_recent.values())) if self.user_stats_recent else 0 + stats_data['start_time_limit'] = self.start_time_limit + stats_data['start_time_limit_readable'] = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(self.start_time_limit)) + + # Write user stats file + with open(user_stats_path, 'w') as f: + json.dump(stats_data, f, indent=2) + + logger.info(f"Wrote user statistics to {user_stats_path}") + + # Log summary + logger.info("User Statistics Summary:") + for tag, count in stats_data['overall_stats'].items(): + logger.info(f" {tag}: {count} distinct users") + logger.info(f" Total unique users overall: {stats_data['total_unique_users_overall']}") + + if self.start_time_limit is not None: + logger.info(f"Recent User Statistics (since {stats_data['start_time_limit_readable']}):") + for tag, count in stats_data['recent_stats'].items(): + logger.info(f" {tag}: {count} distinct users") + logger.info(f" Total unique users recent: {stats_data['total_unique_users_recent']}") + + def process_all(self): + """ + Main processing function + """ + # Load keywords + self.load_keywords() + + try: + # Get all folders with image files + folders = self.list_input_folders() + + if not folders: + logger.info("No unprocessed folders found.") + return + + # Sort folders for more predictable processing + folder_paths = sorted(folders.keys()) + + valid_samples_buffer = [] + start_time = time.time() + + # Process each folder + for i, folder_path in enumerate(folder_paths): + files = folders[folder_path] + logger.info(f"Processing folder {i+1}/{len(folder_paths)}: {folder_path} ({len(files)} files)") + + # Process folder + valid_samples, other_samples = self.process_folder(folder_path, files) + + # Add to buffers + valid_samples_buffer.extend(valid_samples) + + # Update debug samples + self.update_debug_samples(other_samples) + + # Write samples if buffer is large enough + if self.sqlite_output_path: + # SQLite mode: write all buffered samples at once + if valid_samples_buffer: + self.write_to_sqlite(valid_samples_buffer) + valid_samples_buffer = [] + else: + # WebDataset mode: write shards + if len(valid_samples_buffer) >= self.shard_size: + # Write full shards + while len(valid_samples_buffer) >= self.shard_size: + shard_samples = valid_samples_buffer[:self.shard_size] + valid_samples_buffer = valid_samples_buffer[self.shard_size:] + + self.write_shard(shard_samples, self.current_shard) + self.current_shard += 1 + + # Save progress + self._save_progress() + + # Calculate and log progress + elapsed = time.time() - start_time + progress = self.stats['total_processed_folders'] / len(folder_paths) * 100 + folders_per_sec = self.stats['total_processed_folders'] / elapsed if elapsed > 0 else 0 + + logger.info(f"Progress: {self.stats['total_processed_folders']}/{len(folder_paths)} folders " + f"({progress:.1f}%) - {folders_per_sec:.2f} folders/sec") + + logger.info(f"Files: {self.stats['processed_files']}/{self.stats['total_files']} " + f"(Valid: {self.stats['valid_count']}, Multi-tag: {self.stats['multi_tag_count']}, " + f"Unknown: {self.stats['unknown_count']}, Errors: {self.stats['error_count']})") + + # Project completion time + if folders_per_sec > 0: + remaining_folders = len(folder_paths) - self.stats['total_processed_folders'] + remaining_time = remaining_folders / folders_per_sec + eta = time.strftime('%H:%M:%S', time.gmtime(remaining_time)) + logger.info(f"Estimated time remaining: {eta}") + + # Write remaining samples + if valid_samples_buffer: + if self.sqlite_output_path: + self.write_to_sqlite(valid_samples_buffer) + else: + self.write_shard(valid_samples_buffer, self.current_shard) + self.current_shard += 1 + + # Write debug samples + self.write_debug_file() + + # Write user statistics + self.write_user_stats() + + # Final progress update + self._save_progress() + + # Log final statistics + elapsed = time.time() - start_time + logger.info( + f"Processing complete in {elapsed:.1f}s.\n" + f"Total folders: {len(folder_paths)}\n" + f"Total files: {self.stats['total_files']}\n" + f"Valid samples: {self.stats['valid_count']}\n" + f"Multi-tag samples: {self.stats['multi_tag_count']}\n" + f"Unknown samples: {self.stats['unknown_count']}\n" + f"Errors: {self.stats['error_count']}\n" + f"Shards created: {self.current_shard}" + ) + + finally: + # Close database connections + self.close_db_connections() + + def process_folder(self, folder_path: str, files: List[Dict[str, str]]) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]: + """ + Process all files in a folder + + Args: + folder_path: Folder path identifier + files: List of file info dictionaries in the folder + + Returns: + Tuple of (valid_samples, other_samples) + """ + results = [] + valid_samples = [] + other_samples = [] + + with ThreadPoolExecutor(max_workers=self.num_workers) as executor: + # Submit tasks + future_to_file = { + executor.submit(self.process_file, file_info): file_info['file_id'] + for file_info in files + } + + # Process results as they complete + for future in as_completed(future_to_file): + file_id = future_to_file[future] + try: + result = future.result() + results.append(result) + + # Update statistics + if result['error'] and not result['valid']: + self.stats['error_count'] += 1 + other_samples.append(result) + elif result['valid']: + if result['tags'][0] == "Unknown": + self.stats['unknown_count'] += 1 + self.stats['valid_count'] += 1 + valid_samples.append(result) + elif len(result['tags']) > 1: + self.stats['multi_tag_count'] += 1 + # For multi-tag samples, still consider them valid but track separately + valid_samples.append(result) + other_samples.append(result) + + except Exception as e: + logger.error(f"Exception processing file {file_id}: {e}") + self.stats['error_count'] += 1 + + # Mark folder as processed + self.processed_folders.add(folder_path) + self.stats['total_processed_folders'] = len(self.processed_folders) + self.stats['processed_files'] += len(files) + + return valid_samples, other_samples + + +def main(): + """Main entry point""" + parser = argparse.ArgumentParser(description="Process job scripts and create ML datasets") + + parser.add_argument('--db', required=True, help='Path to the SQLite database') + parser.add_argument('--keywords', required=True, help='Path to the keywords CSV file') + parser.add_argument('--output-dir', default='datasets', help='Output directory for datasets') + parser.add_argument('--sqlite-output', help='Path to SQLite database for output (instead of WebDataset shards)') + parser.add_argument('--workers', type=int, default=None, help='Number of parallel workers') + parser.add_argument('--shard-size', type=int, default=1000, help='Number of samples per shard') + parser.add_argument('--resume', action='store_true', help='Resume from previous run') + parser.add_argument('--debug-samples', type=int, default=100, + help='Maximum number of debug samples to keep per category') + parser.add_argument('--start-time-limit', type=int, default=None, + help='Unix timestamp limit for filtering jobs (e.g., for past 2 years: use timestamp from 2 years ago). ' + 'If not provided, all jobs are included. You can calculate 2 years ago as: ' + 'python -c "import time; print(int(time.time()) - 2*365*24*3600)"') + + # Input source arguments + input_group = parser.add_mutually_exclusive_group(required=True) + input_group.add_argument('--tars', nargs='+', help='Paths to tar archives containing images') + input_group.add_argument('--dirs', nargs='+', help='Paths to directories containing images') + + args = parser.parse_args() + + # Determine input mode and paths + input_mode = 'tar' if args.tars else 'dir' + input_paths = args.tars if args.tars else args.dirs + + print(input_paths) + + # Initialize tagger + tagger = JobScriptTagger( + db_path=args.db, + keywords_csv=args.keywords, + output_dir=args.output_dir, + input_mode=input_mode, + input_paths=input_paths, + num_workers=args.workers, + shard_size=args.shard_size, + resume=args.resume, + max_debug_samples=args.debug_samples, + start_time_limit=args.start_time_limit, + sqlite_output_path=args.sqlite_output + ) + + # Run processing + tagger.process_all() + + +if __name__ == "__main__": + main() \ No newline at end of file