Files
slurm-application-detection/job_script_tagger.py
BoleMa 2bd43009c3 Added Job Script Tagger and WebDataset Creator
This script processes image files from tar archives or directories, extracts job IDs, fetches job scripts from database, and identifies application tags based on keywords. Valid samples are saved to sharded WebDatasets while problematic ones are logged.
2025-10-14 15:05:15 +02:00

1174 lines
48 KiB
Python

#!/usr/bin/env python3
import os
import re
import csv
import json
import sqlite3
import tarfile
import logging
import threading
import argparse
import pandas as pd
import time
from pathlib import Path
from typing import Dict, List, Tuple, Set, Any, Optional, Union
from concurrent.futures import ThreadPoolExecutor, as_completed
import webdataset as wds
from PIL import Image
import io
import hashlib
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler("job_tagger_dataset.log"),
logging.StreamHandler()
]
)
logger = logging.getLogger("JobScriptTagger")
class JobScriptTagger:
"""Process job scripts and images to create ML datasets"""
def __init__(
self,
db_path: str,
keywords_csv: str,
output_dir: str,
input_mode: str = 'tar',
input_paths: List[str] = None,
num_workers: int = None,
shard_size: int = 1000,
resume: bool = False,
max_debug_samples: int = 100,
start_time_limit: Optional[int] = None,
sqlite_output_path: Optional[str] = None
):
"""
Initialize the tagger
Args:
db_path: Path to the SQLite database
keywords_csv: Path to the CSV file with application names and keywords
output_dir: Directory for output files
input_mode: 'tar' for tar archives or 'dir' for directories
input_paths: List of tar archives or directories containing images
num_workers: Number of parallel workers (defaults to CPU count)
shard_size: Number of samples per output shard
resume: Whether to resume from a previous run
max_debug_samples: Maximum number of debug samples to keep
start_time_limit: Unix timestamp limit for filtering jobs (None for no limit)
sqlite_output_path: Path to SQLite database for output (if None, use WebDataset mode)
"""
self.db_path = db_path
self.keywords_csv = keywords_csv
self.output_dir = output_dir
self.input_mode = input_mode
self.input_paths = input_paths or []
self.num_workers = num_workers or os.cpu_count()
self.shard_size = shard_size
self.resume = resume
self.max_debug_samples = max_debug_samples
self.start_time_limit = start_time_limit
self.sqlite_output_path = sqlite_output_path
# Ensure output directory exists (only needed for WebDataset mode)
if not sqlite_output_path:
os.makedirs(output_dir, exist_ok=True)
# Thread-local storage for database connections
self.local_db = threading.local()
self.app_keywords = {} # Maps app names to lists of keywords
# Track processed folders and progress
if sqlite_output_path:
self.progress_file = 'progress_sqlite.json'
else:
self.progress_file = os.path.join(output_dir, 'progress.json')
self.processed_folders = set()
self.current_shard = 0
self.stats = {
'valid_count': 0,
'multi_tag_count': 0,
'unknown_count': 0,
'error_count': 0,
'total_processed_folders': 0,
'total_folders': 0,
'total_files': 0,
'processed_files': 0
}
# User statistics tracking
self.user_stats = {} # tag -> set of users
self.user_stats_recent = {} # tag -> set of users (within time limit)
# Debug sample collections
self.multi_tag_samples = []
# Multi-label logging - keep track of files per shard
self.multi_label_logs = {} # shard_id -> list of filenames
# Initialize or load progress
self._init_progress()
logger.info(f"Initialized with {self.num_workers} workers, input mode: {input_mode}")
def _track_user_stats(self, tag: str, user: str, start_time: Optional[int] = None):
"""
Track user statistics for a given tag
Args:
tag: Application tag
user: HPC user
start_time: Job start time (Unix timestamp)
"""
if not user:
return
# Track overall user stats
if tag not in self.user_stats:
self.user_stats[tag] = set()
self.user_stats[tag].add(user)
# Track recent user stats if start_time is within limit
if self.start_time_limit is not None and start_time is not None:
if start_time >= self.start_time_limit:
if tag not in self.user_stats_recent:
self.user_stats_recent[tag] = set()
self.user_stats_recent[tag].add(user)
def _init_progress(self):
"""Initialize or load progress tracking"""
if self.resume and os.path.exists(self.progress_file):
try:
with open(self.progress_file, 'r') as f:
progress_data = json.load(f)
self.processed_folders = set(progress_data.get('processed_folders', []))
self.current_shard = progress_data.get('current_shard', 0)
self.stats = progress_data.get('stats', self.stats)
# Load user stats
user_stats_data = progress_data.get('user_stats', {})
self.user_stats = {tag: set(users) for tag, users in user_stats_data.items()}
user_stats_recent_data = progress_data.get('user_stats_recent', {})
self.user_stats_recent = {tag: set(users) for tag, users in user_stats_recent_data.items()}
logger.info(f"Resumed from previous run: {len(self.processed_folders)} folders processed, "
f"current shard: {self.current_shard}")
except Exception as e:
logger.warning(f"Failed to load progress file: {e}. Starting fresh.")
self._save_progress()
else:
self._save_progress()
def _save_progress(self):
"""Save current progress to file"""
progress_data = {
'processed_folders': list(self.processed_folders),
'current_shard': self.current_shard,
'stats': self.stats,
'user_stats': {tag: list(users) for tag, users in self.user_stats.items()},
'user_stats_recent': {tag: list(users) for tag, users in self.user_stats_recent.items()},
'last_updated': time.strftime('%Y-%m-%d %H:%M:%S')
}
with open(self.progress_file, 'w') as f:
json.dump(progress_data, f, indent=2)
logger.debug(f"Progress saved: {self.stats['total_processed_folders']} folders processed")
def get_db_connection(self):
"""
Get a thread-local database connection
Returns:
SQLite connection object
"""
# Check if this thread already has a connection
if not hasattr(self.local_db, 'conn') or self.local_db.conn is None:
try:
self.local_db.conn = sqlite3.connect(self.db_path)
logger.debug(f"Created new database connection for thread {threading.get_ident()}")
except sqlite3.Error as e:
logger.error(f"Database connection error in thread {threading.get_ident()}: {e}")
raise
return self.local_db.conn
def close_db_connections(self):
"""Close all database connections"""
# Since connections are thread-local, we can only close the one in the current thread
if hasattr(self.local_db, 'conn') and self.local_db.conn is not None:
self.local_db.conn.close()
self.local_db.conn = None
logger.info("Database connection closed")
def load_keywords(self):
"""Load application keywords from CSV file"""
try:
# Read CSV file
df = pd.read_csv(self.keywords_csv)
# Process each row
for _, row in df.iterrows():
app_name = row['Name'].strip()
# Split keywords by semicolon and strip whitespace
keywords = [kw.strip() for kw in row['Keyword'].split(';') if kw.strip()]
if app_name and keywords:
self.app_keywords[app_name] = keywords
logger.info(f"Loaded {len(self.app_keywords)} applications with keywords")
# Log the apps and their keywords for verification
for app_name, keywords in self.app_keywords.items():
logger.debug(f"App: {app_name}, Keywords: {keywords}")
except Exception as e:
logger.error(f"Error loading keywords from {self.keywords_csv}: {e}")
raise
def extract_job_id(self, filename: str) -> Optional[int]:
"""
Extract job ID from image filename
Args:
filename: Image filename in format {job_id}_{timestamp}_color.png
Returns:
Job ID if successfully extracted, None otherwise
"""
match = re.match(r'(\d+)_\d+_color\.png$', os.path.basename(filename))
if match:
return int(match.group(1))
return None
def get_job_script(self, job_id: int) -> Optional[Dict[str, Any]]:
"""
Fetch job script from database
Args:
job_id: Job ID
Returns:
Dict with job script and metadata if found, None otherwise
"""
try:
conn = self.get_db_connection()
cursor = conn.cursor()
query = "SELECT meta_data, duration, hpc_user, start_time FROM job WHERE job_id = ?"
cursor.execute(query, (job_id,))
result = cursor.fetchone()
if result:
meta_data = json.loads(result[0]) if result[0] else {}
duration = result[1]
hpc_user = result[2]
start_time = result[3]
script = meta_data.get('jobScript', '')
name = meta_data.get('jobName', '')
# Check start_time limit if specified
if self.start_time_limit is not None and start_time is not None:
if start_time < self.start_time_limit:
return None # Job is too old
return {
'script': script,
'meta_data': meta_data,
'duration': duration,
'hpc_user': hpc_user,
'start_time': start_time,
'interactive': True if name == "interactive" else False
}
return None
except sqlite3.Error as e:
logger.error(f"Error fetching job script for job {job_id}: {e}")
return None
def is_valid_script(self, script: str) -> bool:
"""
Check if a script is valid (starts with #!)
Args:
script: Job script text
Returns:
True if script is valid, False otherwise
"""
if not script:
return False
# Check if script starts with #!
return script.strip().startswith('#!')
def filter_script_comments(self, script: str) -> Tuple[str, List[str]]:
"""
Filter out commented and empty lines from script
Args:
script: Job script text
Returns:
Tuple of (filtered_script_text, filtered_lines)
"""
lines = script.splitlines()
filtered_lines = []
for line in lines:
stripped = line.strip()
# Keep shebang line and non-comment lines
if stripped and (stripped.startswith('#!') or not stripped.startswith('#')):
filtered_lines.append(line)
filtered_script = '\n'.join(filtered_lines)
return filtered_script, filtered_lines
def check_keywords(self, script: str) -> Tuple[List[str], Dict[str, List[str]]]:
"""
Check which applications' keywords match in the script and extract matching lines
Args:
script: Job script text
Returns:
Tuple of (matching_app_names, {app_name: [matching_lines]})
"""
# Filter out commented lines
filtered_script, filtered_lines = self.filter_script_comments(script)
script_lower = filtered_script.lower()
matched_apps = []
matched_lines = {}
for app_name, keywords in self.app_keywords.items():
app_matches = []
for keyword in keywords:
keyword_lower = keyword.lower()
# Check if keyword is in filtered script
if keyword_lower in script_lower:
matched_apps.append(app_name)
# Find lines containing this keyword in filtered lines
for line in filtered_lines:
if keyword_lower in line.lower():
app_matches.append(line.strip())
# Break after first matching keyword for this app
break
# Save unique matching lines for this app
if app_matches:
matched_lines[app_name] = list(set(app_matches))
return matched_apps, matched_lines
def get_folder_path(self, file_info: Dict[str, str]) -> str:
"""
Get the subfolder path for a file
Args:
file_info: File information dictionary
Returns:
Subfolder path string
"""
if self.input_mode == 'tar':
# For tar files, use the parent directory in the archive
path_parts = file_info['file_path'].split('/')
if len(path_parts) > 1:
# Return the tar name and top-level directory
return f"{os.path.basename(file_info['tar_path'])}:{path_parts[0]}"
else:
return os.path.basename(file_info['tar_path'])
else:
# For directories, use the relative path's parent directory
path = Path(file_info['rel_path'])
if path.parent != Path('.'):
return f"{os.path.basename(file_info['dir_path'])}/{path.parent}"
else:
return os.path.basename(file_info['dir_path'])
def list_input_folders(self) -> Dict[str, List[Dict[str, str]]]:
"""
List all image files from input sources grouped by subfolder
Returns:
Dictionary mapping folder paths to lists of file information
"""
folders = {}
total_files = 0
if self.input_mode == 'tar':
# Process tar archives
for tar_path in self.input_paths:
if not os.path.exists(tar_path):
logger.warning(f"Tar archive not found: {tar_path}")
continue
try:
with tarfile.open(tar_path, 'r') as tar:
for member in tar.getmembers():
if member.name.endswith('_color.png'):
file_info = {
'source_type': 'tar',
'tar_path': tar_path,
'file_path': member.name,
'file_id': f"{tar_path}:{member.name}",
'filename': os.path.basename(member.name)
}
# Get folder path
folder_path = self.get_folder_path(file_info)
# Skip if folder already processed
if folder_path in self.processed_folders and self.resume:
continue
# Add to folders dictionary
if folder_path not in folders:
folders[folder_path] = []
folders[folder_path].append(file_info)
total_files += 1
except Exception as e:
logger.error(f"Error reading tar archive {tar_path}: {e}")
else:
# Process directories
for dir_path in self.input_paths:
if not os.path.isdir(dir_path):
logger.warning(f"Directory not found: {dir_path}")
continue
for root, _, files in os.walk(dir_path):
# Filter image files
image_files = [f for f in files if f.endswith('_color.png')]
if not image_files:
continue
for filename in image_files:
file_path = os.path.join(root, filename)
rel_path = os.path.relpath(file_path, dir_path)
file_info = {
'source_type': 'dir',
'dir_path': dir_path,
'file_path': file_path,
'rel_path': rel_path,
'file_id': f"{dir_path}:{rel_path}",
'filename': filename
}
# Get folder path
folder_path = self.get_folder_path(file_info)
# Skip if folder already processed
if folder_path in self.processed_folders and self.resume:
continue
# Add to folders dictionary
if folder_path not in folders:
folders[folder_path] = []
folders[folder_path].append(file_info)
total_files += 1
# Update stats
self.stats['total_files'] = total_files
self.stats['total_folders'] = len(folders)
logger.info(f"Found {len(folders)} unprocessed folders with {total_files} files")
return folders
def process_file(self, file_info: Dict[str, str]) -> Dict[str, Any]:
"""
Process a single image file (worker function for thread pool)
Args:
file_info: Dictionary with file information
Returns:
Dictionary with processing results
"""
result = {
'file_info': file_info,
'job_id': None,
'tags': [],
'valid': False,
'error': None,
'script': None,
'script_excerpt': None,
'matching_lines': {},
'duration': None,
'hpc_user': None,
'start_time': None
}
try:
# Extract job ID from filename
job_id = self.extract_job_id(file_info['filename'])
if not job_id:
result['error'] = f"Could not extract job ID from filename: {file_info['filename']}"
return result
result['job_id'] = job_id
# Get job script
job_data = self.get_job_script(job_id)
# Check if job was filtered out by start_time_limit
if job_data is None:
result['error'] = f"Job {job_id} filtered out by start_time_limit"
result['tags'] = ["Filtered"]
result['valid'] = False
return result
# Store duration from job data
result['duration'] = job_data.get('duration') if job_data else None
result['hpc_user'] = job_data.get('hpc_user') if job_data else None
result['start_time'] = job_data.get('start_time') if job_data else None
if job_data['interactive']:
result['tags'] = ["Interactive"]
result['valid'] = True
# Track user stats for interactive jobs
if job_data.get('hpc_user'):
self._track_user_stats("Interactive", job_data['hpc_user'], job_data.get('start_time'))
return result
# If no script, assign "Unknown" app label
if not job_data or not job_data['script']:
result['tags'] = ["Unknown"]
result['valid'] = False
# Track user stats for unknown jobs
if job_data and job_data.get('hpc_user'):
self._track_user_stats("Unknown", job_data['hpc_user'], job_data.get('start_time'))
# result['error'] = "No job script found"
return result
script = job_data['script']
# If script doesn't start with #!, assign "Unknown" app label
if not self.is_valid_script(script):
result['tags'] = ["Unknown"]
result['valid'] = True # Mark as valid with "Unknown" tag
result['error'] = "Script doesn't start with #!"
# Track user stats for unknown jobs
if job_data.get('hpc_user'):
self._track_user_stats("Unknown", job_data['hpc_user'], job_data.get('start_time'))
return result
# Store script excerpt for debugging
lines = script.strip().split('\n')
result['script_excerpt'] = '\n'.join(lines[:60])
if len(lines) > 60:
result['script_excerpt'] += f"\n... [+ {len(lines) - 60} more lines]"
# Store full script
result['script'] = script
# Check keywords
matched_apps, matching_lines = self.check_keywords(script)
# No matches, assign "Unknown"
if not matched_apps:
result['tags'] = ["Unknown"]
result['valid'] = True
result['error'] = "No matching keywords found"
# Track user stats for unknown jobs
if job_data.get('hpc_user'):
self._track_user_stats("Unknown", job_data['hpc_user'], job_data.get('start_time'))
else:
result['tags'] = matched_apps
result['matching_lines'] = matching_lines
result['valid'] = len(matched_apps) == 1 # Valid if exactly one tag
# Track user stats for matched apps
if job_data.get('hpc_user'):
for tag in matched_apps:
self._track_user_stats(tag, job_data['hpc_user'], job_data.get('start_time'))
except Exception as e:
result['error'] = str(e)
result['tags'] = ["Unknown"] # Default to "Unknown" on error
result['valid'] = False
logger.error(f"Error processing file {file_info['file_id']}: {e}")
# Track user stats for error cases if we have job data
if 'job_data' in locals() and job_data and job_data.get('hpc_user'):
self._track_user_stats("Unknown", job_data['hpc_user'], job_data.get('start_time'))
return result
def read_image_data(self, file_info: Dict[str, str]) -> Optional[bytes]:
"""
Read image data from the source (tar or directory)
Args:
file_info: File information dictionary
Returns:
Image data as bytes if successful, None otherwise
"""
try:
if file_info['source_type'] == 'tar':
# Extract from tar archive
with tarfile.open(file_info['tar_path'], 'r') as tar:
img_file = tar.extractfile(file_info['file_path'])
if img_file:
return img_file.read()
else:
# Read from filesystem
with open(file_info['file_path'], 'rb') as f:
return f.read()
return None
except Exception as e:
logger.error(f"Error reading image {file_info['file_id']}: {e}")
return None
def log_multi_label_files(self, shard_idx: int, multi_label_samples: List[Dict[str, Any]]):
"""
Log files with multiple labels to a file
Args:
shard_idx: Shard index number (ignored in SQLite mode)
multi_label_samples: List of samples with multiple labels
"""
if not multi_label_samples:
return
if self.sqlite_output_path:
log_path = 'multi_label_sqlite.txt'
else:
log_path = os.path.join(self.output_dir, f"multi_label_shard_{shard_idx:06d}.txt")
with open(log_path, 'w') as f:
f.write(f"Multi-label files in shard {shard_idx} ({len(multi_label_samples)} files):\n\n")
for i, sample in enumerate(multi_label_samples, 1):
f.write(f"File {i}: {sample['file_info']['filename']}\n")
f.write(f"Job ID: {sample['job_id']}\n")
f.write(f"Labels: {', '.join(sample['tags'])}\n")
# Write matching lines for each app
for app, lines in sample['matching_lines'].items():
f.write(f"\n {app} matching lines:\n")
for line in lines:
f.write(f" {line}\n")
f.write("\n" + "-" * 80 + "\n\n")
logger.info(f"Logged {len(multi_label_samples)} multi-label files to {log_path}")
def log_unknown_script_files(self, shard_idx: int, unknown_samples: List[Dict[str, Any]]):
"""
Log files with unknown scripts to a file
Args:
shard_idx: Shard index number (ignored in SQLite mode)
unknown_samples: List of samples with unknown scripts
"""
if not unknown_samples:
return
if self.sqlite_output_path:
log_path = 'unknown_script_sqlite.txt'
else:
log_path = os.path.join(self.output_dir, f"unknown_script_shard_{shard_idx:06d}.txt")
with open(log_path, 'w', encoding='utf-8') as f:
f.write(f"Unknown script files in shard {shard_idx} ({len(unknown_samples)} files):\n\n")
for i, sample in enumerate(unknown_samples, 1):
f.write(f"File {i}: {sample['file_info']['filename']}\n")
f.write(f"Job ID: {sample['job_id']}\n")
f.write(f"Error: {sample.get('error', '')}\n")
# Filter out comment lines (starting with #) from script excerpt
script_excerpt = sample.get('script_excerpt', '')
if script_excerpt:
lines = script_excerpt.split('\n')
filtered_lines = [line for line in lines if not line.strip().startswith('#')]
filtered_excerpt = '\n'.join(filtered_lines)
f.write(f"Script excerpt (comments removed):\n{filtered_excerpt}\n")
else:
f.write("Script excerpt: (none)\n")
f.write("\n" + "-" * 80 + "\n\n")
logger.info(f"Logged {len(unknown_samples)} unknown script files to {log_path}")
def write_shard(self, samples: List[Dict[str, Any]], shard_idx: int):
"""
Write samples to a WebDataset shard
Args:
samples: List of samples to write
shard_idx: Shard index number
"""
if not samples:
return
# Create shard filename with proper padding
shard_path = os.path.join(
self.output_dir,
f"roofline_dataset_{shard_idx:06d}.tar"
)
# Track multi-label and unknown samples for this shard
multi_label_samples = []
unknown_samples = []
valid_count = 0
with wds.TarWriter(shard_path) as sink:
for sample in samples:
try:
file_info = sample['file_info']
# Extract image data
img_data = self.read_image_data(file_info)
if not img_data:
logger.warning(f"Could not read image for job {sample['job_id']}")
continue
# Check if this is a multi-label sample
if len(sample['tags']) > 1:
multi_label_samples.append(sample)
# Check if this is an unknown sample
if sample['tags'] == ["Unknown"]:
unknown_samples.append(sample)
# Create a unique key from original filename
key = file_info['filename'].split('.')[0] # Remove extension
# Extract lines that contain the keywords for the chosen app
keyword_lines = []
if sample['tags'][0] != "Unknown" and sample['tags'][0] in sample['matching_lines']:
keyword_lines = sample['matching_lines'][sample['tags'][0]]
# Create sample
wds_sample = {
"__key__": key,
"png": img_data,
"json": json.dumps({
"job_id": sample['job_id'],
"tags": sample['tags'], # Use the first tag for single-label classification
"script": keyword_lines, # Add lines containing detected keywords
"duration": sample.get('duration'), # Add job duration
"hpc_user": sample.get('hpc_user'), # Add HPC user
"start_time": sample.get('start_time') # Add job start time
})
}
# Write to WebDataset
sink.write(wds_sample)
valid_count += 1
except Exception as e:
logger.error(f"Error adding sample {sample.get('job_id')} to WebDataset: {e}")
# Log multi-label samples for this shard
if multi_label_samples:
self.log_multi_label_files(shard_idx, multi_label_samples)
# Log unknown samples for this shard
if unknown_samples:
self.log_unknown_script_files(shard_idx, unknown_samples)
logger.info(f"Created shard {shard_path} with {valid_count} samples")
return valid_count
def write_to_sqlite(self, samples: List[Dict[str, Any]]):
"""
Write samples to SQLite database
Args:
samples: List of samples to write
"""
if not samples:
return
try:
# Connect to output database
conn = sqlite3.connect(self.sqlite_output_path)
cursor = conn.cursor()
# Create table if it doesn't exist
cursor.execute('''
CREATE TABLE IF NOT EXISTS job_tags (
job_id INTEGER PRIMARY KEY,
tags TEXT NOT NULL,
script_lines TEXT,
duration REAL,
hpc_user TEXT,
start_time INTEGER,
filename TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
''')
# Track multi-label and unknown samples for logging
multi_label_samples = []
unknown_samples = []
# Insert samples
for sample in samples:
try:
# Check if this is a multi-label sample
if len(sample['tags']) > 1:
multi_label_samples.append(sample)
# Check if this is an unknown sample
if sample['tags'] == ["Unknown"]:
unknown_samples.append(sample)
# Extract lines that contain the keywords for the chosen app
keyword_lines = []
if sample['tags'] and sample['tags'][0] != "Unknown" and sample['tags'][0] in sample.get('matching_lines', {}):
keyword_lines = sample['matching_lines'][sample['tags'][0]]
# Create filename key
filename = sample['file_info']['filename'].split('.')[0] # Remove extension
cursor.execute('''
INSERT OR REPLACE INTO job_tags
(job_id, tags, script_lines, duration, hpc_user, start_time, filename)
VALUES (?, ?, ?, ?, ?, ?, ?)
''', (
sample['job_id'],
json.dumps(sample['tags']),
json.dumps(keyword_lines),
sample.get('duration'),
sample.get('hpc_user'),
sample.get('start_time'),
filename
))
except Exception as e:
logger.error(f"Error inserting sample {sample.get('job_id')} to SQLite: {e}")
conn.commit()
logger.info(f"Inserted {len(samples)} samples to SQLite database")
# Log multi-label samples
if multi_label_samples:
self.log_multi_label_files(0, multi_label_samples) # Use 0 as shard_idx for SQLite mode
# Log unknown samples
if unknown_samples:
self.log_unknown_script_files(0, unknown_samples) # Use 0 as shard_idx for SQLite mode
except Exception as e:
logger.error(f"Error writing to SQLite database: {e}")
raise
finally:
if 'conn' in locals():
conn.close()
def update_debug_samples(self, results: List[Dict[str, Any]]):
"""
Update debug samples collections with new results
Args:
results: List of processing results
"""
for result in results:
# Only track multi-tag samples for debugging
if len(result['tags']) > 1:
# Add to multi-tag samples if not full
if len(self.multi_tag_samples) < self.max_debug_samples:
self.multi_tag_samples.append(result)
def write_debug_file(self):
"""Write debug samples to file"""
if self.sqlite_output_path:
debug_path = 'multi_tag_samples_sqlite.json'
else:
debug_path = os.path.join(self.output_dir, 'multi_tag_samples.json')
# Process multi-tag samples
debug_data = []
for sample in self.multi_tag_samples:
debug_entry = {
"job_id": sample['job_id'],
"file_id": sample['file_info']['file_id'],
"filename": sample['file_info']['filename'],
"tags": sample['tags'],
"tag_count": len(sample['tags']),
"matching_lines": sample['matching_lines']
}
debug_data.append(debug_entry)
# Write debug file
with open(debug_path, 'w') as f:
json.dump(debug_data, f, indent=2)
logger.info(f"Wrote {len(debug_data)} multi-tag samples to {debug_path}")
def write_user_stats(self):
"""Write user statistics to file"""
if self.sqlite_output_path:
user_stats_path = 'user_statistics_sqlite.json'
else:
user_stats_path = os.path.join(self.output_dir, 'user_statistics.json')
# Prepare user stats data
stats_data = {
'overall_stats': {tag: len(users) for tag, users in self.user_stats.items()},
'overall_users': {tag: sorted(list(users)) for tag, users in self.user_stats.items()},
'total_unique_users_overall': len(set().union(*self.user_stats.values())) if self.user_stats else 0
}
# Add recent stats if time limit is set
if self.start_time_limit is not None:
stats_data['recent_stats'] = {tag: len(users) for tag, users in self.user_stats_recent.items()}
stats_data['recent_users'] = {tag: sorted(list(users)) for tag, users in self.user_stats_recent.items()}
stats_data['total_unique_users_recent'] = len(set().union(*self.user_stats_recent.values())) if self.user_stats_recent else 0
stats_data['start_time_limit'] = self.start_time_limit
stats_data['start_time_limit_readable'] = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(self.start_time_limit))
# Write user stats file
with open(user_stats_path, 'w') as f:
json.dump(stats_data, f, indent=2)
logger.info(f"Wrote user statistics to {user_stats_path}")
# Log summary
logger.info("User Statistics Summary:")
for tag, count in stats_data['overall_stats'].items():
logger.info(f" {tag}: {count} distinct users")
logger.info(f" Total unique users overall: {stats_data['total_unique_users_overall']}")
if self.start_time_limit is not None:
logger.info(f"Recent User Statistics (since {stats_data['start_time_limit_readable']}):")
for tag, count in stats_data['recent_stats'].items():
logger.info(f" {tag}: {count} distinct users")
logger.info(f" Total unique users recent: {stats_data['total_unique_users_recent']}")
def process_all(self):
"""
Main processing function
"""
# Load keywords
self.load_keywords()
try:
# Get all folders with image files
folders = self.list_input_folders()
if not folders:
logger.info("No unprocessed folders found.")
return
# Sort folders for more predictable processing
folder_paths = sorted(folders.keys())
valid_samples_buffer = []
start_time = time.time()
# Process each folder
for i, folder_path in enumerate(folder_paths):
files = folders[folder_path]
logger.info(f"Processing folder {i+1}/{len(folder_paths)}: {folder_path} ({len(files)} files)")
# Process folder
valid_samples, other_samples = self.process_folder(folder_path, files)
# Add to buffers
valid_samples_buffer.extend(valid_samples)
# Update debug samples
self.update_debug_samples(other_samples)
# Write samples if buffer is large enough
if self.sqlite_output_path:
# SQLite mode: write all buffered samples at once
if valid_samples_buffer:
self.write_to_sqlite(valid_samples_buffer)
valid_samples_buffer = []
else:
# WebDataset mode: write shards
if len(valid_samples_buffer) >= self.shard_size:
# Write full shards
while len(valid_samples_buffer) >= self.shard_size:
shard_samples = valid_samples_buffer[:self.shard_size]
valid_samples_buffer = valid_samples_buffer[self.shard_size:]
self.write_shard(shard_samples, self.current_shard)
self.current_shard += 1
# Save progress
self._save_progress()
# Calculate and log progress
elapsed = time.time() - start_time
progress = self.stats['total_processed_folders'] / len(folder_paths) * 100
folders_per_sec = self.stats['total_processed_folders'] / elapsed if elapsed > 0 else 0
logger.info(f"Progress: {self.stats['total_processed_folders']}/{len(folder_paths)} folders "
f"({progress:.1f}%) - {folders_per_sec:.2f} folders/sec")
logger.info(f"Files: {self.stats['processed_files']}/{self.stats['total_files']} "
f"(Valid: {self.stats['valid_count']}, Multi-tag: {self.stats['multi_tag_count']}, "
f"Unknown: {self.stats['unknown_count']}, Errors: {self.stats['error_count']})")
# Project completion time
if folders_per_sec > 0:
remaining_folders = len(folder_paths) - self.stats['total_processed_folders']
remaining_time = remaining_folders / folders_per_sec
eta = time.strftime('%H:%M:%S', time.gmtime(remaining_time))
logger.info(f"Estimated time remaining: {eta}")
# Write remaining samples
if valid_samples_buffer:
if self.sqlite_output_path:
self.write_to_sqlite(valid_samples_buffer)
else:
self.write_shard(valid_samples_buffer, self.current_shard)
self.current_shard += 1
# Write debug samples
self.write_debug_file()
# Write user statistics
self.write_user_stats()
# Final progress update
self._save_progress()
# Log final statistics
elapsed = time.time() - start_time
logger.info(
f"Processing complete in {elapsed:.1f}s.\n"
f"Total folders: {len(folder_paths)}\n"
f"Total files: {self.stats['total_files']}\n"
f"Valid samples: {self.stats['valid_count']}\n"
f"Multi-tag samples: {self.stats['multi_tag_count']}\n"
f"Unknown samples: {self.stats['unknown_count']}\n"
f"Errors: {self.stats['error_count']}\n"
f"Shards created: {self.current_shard}"
)
finally:
# Close database connections
self.close_db_connections()
def process_folder(self, folder_path: str, files: List[Dict[str, str]]) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
"""
Process all files in a folder
Args:
folder_path: Folder path identifier
files: List of file info dictionaries in the folder
Returns:
Tuple of (valid_samples, other_samples)
"""
results = []
valid_samples = []
other_samples = []
with ThreadPoolExecutor(max_workers=self.num_workers) as executor:
# Submit tasks
future_to_file = {
executor.submit(self.process_file, file_info): file_info['file_id']
for file_info in files
}
# Process results as they complete
for future in as_completed(future_to_file):
file_id = future_to_file[future]
try:
result = future.result()
results.append(result)
# Update statistics
if result['error'] and not result['valid']:
self.stats['error_count'] += 1
other_samples.append(result)
elif result['valid']:
if result['tags'][0] == "Unknown":
self.stats['unknown_count'] += 1
self.stats['valid_count'] += 1
valid_samples.append(result)
elif len(result['tags']) > 1:
self.stats['multi_tag_count'] += 1
# For multi-tag samples, still consider them valid but track separately
valid_samples.append(result)
other_samples.append(result)
except Exception as e:
logger.error(f"Exception processing file {file_id}: {e}")
self.stats['error_count'] += 1
# Mark folder as processed
self.processed_folders.add(folder_path)
self.stats['total_processed_folders'] = len(self.processed_folders)
self.stats['processed_files'] += len(files)
return valid_samples, other_samples
def main():
"""Main entry point"""
parser = argparse.ArgumentParser(description="Process job scripts and create ML datasets")
parser.add_argument('--db', required=True, help='Path to the SQLite database')
parser.add_argument('--keywords', required=True, help='Path to the keywords CSV file')
parser.add_argument('--output-dir', default='datasets', help='Output directory for datasets')
parser.add_argument('--sqlite-output', help='Path to SQLite database for output (instead of WebDataset shards)')
parser.add_argument('--workers', type=int, default=None, help='Number of parallel workers')
parser.add_argument('--shard-size', type=int, default=1000, help='Number of samples per shard')
parser.add_argument('--resume', action='store_true', help='Resume from previous run')
parser.add_argument('--debug-samples', type=int, default=100,
help='Maximum number of debug samples to keep per category')
parser.add_argument('--start-time-limit', type=int, default=None,
help='Unix timestamp limit for filtering jobs (e.g., for past 2 years: use timestamp from 2 years ago). '
'If not provided, all jobs are included. You can calculate 2 years ago as: '
'python -c "import time; print(int(time.time()) - 2*365*24*3600)"')
# Input source arguments
input_group = parser.add_mutually_exclusive_group(required=True)
input_group.add_argument('--tars', nargs='+', help='Paths to tar archives containing images')
input_group.add_argument('--dirs', nargs='+', help='Paths to directories containing images')
args = parser.parse_args()
# Determine input mode and paths
input_mode = 'tar' if args.tars else 'dir'
input_paths = args.tars if args.tars else args.dirs
print(input_paths)
# Initialize tagger
tagger = JobScriptTagger(
db_path=args.db,
keywords_csv=args.keywords,
output_dir=args.output_dir,
input_mode=input_mode,
input_paths=input_paths,
num_workers=args.workers,
shard_size=args.shard_size,
resume=args.resume,
max_debug_samples=args.debug_samples,
start_time_limit=args.start_time_limit,
sqlite_output_path=args.sqlite_output
)
# Run processing
tagger.process_all()
if __name__ == "__main__":
main()