#!/usr/bin/env python3 import os import re import csv import json import sqlite3 import tarfile import logging import threading import argparse import pandas as pd import time from pathlib import Path from typing import Dict, List, Tuple, Set, Any, Optional, Union from concurrent.futures import ThreadPoolExecutor, as_completed import webdataset as wds from PIL import Image import io import hashlib # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', handlers=[ logging.FileHandler("job_tagger_dataset.log"), logging.StreamHandler() ] ) logger = logging.getLogger("JobScriptTagger") class JobScriptTagger: """Process job scripts and images to create ML datasets""" def __init__( self, db_path: str, keywords_csv: str, output_dir: str, input_mode: str = 'tar', input_paths: List[str] = None, num_workers: int = None, shard_size: int = 1000, resume: bool = False, max_debug_samples: int = 100, start_time_limit: Optional[int] = None, sqlite_output_path: Optional[str] = None ): """ Initialize the tagger Args: db_path: Path to the SQLite database keywords_csv: Path to the CSV file with application names and keywords output_dir: Directory for output files input_mode: 'tar' for tar archives or 'dir' for directories input_paths: List of tar archives or directories containing images num_workers: Number of parallel workers (defaults to CPU count) shard_size: Number of samples per output shard resume: Whether to resume from a previous run max_debug_samples: Maximum number of debug samples to keep start_time_limit: Unix timestamp limit for filtering jobs (None for no limit) sqlite_output_path: Path to SQLite database for output (if None, use WebDataset mode) """ self.db_path = db_path self.keywords_csv = keywords_csv self.output_dir = output_dir self.input_mode = input_mode self.input_paths = input_paths or [] self.num_workers = num_workers or os.cpu_count() self.shard_size = shard_size self.resume = resume self.max_debug_samples = max_debug_samples self.start_time_limit = start_time_limit self.sqlite_output_path = sqlite_output_path # Ensure output directory exists (only needed for WebDataset mode) if not sqlite_output_path: os.makedirs(output_dir, exist_ok=True) # Thread-local storage for database connections self.local_db = threading.local() self.app_keywords = {} # Maps app names to lists of keywords # Track processed folders and progress if sqlite_output_path: self.progress_file = 'progress_sqlite.json' else: self.progress_file = os.path.join(output_dir, 'progress.json') self.processed_folders = set() self.current_shard = 0 self.stats = { 'valid_count': 0, 'multi_tag_count': 0, 'unknown_count': 0, 'error_count': 0, 'total_processed_folders': 0, 'total_folders': 0, 'total_files': 0, 'processed_files': 0 } # User statistics tracking self.user_stats = {} # tag -> set of users self.user_stats_recent = {} # tag -> set of users (within time limit) # Debug sample collections self.multi_tag_samples = [] # Multi-label logging - keep track of files per shard self.multi_label_logs = {} # shard_id -> list of filenames # Initialize or load progress self._init_progress() logger.info(f"Initialized with {self.num_workers} workers, input mode: {input_mode}") def _track_user_stats(self, tag: str, user: str, start_time: Optional[int] = None): """ Track user statistics for a given tag Args: tag: Application tag user: HPC user start_time: Job start time (Unix timestamp) """ if not user: return # Track overall user stats if tag not in self.user_stats: self.user_stats[tag] = set() self.user_stats[tag].add(user) # Track recent user stats if start_time is within limit if self.start_time_limit is not None and start_time is not None: if start_time >= self.start_time_limit: if tag not in self.user_stats_recent: self.user_stats_recent[tag] = set() self.user_stats_recent[tag].add(user) def _init_progress(self): """Initialize or load progress tracking""" if self.resume and os.path.exists(self.progress_file): try: with open(self.progress_file, 'r') as f: progress_data = json.load(f) self.processed_folders = set(progress_data.get('processed_folders', [])) self.current_shard = progress_data.get('current_shard', 0) self.stats = progress_data.get('stats', self.stats) # Load user stats user_stats_data = progress_data.get('user_stats', {}) self.user_stats = {tag: set(users) for tag, users in user_stats_data.items()} user_stats_recent_data = progress_data.get('user_stats_recent', {}) self.user_stats_recent = {tag: set(users) for tag, users in user_stats_recent_data.items()} logger.info(f"Resumed from previous run: {len(self.processed_folders)} folders processed, " f"current shard: {self.current_shard}") except Exception as e: logger.warning(f"Failed to load progress file: {e}. Starting fresh.") self._save_progress() else: self._save_progress() def _save_progress(self): """Save current progress to file""" progress_data = { 'processed_folders': list(self.processed_folders), 'current_shard': self.current_shard, 'stats': self.stats, 'user_stats': {tag: list(users) for tag, users in self.user_stats.items()}, 'user_stats_recent': {tag: list(users) for tag, users in self.user_stats_recent.items()}, 'last_updated': time.strftime('%Y-%m-%d %H:%M:%S') } with open(self.progress_file, 'w') as f: json.dump(progress_data, f, indent=2) logger.debug(f"Progress saved: {self.stats['total_processed_folders']} folders processed") def get_db_connection(self): """ Get a thread-local database connection Returns: SQLite connection object """ # Check if this thread already has a connection if not hasattr(self.local_db, 'conn') or self.local_db.conn is None: try: self.local_db.conn = sqlite3.connect(self.db_path) logger.debug(f"Created new database connection for thread {threading.get_ident()}") except sqlite3.Error as e: logger.error(f"Database connection error in thread {threading.get_ident()}: {e}") raise return self.local_db.conn def close_db_connections(self): """Close all database connections""" # Since connections are thread-local, we can only close the one in the current thread if hasattr(self.local_db, 'conn') and self.local_db.conn is not None: self.local_db.conn.close() self.local_db.conn = None logger.info("Database connection closed") def load_keywords(self): """Load application keywords from CSV file""" try: # Read CSV file df = pd.read_csv(self.keywords_csv) # Process each row for _, row in df.iterrows(): app_name = row['Name'].strip() # Split keywords by semicolon and strip whitespace keywords = [kw.strip() for kw in row['Keyword'].split(';') if kw.strip()] if app_name and keywords: self.app_keywords[app_name] = keywords logger.info(f"Loaded {len(self.app_keywords)} applications with keywords") # Log the apps and their keywords for verification for app_name, keywords in self.app_keywords.items(): logger.debug(f"App: {app_name}, Keywords: {keywords}") except Exception as e: logger.error(f"Error loading keywords from {self.keywords_csv}: {e}") raise def extract_job_id(self, filename: str) -> Optional[int]: """ Extract job ID from image filename Args: filename: Image filename in format {job_id}_{timestamp}_color.png Returns: Job ID if successfully extracted, None otherwise """ match = re.match(r'(\d+)_\d+_color\.png$', os.path.basename(filename)) if match: return int(match.group(1)) return None def get_job_script(self, job_id: int) -> Optional[Dict[str, Any]]: """ Fetch job script from database Args: job_id: Job ID Returns: Dict with job script and metadata if found, None otherwise """ try: conn = self.get_db_connection() cursor = conn.cursor() query = "SELECT meta_data, duration, hpc_user, start_time FROM job WHERE job_id = ?" cursor.execute(query, (job_id,)) result = cursor.fetchone() if result: meta_data = json.loads(result[0]) if result[0] else {} duration = result[1] hpc_user = result[2] start_time = result[3] script = meta_data.get('jobScript', '') name = meta_data.get('jobName', '') # Check start_time limit if specified if self.start_time_limit is not None and start_time is not None: if start_time < self.start_time_limit: return None # Job is too old return { 'script': script, 'meta_data': meta_data, 'duration': duration, 'hpc_user': hpc_user, 'start_time': start_time, 'interactive': True if name == "interactive" else False } return None except sqlite3.Error as e: logger.error(f"Error fetching job script for job {job_id}: {e}") return None def is_valid_script(self, script: str) -> bool: """ Check if a script is valid (starts with #!) Args: script: Job script text Returns: True if script is valid, False otherwise """ if not script: return False # Check if script starts with #! return script.strip().startswith('#!') def filter_script_comments(self, script: str) -> Tuple[str, List[str]]: """ Filter out commented and empty lines from script Args: script: Job script text Returns: Tuple of (filtered_script_text, filtered_lines) """ lines = script.splitlines() filtered_lines = [] for line in lines: stripped = line.strip() # Keep shebang line and non-comment lines if stripped and (stripped.startswith('#!') or not stripped.startswith('#')): filtered_lines.append(line) filtered_script = '\n'.join(filtered_lines) return filtered_script, filtered_lines def check_keywords(self, script: str) -> Tuple[List[str], Dict[str, List[str]]]: """ Check which applications' keywords match in the script and extract matching lines Args: script: Job script text Returns: Tuple of (matching_app_names, {app_name: [matching_lines]}) """ # Filter out commented lines filtered_script, filtered_lines = self.filter_script_comments(script) script_lower = filtered_script.lower() matched_apps = [] matched_lines = {} for app_name, keywords in self.app_keywords.items(): app_matches = [] for keyword in keywords: keyword_lower = keyword.lower() # Check if keyword is in filtered script if keyword_lower in script_lower: matched_apps.append(app_name) # Find lines containing this keyword in filtered lines for line in filtered_lines: if keyword_lower in line.lower(): app_matches.append(line.strip()) # Break after first matching keyword for this app break # Save unique matching lines for this app if app_matches: matched_lines[app_name] = list(set(app_matches)) return matched_apps, matched_lines def get_folder_path(self, file_info: Dict[str, str]) -> str: """ Get the subfolder path for a file Args: file_info: File information dictionary Returns: Subfolder path string """ if self.input_mode == 'tar': # For tar files, use the parent directory in the archive path_parts = file_info['file_path'].split('/') if len(path_parts) > 1: # Return the tar name and top-level directory return f"{os.path.basename(file_info['tar_path'])}:{path_parts[0]}" else: return os.path.basename(file_info['tar_path']) else: # For directories, use the relative path's parent directory path = Path(file_info['rel_path']) if path.parent != Path('.'): return f"{os.path.basename(file_info['dir_path'])}/{path.parent}" else: return os.path.basename(file_info['dir_path']) def list_input_folders(self) -> Dict[str, List[Dict[str, str]]]: """ List all image files from input sources grouped by subfolder Returns: Dictionary mapping folder paths to lists of file information """ folders = {} total_files = 0 if self.input_mode == 'tar': # Process tar archives for tar_path in self.input_paths: if not os.path.exists(tar_path): logger.warning(f"Tar archive not found: {tar_path}") continue try: with tarfile.open(tar_path, 'r') as tar: for member in tar.getmembers(): if member.name.endswith('_color.png'): file_info = { 'source_type': 'tar', 'tar_path': tar_path, 'file_path': member.name, 'file_id': f"{tar_path}:{member.name}", 'filename': os.path.basename(member.name) } # Get folder path folder_path = self.get_folder_path(file_info) # Skip if folder already processed if folder_path in self.processed_folders and self.resume: continue # Add to folders dictionary if folder_path not in folders: folders[folder_path] = [] folders[folder_path].append(file_info) total_files += 1 except Exception as e: logger.error(f"Error reading tar archive {tar_path}: {e}") else: # Process directories for dir_path in self.input_paths: if not os.path.isdir(dir_path): logger.warning(f"Directory not found: {dir_path}") continue for root, _, files in os.walk(dir_path): # Filter image files image_files = [f for f in files if f.endswith('_color.png')] if not image_files: continue for filename in image_files: file_path = os.path.join(root, filename) rel_path = os.path.relpath(file_path, dir_path) file_info = { 'source_type': 'dir', 'dir_path': dir_path, 'file_path': file_path, 'rel_path': rel_path, 'file_id': f"{dir_path}:{rel_path}", 'filename': filename } # Get folder path folder_path = self.get_folder_path(file_info) # Skip if folder already processed if folder_path in self.processed_folders and self.resume: continue # Add to folders dictionary if folder_path not in folders: folders[folder_path] = [] folders[folder_path].append(file_info) total_files += 1 # Update stats self.stats['total_files'] = total_files self.stats['total_folders'] = len(folders) logger.info(f"Found {len(folders)} unprocessed folders with {total_files} files") return folders def process_file(self, file_info: Dict[str, str]) -> Dict[str, Any]: """ Process a single image file (worker function for thread pool) Args: file_info: Dictionary with file information Returns: Dictionary with processing results """ result = { 'file_info': file_info, 'job_id': None, 'tags': [], 'valid': False, 'error': None, 'script': None, 'script_excerpt': None, 'matching_lines': {}, 'duration': None, 'hpc_user': None, 'start_time': None } try: # Extract job ID from filename job_id = self.extract_job_id(file_info['filename']) if not job_id: result['error'] = f"Could not extract job ID from filename: {file_info['filename']}" return result result['job_id'] = job_id # Get job script job_data = self.get_job_script(job_id) # Check if job was filtered out by start_time_limit if job_data is None: result['error'] = f"Job {job_id} filtered out by start_time_limit" result['tags'] = ["Filtered"] result['valid'] = False return result # Store duration from job data result['duration'] = job_data.get('duration') if job_data else None result['hpc_user'] = job_data.get('hpc_user') if job_data else None result['start_time'] = job_data.get('start_time') if job_data else None if job_data['interactive']: result['tags'] = ["Interactive"] result['valid'] = True # Track user stats for interactive jobs if job_data.get('hpc_user'): self._track_user_stats("Interactive", job_data['hpc_user'], job_data.get('start_time')) return result # If no script, assign "Unknown" app label if not job_data or not job_data['script']: result['tags'] = ["Unknown"] result['valid'] = False # Track user stats for unknown jobs if job_data and job_data.get('hpc_user'): self._track_user_stats("Unknown", job_data['hpc_user'], job_data.get('start_time')) # result['error'] = "No job script found" return result script = job_data['script'] # If script doesn't start with #!, assign "Unknown" app label if not self.is_valid_script(script): result['tags'] = ["Unknown"] result['valid'] = True # Mark as valid with "Unknown" tag result['error'] = "Script doesn't start with #!" # Track user stats for unknown jobs if job_data.get('hpc_user'): self._track_user_stats("Unknown", job_data['hpc_user'], job_data.get('start_time')) return result # Store script excerpt for debugging lines = script.strip().split('\n') result['script_excerpt'] = '\n'.join(lines[:60]) if len(lines) > 60: result['script_excerpt'] += f"\n... [+ {len(lines) - 60} more lines]" # Store full script result['script'] = script # Check keywords matched_apps, matching_lines = self.check_keywords(script) # No matches, assign "Unknown" if not matched_apps: result['tags'] = ["Unknown"] result['valid'] = True result['error'] = "No matching keywords found" # Track user stats for unknown jobs if job_data.get('hpc_user'): self._track_user_stats("Unknown", job_data['hpc_user'], job_data.get('start_time')) else: result['tags'] = matched_apps result['matching_lines'] = matching_lines result['valid'] = len(matched_apps) == 1 # Valid if exactly one tag # Track user stats for matched apps if job_data.get('hpc_user'): for tag in matched_apps: self._track_user_stats(tag, job_data['hpc_user'], job_data.get('start_time')) except Exception as e: result['error'] = str(e) result['tags'] = ["Unknown"] # Default to "Unknown" on error result['valid'] = False logger.error(f"Error processing file {file_info['file_id']}: {e}") # Track user stats for error cases if we have job data if 'job_data' in locals() and job_data and job_data.get('hpc_user'): self._track_user_stats("Unknown", job_data['hpc_user'], job_data.get('start_time')) return result def read_image_data(self, file_info: Dict[str, str]) -> Optional[bytes]: """ Read image data from the source (tar or directory) Args: file_info: File information dictionary Returns: Image data as bytes if successful, None otherwise """ try: if file_info['source_type'] == 'tar': # Extract from tar archive with tarfile.open(file_info['tar_path'], 'r') as tar: img_file = tar.extractfile(file_info['file_path']) if img_file: return img_file.read() else: # Read from filesystem with open(file_info['file_path'], 'rb') as f: return f.read() return None except Exception as e: logger.error(f"Error reading image {file_info['file_id']}: {e}") return None def log_multi_label_files(self, shard_idx: int, multi_label_samples: List[Dict[str, Any]]): """ Log files with multiple labels to a file Args: shard_idx: Shard index number (ignored in SQLite mode) multi_label_samples: List of samples with multiple labels """ if not multi_label_samples: return if self.sqlite_output_path: log_path = 'multi_label_sqlite.txt' else: log_path = os.path.join(self.output_dir, f"multi_label_shard_{shard_idx:06d}.txt") with open(log_path, 'w') as f: f.write(f"Multi-label files in shard {shard_idx} ({len(multi_label_samples)} files):\n\n") for i, sample in enumerate(multi_label_samples, 1): f.write(f"File {i}: {sample['file_info']['filename']}\n") f.write(f"Job ID: {sample['job_id']}\n") f.write(f"Labels: {', '.join(sample['tags'])}\n") # Write matching lines for each app for app, lines in sample['matching_lines'].items(): f.write(f"\n {app} matching lines:\n") for line in lines: f.write(f" {line}\n") f.write("\n" + "-" * 80 + "\n\n") logger.info(f"Logged {len(multi_label_samples)} multi-label files to {log_path}") def log_unknown_script_files(self, shard_idx: int, unknown_samples: List[Dict[str, Any]]): """ Log files with unknown scripts to a file Args: shard_idx: Shard index number (ignored in SQLite mode) unknown_samples: List of samples with unknown scripts """ if not unknown_samples: return if self.sqlite_output_path: log_path = 'unknown_script_sqlite.txt' else: log_path = os.path.join(self.output_dir, f"unknown_script_shard_{shard_idx:06d}.txt") with open(log_path, 'w', encoding='utf-8') as f: f.write(f"Unknown script files in shard {shard_idx} ({len(unknown_samples)} files):\n\n") for i, sample in enumerate(unknown_samples, 1): f.write(f"File {i}: {sample['file_info']['filename']}\n") f.write(f"Job ID: {sample['job_id']}\n") f.write(f"Error: {sample.get('error', '')}\n") # Filter out comment lines (starting with #) from script excerpt script_excerpt = sample.get('script_excerpt', '') if script_excerpt: lines = script_excerpt.split('\n') filtered_lines = [line for line in lines if not line.strip().startswith('#')] filtered_excerpt = '\n'.join(filtered_lines) f.write(f"Script excerpt (comments removed):\n{filtered_excerpt}\n") else: f.write("Script excerpt: (none)\n") f.write("\n" + "-" * 80 + "\n\n") logger.info(f"Logged {len(unknown_samples)} unknown script files to {log_path}") def write_shard(self, samples: List[Dict[str, Any]], shard_idx: int): """ Write samples to a WebDataset shard Args: samples: List of samples to write shard_idx: Shard index number """ if not samples: return # Create shard filename with proper padding shard_path = os.path.join( self.output_dir, f"roofline_dataset_{shard_idx:06d}.tar" ) # Track multi-label and unknown samples for this shard multi_label_samples = [] unknown_samples = [] valid_count = 0 with wds.TarWriter(shard_path) as sink: for sample in samples: try: file_info = sample['file_info'] # Extract image data img_data = self.read_image_data(file_info) if not img_data: logger.warning(f"Could not read image for job {sample['job_id']}") continue # Check if this is a multi-label sample if len(sample['tags']) > 1: multi_label_samples.append(sample) # Check if this is an unknown sample if sample['tags'] == ["Unknown"]: unknown_samples.append(sample) # Create a unique key from original filename key = file_info['filename'].split('.')[0] # Remove extension # Extract lines that contain the keywords for the chosen app keyword_lines = [] if sample['tags'][0] != "Unknown" and sample['tags'][0] in sample['matching_lines']: keyword_lines = sample['matching_lines'][sample['tags'][0]] # Create sample wds_sample = { "__key__": key, "png": img_data, "json": json.dumps({ "job_id": sample['job_id'], "tags": sample['tags'], # Use the first tag for single-label classification "script": keyword_lines, # Add lines containing detected keywords "duration": sample.get('duration'), # Add job duration "hpc_user": sample.get('hpc_user'), # Add HPC user "start_time": sample.get('start_time') # Add job start time }) } # Write to WebDataset sink.write(wds_sample) valid_count += 1 except Exception as e: logger.error(f"Error adding sample {sample.get('job_id')} to WebDataset: {e}") # Log multi-label samples for this shard if multi_label_samples: self.log_multi_label_files(shard_idx, multi_label_samples) # Log unknown samples for this shard if unknown_samples: self.log_unknown_script_files(shard_idx, unknown_samples) logger.info(f"Created shard {shard_path} with {valid_count} samples") return valid_count def write_to_sqlite(self, samples: List[Dict[str, Any]]): """ Write samples to SQLite database Args: samples: List of samples to write """ if not samples: return try: # Connect to output database conn = sqlite3.connect(self.sqlite_output_path) cursor = conn.cursor() # Create table if it doesn't exist cursor.execute(''' CREATE TABLE IF NOT EXISTS job_tags ( job_id INTEGER PRIMARY KEY, tags TEXT NOT NULL, script_lines TEXT, duration REAL, hpc_user TEXT, start_time INTEGER, filename TEXT, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ) ''') # Track multi-label and unknown samples for logging multi_label_samples = [] unknown_samples = [] # Insert samples for sample in samples: try: # Check if this is a multi-label sample if len(sample['tags']) > 1: multi_label_samples.append(sample) # Check if this is an unknown sample if sample['tags'] == ["Unknown"]: unknown_samples.append(sample) # Extract lines that contain the keywords for the chosen app keyword_lines = [] if sample['tags'] and sample['tags'][0] != "Unknown" and sample['tags'][0] in sample.get('matching_lines', {}): keyword_lines = sample['matching_lines'][sample['tags'][0]] # Create filename key filename = sample['file_info']['filename'].split('.')[0] # Remove extension cursor.execute(''' INSERT OR REPLACE INTO job_tags (job_id, tags, script_lines, duration, hpc_user, start_time, filename) VALUES (?, ?, ?, ?, ?, ?, ?) ''', ( sample['job_id'], json.dumps(sample['tags']), json.dumps(keyword_lines), sample.get('duration'), sample.get('hpc_user'), sample.get('start_time'), filename )) except Exception as e: logger.error(f"Error inserting sample {sample.get('job_id')} to SQLite: {e}") conn.commit() logger.info(f"Inserted {len(samples)} samples to SQLite database") # Log multi-label samples if multi_label_samples: self.log_multi_label_files(0, multi_label_samples) # Use 0 as shard_idx for SQLite mode # Log unknown samples if unknown_samples: self.log_unknown_script_files(0, unknown_samples) # Use 0 as shard_idx for SQLite mode except Exception as e: logger.error(f"Error writing to SQLite database: {e}") raise finally: if 'conn' in locals(): conn.close() def update_debug_samples(self, results: List[Dict[str, Any]]): """ Update debug samples collections with new results Args: results: List of processing results """ for result in results: # Only track multi-tag samples for debugging if len(result['tags']) > 1: # Add to multi-tag samples if not full if len(self.multi_tag_samples) < self.max_debug_samples: self.multi_tag_samples.append(result) def write_debug_file(self): """Write debug samples to file""" if self.sqlite_output_path: debug_path = 'multi_tag_samples_sqlite.json' else: debug_path = os.path.join(self.output_dir, 'multi_tag_samples.json') # Process multi-tag samples debug_data = [] for sample in self.multi_tag_samples: debug_entry = { "job_id": sample['job_id'], "file_id": sample['file_info']['file_id'], "filename": sample['file_info']['filename'], "tags": sample['tags'], "tag_count": len(sample['tags']), "matching_lines": sample['matching_lines'] } debug_data.append(debug_entry) # Write debug file with open(debug_path, 'w') as f: json.dump(debug_data, f, indent=2) logger.info(f"Wrote {len(debug_data)} multi-tag samples to {debug_path}") def write_user_stats(self): """Write user statistics to file""" if self.sqlite_output_path: user_stats_path = 'user_statistics_sqlite.json' else: user_stats_path = os.path.join(self.output_dir, 'user_statistics.json') # Prepare user stats data stats_data = { 'overall_stats': {tag: len(users) for tag, users in self.user_stats.items()}, 'overall_users': {tag: sorted(list(users)) for tag, users in self.user_stats.items()}, 'total_unique_users_overall': len(set().union(*self.user_stats.values())) if self.user_stats else 0 } # Add recent stats if time limit is set if self.start_time_limit is not None: stats_data['recent_stats'] = {tag: len(users) for tag, users in self.user_stats_recent.items()} stats_data['recent_users'] = {tag: sorted(list(users)) for tag, users in self.user_stats_recent.items()} stats_data['total_unique_users_recent'] = len(set().union(*self.user_stats_recent.values())) if self.user_stats_recent else 0 stats_data['start_time_limit'] = self.start_time_limit stats_data['start_time_limit_readable'] = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(self.start_time_limit)) # Write user stats file with open(user_stats_path, 'w') as f: json.dump(stats_data, f, indent=2) logger.info(f"Wrote user statistics to {user_stats_path}") # Log summary logger.info("User Statistics Summary:") for tag, count in stats_data['overall_stats'].items(): logger.info(f" {tag}: {count} distinct users") logger.info(f" Total unique users overall: {stats_data['total_unique_users_overall']}") if self.start_time_limit is not None: logger.info(f"Recent User Statistics (since {stats_data['start_time_limit_readable']}):") for tag, count in stats_data['recent_stats'].items(): logger.info(f" {tag}: {count} distinct users") logger.info(f" Total unique users recent: {stats_data['total_unique_users_recent']}") def process_all(self): """ Main processing function """ # Load keywords self.load_keywords() try: # Get all folders with image files folders = self.list_input_folders() if not folders: logger.info("No unprocessed folders found.") return # Sort folders for more predictable processing folder_paths = sorted(folders.keys()) valid_samples_buffer = [] start_time = time.time() # Process each folder for i, folder_path in enumerate(folder_paths): files = folders[folder_path] logger.info(f"Processing folder {i+1}/{len(folder_paths)}: {folder_path} ({len(files)} files)") # Process folder valid_samples, other_samples = self.process_folder(folder_path, files) # Add to buffers valid_samples_buffer.extend(valid_samples) # Update debug samples self.update_debug_samples(other_samples) # Write samples if buffer is large enough if self.sqlite_output_path: # SQLite mode: write all buffered samples at once if valid_samples_buffer: self.write_to_sqlite(valid_samples_buffer) valid_samples_buffer = [] else: # WebDataset mode: write shards if len(valid_samples_buffer) >= self.shard_size: # Write full shards while len(valid_samples_buffer) >= self.shard_size: shard_samples = valid_samples_buffer[:self.shard_size] valid_samples_buffer = valid_samples_buffer[self.shard_size:] self.write_shard(shard_samples, self.current_shard) self.current_shard += 1 # Save progress self._save_progress() # Calculate and log progress elapsed = time.time() - start_time progress = self.stats['total_processed_folders'] / len(folder_paths) * 100 folders_per_sec = self.stats['total_processed_folders'] / elapsed if elapsed > 0 else 0 logger.info(f"Progress: {self.stats['total_processed_folders']}/{len(folder_paths)} folders " f"({progress:.1f}%) - {folders_per_sec:.2f} folders/sec") logger.info(f"Files: {self.stats['processed_files']}/{self.stats['total_files']} " f"(Valid: {self.stats['valid_count']}, Multi-tag: {self.stats['multi_tag_count']}, " f"Unknown: {self.stats['unknown_count']}, Errors: {self.stats['error_count']})") # Project completion time if folders_per_sec > 0: remaining_folders = len(folder_paths) - self.stats['total_processed_folders'] remaining_time = remaining_folders / folders_per_sec eta = time.strftime('%H:%M:%S', time.gmtime(remaining_time)) logger.info(f"Estimated time remaining: {eta}") # Write remaining samples if valid_samples_buffer: if self.sqlite_output_path: self.write_to_sqlite(valid_samples_buffer) else: self.write_shard(valid_samples_buffer, self.current_shard) self.current_shard += 1 # Write debug samples self.write_debug_file() # Write user statistics self.write_user_stats() # Final progress update self._save_progress() # Log final statistics elapsed = time.time() - start_time logger.info( f"Processing complete in {elapsed:.1f}s.\n" f"Total folders: {len(folder_paths)}\n" f"Total files: {self.stats['total_files']}\n" f"Valid samples: {self.stats['valid_count']}\n" f"Multi-tag samples: {self.stats['multi_tag_count']}\n" f"Unknown samples: {self.stats['unknown_count']}\n" f"Errors: {self.stats['error_count']}\n" f"Shards created: {self.current_shard}" ) finally: # Close database connections self.close_db_connections() def process_folder(self, folder_path: str, files: List[Dict[str, str]]) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]: """ Process all files in a folder Args: folder_path: Folder path identifier files: List of file info dictionaries in the folder Returns: Tuple of (valid_samples, other_samples) """ results = [] valid_samples = [] other_samples = [] with ThreadPoolExecutor(max_workers=self.num_workers) as executor: # Submit tasks future_to_file = { executor.submit(self.process_file, file_info): file_info['file_id'] for file_info in files } # Process results as they complete for future in as_completed(future_to_file): file_id = future_to_file[future] try: result = future.result() results.append(result) # Update statistics if result['error'] and not result['valid']: self.stats['error_count'] += 1 other_samples.append(result) elif result['valid']: if result['tags'][0] == "Unknown": self.stats['unknown_count'] += 1 self.stats['valid_count'] += 1 valid_samples.append(result) elif len(result['tags']) > 1: self.stats['multi_tag_count'] += 1 # For multi-tag samples, still consider them valid but track separately valid_samples.append(result) other_samples.append(result) except Exception as e: logger.error(f"Exception processing file {file_id}: {e}") self.stats['error_count'] += 1 # Mark folder as processed self.processed_folders.add(folder_path) self.stats['total_processed_folders'] = len(self.processed_folders) self.stats['processed_files'] += len(files) return valid_samples, other_samples def main(): """Main entry point""" parser = argparse.ArgumentParser(description="Process job scripts and create ML datasets") parser.add_argument('--db', required=True, help='Path to the SQLite database') parser.add_argument('--keywords', required=True, help='Path to the keywords CSV file') parser.add_argument('--output-dir', default='datasets', help='Output directory for datasets') parser.add_argument('--sqlite-output', help='Path to SQLite database for output (instead of WebDataset shards)') parser.add_argument('--workers', type=int, default=None, help='Number of parallel workers') parser.add_argument('--shard-size', type=int, default=1000, help='Number of samples per shard') parser.add_argument('--resume', action='store_true', help='Resume from previous run') parser.add_argument('--debug-samples', type=int, default=100, help='Maximum number of debug samples to keep per category') parser.add_argument('--start-time-limit', type=int, default=None, help='Unix timestamp limit for filtering jobs (e.g., for past 2 years: use timestamp from 2 years ago). ' 'If not provided, all jobs are included. You can calculate 2 years ago as: ' 'python -c "import time; print(int(time.time()) - 2*365*24*3600)"') # Input source arguments input_group = parser.add_mutually_exclusive_group(required=True) input_group.add_argument('--tars', nargs='+', help='Paths to tar archives containing images') input_group.add_argument('--dirs', nargs='+', help='Paths to directories containing images') args = parser.parse_args() # Determine input mode and paths input_mode = 'tar' if args.tars else 'dir' input_paths = args.tars if args.tars else args.dirs print(input_paths) # Initialize tagger tagger = JobScriptTagger( db_path=args.db, keywords_csv=args.keywords, output_dir=args.output_dir, input_mode=input_mode, input_paths=input_paths, num_workers=args.workers, shard_size=args.shard_size, resume=args.resume, max_debug_samples=args.debug_samples, start_time_limit=args.start_time_limit, sqlite_output_path=args.sqlite_output ) # Run processing tagger.process_all() if __name__ == "__main__": main()