This script processes image files from tar archives or directories, extracts job IDs, fetches job scripts from database, and identifies application tags based on keywords. Valid samples are saved to sharded WebDatasets while problematic ones are logged.
1174 lines
48 KiB
Python
1174 lines
48 KiB
Python
#!/usr/bin/env python3
|
|
|
|
import os
|
|
import re
|
|
import csv
|
|
import json
|
|
import sqlite3
|
|
import tarfile
|
|
import logging
|
|
import threading
|
|
import argparse
|
|
import pandas as pd
|
|
import time
|
|
from pathlib import Path
|
|
from typing import Dict, List, Tuple, Set, Any, Optional, Union
|
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
import webdataset as wds
|
|
from PIL import Image
|
|
import io
|
|
import hashlib
|
|
|
|
# Configure logging
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
|
handlers=[
|
|
logging.FileHandler("job_tagger_dataset.log"),
|
|
logging.StreamHandler()
|
|
]
|
|
)
|
|
logger = logging.getLogger("JobScriptTagger")
|
|
|
|
class JobScriptTagger:
|
|
"""Process job scripts and images to create ML datasets"""
|
|
|
|
def __init__(
|
|
self,
|
|
db_path: str,
|
|
keywords_csv: str,
|
|
output_dir: str,
|
|
input_mode: str = 'tar',
|
|
input_paths: List[str] = None,
|
|
num_workers: int = None,
|
|
shard_size: int = 1000,
|
|
resume: bool = False,
|
|
max_debug_samples: int = 100,
|
|
start_time_limit: Optional[int] = None,
|
|
sqlite_output_path: Optional[str] = None
|
|
):
|
|
"""
|
|
Initialize the tagger
|
|
|
|
Args:
|
|
db_path: Path to the SQLite database
|
|
keywords_csv: Path to the CSV file with application names and keywords
|
|
output_dir: Directory for output files
|
|
input_mode: 'tar' for tar archives or 'dir' for directories
|
|
input_paths: List of tar archives or directories containing images
|
|
num_workers: Number of parallel workers (defaults to CPU count)
|
|
shard_size: Number of samples per output shard
|
|
resume: Whether to resume from a previous run
|
|
max_debug_samples: Maximum number of debug samples to keep
|
|
start_time_limit: Unix timestamp limit for filtering jobs (None for no limit)
|
|
sqlite_output_path: Path to SQLite database for output (if None, use WebDataset mode)
|
|
"""
|
|
self.db_path = db_path
|
|
self.keywords_csv = keywords_csv
|
|
self.output_dir = output_dir
|
|
self.input_mode = input_mode
|
|
self.input_paths = input_paths or []
|
|
self.num_workers = num_workers or os.cpu_count()
|
|
self.shard_size = shard_size
|
|
self.resume = resume
|
|
self.max_debug_samples = max_debug_samples
|
|
self.start_time_limit = start_time_limit
|
|
self.sqlite_output_path = sqlite_output_path
|
|
|
|
# Ensure output directory exists (only needed for WebDataset mode)
|
|
if not sqlite_output_path:
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
|
|
# Thread-local storage for database connections
|
|
self.local_db = threading.local()
|
|
|
|
self.app_keywords = {} # Maps app names to lists of keywords
|
|
|
|
# Track processed folders and progress
|
|
if sqlite_output_path:
|
|
self.progress_file = 'progress_sqlite.json'
|
|
else:
|
|
self.progress_file = os.path.join(output_dir, 'progress.json')
|
|
self.processed_folders = set()
|
|
self.current_shard = 0
|
|
self.stats = {
|
|
'valid_count': 0,
|
|
'multi_tag_count': 0,
|
|
'unknown_count': 0,
|
|
'error_count': 0,
|
|
'total_processed_folders': 0,
|
|
'total_folders': 0,
|
|
'total_files': 0,
|
|
'processed_files': 0
|
|
}
|
|
|
|
# User statistics tracking
|
|
self.user_stats = {} # tag -> set of users
|
|
self.user_stats_recent = {} # tag -> set of users (within time limit)
|
|
|
|
# Debug sample collections
|
|
self.multi_tag_samples = []
|
|
|
|
# Multi-label logging - keep track of files per shard
|
|
self.multi_label_logs = {} # shard_id -> list of filenames
|
|
|
|
# Initialize or load progress
|
|
self._init_progress()
|
|
|
|
logger.info(f"Initialized with {self.num_workers} workers, input mode: {input_mode}")
|
|
|
|
def _track_user_stats(self, tag: str, user: str, start_time: Optional[int] = None):
|
|
"""
|
|
Track user statistics for a given tag
|
|
|
|
Args:
|
|
tag: Application tag
|
|
user: HPC user
|
|
start_time: Job start time (Unix timestamp)
|
|
"""
|
|
if not user:
|
|
return
|
|
|
|
# Track overall user stats
|
|
if tag not in self.user_stats:
|
|
self.user_stats[tag] = set()
|
|
self.user_stats[tag].add(user)
|
|
|
|
# Track recent user stats if start_time is within limit
|
|
if self.start_time_limit is not None and start_time is not None:
|
|
if start_time >= self.start_time_limit:
|
|
if tag not in self.user_stats_recent:
|
|
self.user_stats_recent[tag] = set()
|
|
self.user_stats_recent[tag].add(user)
|
|
|
|
def _init_progress(self):
|
|
"""Initialize or load progress tracking"""
|
|
if self.resume and os.path.exists(self.progress_file):
|
|
try:
|
|
with open(self.progress_file, 'r') as f:
|
|
progress_data = json.load(f)
|
|
|
|
self.processed_folders = set(progress_data.get('processed_folders', []))
|
|
self.current_shard = progress_data.get('current_shard', 0)
|
|
self.stats = progress_data.get('stats', self.stats)
|
|
|
|
# Load user stats
|
|
user_stats_data = progress_data.get('user_stats', {})
|
|
self.user_stats = {tag: set(users) for tag, users in user_stats_data.items()}
|
|
|
|
user_stats_recent_data = progress_data.get('user_stats_recent', {})
|
|
self.user_stats_recent = {tag: set(users) for tag, users in user_stats_recent_data.items()}
|
|
|
|
logger.info(f"Resumed from previous run: {len(self.processed_folders)} folders processed, "
|
|
f"current shard: {self.current_shard}")
|
|
except Exception as e:
|
|
logger.warning(f"Failed to load progress file: {e}. Starting fresh.")
|
|
self._save_progress()
|
|
else:
|
|
self._save_progress()
|
|
|
|
def _save_progress(self):
|
|
"""Save current progress to file"""
|
|
progress_data = {
|
|
'processed_folders': list(self.processed_folders),
|
|
'current_shard': self.current_shard,
|
|
'stats': self.stats,
|
|
'user_stats': {tag: list(users) for tag, users in self.user_stats.items()},
|
|
'user_stats_recent': {tag: list(users) for tag, users in self.user_stats_recent.items()},
|
|
'last_updated': time.strftime('%Y-%m-%d %H:%M:%S')
|
|
}
|
|
|
|
with open(self.progress_file, 'w') as f:
|
|
json.dump(progress_data, f, indent=2)
|
|
|
|
logger.debug(f"Progress saved: {self.stats['total_processed_folders']} folders processed")
|
|
|
|
def get_db_connection(self):
|
|
"""
|
|
Get a thread-local database connection
|
|
|
|
Returns:
|
|
SQLite connection object
|
|
"""
|
|
# Check if this thread already has a connection
|
|
if not hasattr(self.local_db, 'conn') or self.local_db.conn is None:
|
|
try:
|
|
self.local_db.conn = sqlite3.connect(self.db_path)
|
|
logger.debug(f"Created new database connection for thread {threading.get_ident()}")
|
|
except sqlite3.Error as e:
|
|
logger.error(f"Database connection error in thread {threading.get_ident()}: {e}")
|
|
raise
|
|
|
|
return self.local_db.conn
|
|
|
|
def close_db_connections(self):
|
|
"""Close all database connections"""
|
|
# Since connections are thread-local, we can only close the one in the current thread
|
|
if hasattr(self.local_db, 'conn') and self.local_db.conn is not None:
|
|
self.local_db.conn.close()
|
|
self.local_db.conn = None
|
|
logger.info("Database connection closed")
|
|
|
|
def load_keywords(self):
|
|
"""Load application keywords from CSV file"""
|
|
try:
|
|
# Read CSV file
|
|
df = pd.read_csv(self.keywords_csv)
|
|
|
|
# Process each row
|
|
for _, row in df.iterrows():
|
|
app_name = row['Name'].strip()
|
|
# Split keywords by semicolon and strip whitespace
|
|
keywords = [kw.strip() for kw in row['Keyword'].split(';') if kw.strip()]
|
|
if app_name and keywords:
|
|
self.app_keywords[app_name] = keywords
|
|
|
|
logger.info(f"Loaded {len(self.app_keywords)} applications with keywords")
|
|
|
|
# Log the apps and their keywords for verification
|
|
for app_name, keywords in self.app_keywords.items():
|
|
logger.debug(f"App: {app_name}, Keywords: {keywords}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error loading keywords from {self.keywords_csv}: {e}")
|
|
raise
|
|
|
|
def extract_job_id(self, filename: str) -> Optional[int]:
|
|
"""
|
|
Extract job ID from image filename
|
|
|
|
Args:
|
|
filename: Image filename in format {job_id}_{timestamp}_color.png
|
|
|
|
Returns:
|
|
Job ID if successfully extracted, None otherwise
|
|
"""
|
|
match = re.match(r'(\d+)_\d+_color\.png$', os.path.basename(filename))
|
|
if match:
|
|
return int(match.group(1))
|
|
return None
|
|
|
|
def get_job_script(self, job_id: int) -> Optional[Dict[str, Any]]:
|
|
"""
|
|
Fetch job script from database
|
|
|
|
Args:
|
|
job_id: Job ID
|
|
|
|
Returns:
|
|
Dict with job script and metadata if found, None otherwise
|
|
"""
|
|
try:
|
|
conn = self.get_db_connection()
|
|
cursor = conn.cursor()
|
|
|
|
query = "SELECT meta_data, duration, hpc_user, start_time FROM job WHERE job_id = ?"
|
|
cursor.execute(query, (job_id,))
|
|
result = cursor.fetchone()
|
|
|
|
if result:
|
|
meta_data = json.loads(result[0]) if result[0] else {}
|
|
duration = result[1]
|
|
hpc_user = result[2]
|
|
start_time = result[3]
|
|
script = meta_data.get('jobScript', '')
|
|
name = meta_data.get('jobName', '')
|
|
|
|
# Check start_time limit if specified
|
|
if self.start_time_limit is not None and start_time is not None:
|
|
if start_time < self.start_time_limit:
|
|
return None # Job is too old
|
|
|
|
return {
|
|
'script': script,
|
|
'meta_data': meta_data,
|
|
'duration': duration,
|
|
'hpc_user': hpc_user,
|
|
'start_time': start_time,
|
|
'interactive': True if name == "interactive" else False
|
|
}
|
|
return None
|
|
|
|
except sqlite3.Error as e:
|
|
logger.error(f"Error fetching job script for job {job_id}: {e}")
|
|
return None
|
|
|
|
def is_valid_script(self, script: str) -> bool:
|
|
"""
|
|
Check if a script is valid (starts with #!)
|
|
|
|
Args:
|
|
script: Job script text
|
|
|
|
Returns:
|
|
True if script is valid, False otherwise
|
|
"""
|
|
if not script:
|
|
return False
|
|
|
|
# Check if script starts with #!
|
|
return script.strip().startswith('#!')
|
|
|
|
def filter_script_comments(self, script: str) -> Tuple[str, List[str]]:
|
|
"""
|
|
Filter out commented and empty lines from script
|
|
|
|
Args:
|
|
script: Job script text
|
|
|
|
Returns:
|
|
Tuple of (filtered_script_text, filtered_lines)
|
|
"""
|
|
lines = script.splitlines()
|
|
filtered_lines = []
|
|
|
|
for line in lines:
|
|
stripped = line.strip()
|
|
# Keep shebang line and non-comment lines
|
|
if stripped and (stripped.startswith('#!') or not stripped.startswith('#')):
|
|
filtered_lines.append(line)
|
|
|
|
filtered_script = '\n'.join(filtered_lines)
|
|
return filtered_script, filtered_lines
|
|
|
|
def check_keywords(self, script: str) -> Tuple[List[str], Dict[str, List[str]]]:
|
|
"""
|
|
Check which applications' keywords match in the script and extract matching lines
|
|
|
|
Args:
|
|
script: Job script text
|
|
|
|
Returns:
|
|
Tuple of (matching_app_names, {app_name: [matching_lines]})
|
|
"""
|
|
# Filter out commented lines
|
|
filtered_script, filtered_lines = self.filter_script_comments(script)
|
|
|
|
script_lower = filtered_script.lower()
|
|
|
|
matched_apps = []
|
|
matched_lines = {}
|
|
|
|
for app_name, keywords in self.app_keywords.items():
|
|
app_matches = []
|
|
|
|
for keyword in keywords:
|
|
keyword_lower = keyword.lower()
|
|
|
|
# Check if keyword is in filtered script
|
|
if keyword_lower in script_lower:
|
|
matched_apps.append(app_name)
|
|
|
|
# Find lines containing this keyword in filtered lines
|
|
for line in filtered_lines:
|
|
if keyword_lower in line.lower():
|
|
app_matches.append(line.strip())
|
|
|
|
# Break after first matching keyword for this app
|
|
break
|
|
|
|
# Save unique matching lines for this app
|
|
if app_matches:
|
|
matched_lines[app_name] = list(set(app_matches))
|
|
|
|
return matched_apps, matched_lines
|
|
|
|
def get_folder_path(self, file_info: Dict[str, str]) -> str:
|
|
"""
|
|
Get the subfolder path for a file
|
|
|
|
Args:
|
|
file_info: File information dictionary
|
|
|
|
Returns:
|
|
Subfolder path string
|
|
"""
|
|
if self.input_mode == 'tar':
|
|
# For tar files, use the parent directory in the archive
|
|
path_parts = file_info['file_path'].split('/')
|
|
if len(path_parts) > 1:
|
|
# Return the tar name and top-level directory
|
|
return f"{os.path.basename(file_info['tar_path'])}:{path_parts[0]}"
|
|
else:
|
|
return os.path.basename(file_info['tar_path'])
|
|
else:
|
|
# For directories, use the relative path's parent directory
|
|
path = Path(file_info['rel_path'])
|
|
if path.parent != Path('.'):
|
|
return f"{os.path.basename(file_info['dir_path'])}/{path.parent}"
|
|
else:
|
|
return os.path.basename(file_info['dir_path'])
|
|
|
|
def list_input_folders(self) -> Dict[str, List[Dict[str, str]]]:
|
|
"""
|
|
List all image files from input sources grouped by subfolder
|
|
|
|
Returns:
|
|
Dictionary mapping folder paths to lists of file information
|
|
"""
|
|
folders = {}
|
|
total_files = 0
|
|
|
|
if self.input_mode == 'tar':
|
|
# Process tar archives
|
|
for tar_path in self.input_paths:
|
|
if not os.path.exists(tar_path):
|
|
logger.warning(f"Tar archive not found: {tar_path}")
|
|
continue
|
|
|
|
try:
|
|
with tarfile.open(tar_path, 'r') as tar:
|
|
for member in tar.getmembers():
|
|
if member.name.endswith('_color.png'):
|
|
file_info = {
|
|
'source_type': 'tar',
|
|
'tar_path': tar_path,
|
|
'file_path': member.name,
|
|
'file_id': f"{tar_path}:{member.name}",
|
|
'filename': os.path.basename(member.name)
|
|
}
|
|
|
|
# Get folder path
|
|
folder_path = self.get_folder_path(file_info)
|
|
|
|
# Skip if folder already processed
|
|
if folder_path in self.processed_folders and self.resume:
|
|
continue
|
|
|
|
# Add to folders dictionary
|
|
if folder_path not in folders:
|
|
folders[folder_path] = []
|
|
folders[folder_path].append(file_info)
|
|
total_files += 1
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error reading tar archive {tar_path}: {e}")
|
|
else:
|
|
# Process directories
|
|
for dir_path in self.input_paths:
|
|
if not os.path.isdir(dir_path):
|
|
logger.warning(f"Directory not found: {dir_path}")
|
|
continue
|
|
|
|
for root, _, files in os.walk(dir_path):
|
|
# Filter image files
|
|
image_files = [f for f in files if f.endswith('_color.png')]
|
|
|
|
if not image_files:
|
|
continue
|
|
|
|
for filename in image_files:
|
|
file_path = os.path.join(root, filename)
|
|
rel_path = os.path.relpath(file_path, dir_path)
|
|
|
|
file_info = {
|
|
'source_type': 'dir',
|
|
'dir_path': dir_path,
|
|
'file_path': file_path,
|
|
'rel_path': rel_path,
|
|
'file_id': f"{dir_path}:{rel_path}",
|
|
'filename': filename
|
|
}
|
|
|
|
# Get folder path
|
|
folder_path = self.get_folder_path(file_info)
|
|
|
|
# Skip if folder already processed
|
|
if folder_path in self.processed_folders and self.resume:
|
|
continue
|
|
|
|
# Add to folders dictionary
|
|
if folder_path not in folders:
|
|
folders[folder_path] = []
|
|
folders[folder_path].append(file_info)
|
|
total_files += 1
|
|
|
|
# Update stats
|
|
self.stats['total_files'] = total_files
|
|
self.stats['total_folders'] = len(folders)
|
|
|
|
logger.info(f"Found {len(folders)} unprocessed folders with {total_files} files")
|
|
return folders
|
|
|
|
def process_file(self, file_info: Dict[str, str]) -> Dict[str, Any]:
|
|
"""
|
|
Process a single image file (worker function for thread pool)
|
|
|
|
Args:
|
|
file_info: Dictionary with file information
|
|
|
|
Returns:
|
|
Dictionary with processing results
|
|
"""
|
|
result = {
|
|
'file_info': file_info,
|
|
'job_id': None,
|
|
'tags': [],
|
|
'valid': False,
|
|
'error': None,
|
|
'script': None,
|
|
'script_excerpt': None,
|
|
'matching_lines': {},
|
|
'duration': None,
|
|
'hpc_user': None,
|
|
'start_time': None
|
|
}
|
|
|
|
try:
|
|
# Extract job ID from filename
|
|
job_id = self.extract_job_id(file_info['filename'])
|
|
if not job_id:
|
|
result['error'] = f"Could not extract job ID from filename: {file_info['filename']}"
|
|
return result
|
|
|
|
result['job_id'] = job_id
|
|
|
|
# Get job script
|
|
job_data = self.get_job_script(job_id)
|
|
|
|
# Check if job was filtered out by start_time_limit
|
|
if job_data is None:
|
|
result['error'] = f"Job {job_id} filtered out by start_time_limit"
|
|
result['tags'] = ["Filtered"]
|
|
result['valid'] = False
|
|
return result
|
|
|
|
# Store duration from job data
|
|
result['duration'] = job_data.get('duration') if job_data else None
|
|
result['hpc_user'] = job_data.get('hpc_user') if job_data else None
|
|
result['start_time'] = job_data.get('start_time') if job_data else None
|
|
|
|
if job_data['interactive']:
|
|
result['tags'] = ["Interactive"]
|
|
result['valid'] = True
|
|
# Track user stats for interactive jobs
|
|
if job_data.get('hpc_user'):
|
|
self._track_user_stats("Interactive", job_data['hpc_user'], job_data.get('start_time'))
|
|
return result
|
|
|
|
# If no script, assign "Unknown" app label
|
|
if not job_data or not job_data['script']:
|
|
result['tags'] = ["Unknown"]
|
|
result['valid'] = False
|
|
# Track user stats for unknown jobs
|
|
if job_data and job_data.get('hpc_user'):
|
|
self._track_user_stats("Unknown", job_data['hpc_user'], job_data.get('start_time'))
|
|
# result['error'] = "No job script found"
|
|
return result
|
|
|
|
script = job_data['script']
|
|
|
|
# If script doesn't start with #!, assign "Unknown" app label
|
|
if not self.is_valid_script(script):
|
|
result['tags'] = ["Unknown"]
|
|
result['valid'] = True # Mark as valid with "Unknown" tag
|
|
result['error'] = "Script doesn't start with #!"
|
|
# Track user stats for unknown jobs
|
|
if job_data.get('hpc_user'):
|
|
self._track_user_stats("Unknown", job_data['hpc_user'], job_data.get('start_time'))
|
|
return result
|
|
|
|
# Store script excerpt for debugging
|
|
lines = script.strip().split('\n')
|
|
result['script_excerpt'] = '\n'.join(lines[:60])
|
|
if len(lines) > 60:
|
|
result['script_excerpt'] += f"\n... [+ {len(lines) - 60} more lines]"
|
|
|
|
# Store full script
|
|
result['script'] = script
|
|
|
|
# Check keywords
|
|
matched_apps, matching_lines = self.check_keywords(script)
|
|
|
|
# No matches, assign "Unknown"
|
|
if not matched_apps:
|
|
result['tags'] = ["Unknown"]
|
|
result['valid'] = True
|
|
result['error'] = "No matching keywords found"
|
|
# Track user stats for unknown jobs
|
|
if job_data.get('hpc_user'):
|
|
self._track_user_stats("Unknown", job_data['hpc_user'], job_data.get('start_time'))
|
|
else:
|
|
result['tags'] = matched_apps
|
|
result['matching_lines'] = matching_lines
|
|
result['valid'] = len(matched_apps) == 1 # Valid if exactly one tag
|
|
|
|
# Track user stats for matched apps
|
|
if job_data.get('hpc_user'):
|
|
for tag in matched_apps:
|
|
self._track_user_stats(tag, job_data['hpc_user'], job_data.get('start_time'))
|
|
|
|
except Exception as e:
|
|
result['error'] = str(e)
|
|
result['tags'] = ["Unknown"] # Default to "Unknown" on error
|
|
result['valid'] = False
|
|
logger.error(f"Error processing file {file_info['file_id']}: {e}")
|
|
|
|
# Track user stats for error cases if we have job data
|
|
if 'job_data' in locals() and job_data and job_data.get('hpc_user'):
|
|
self._track_user_stats("Unknown", job_data['hpc_user'], job_data.get('start_time'))
|
|
|
|
return result
|
|
|
|
def read_image_data(self, file_info: Dict[str, str]) -> Optional[bytes]:
|
|
"""
|
|
Read image data from the source (tar or directory)
|
|
|
|
Args:
|
|
file_info: File information dictionary
|
|
|
|
Returns:
|
|
Image data as bytes if successful, None otherwise
|
|
"""
|
|
try:
|
|
if file_info['source_type'] == 'tar':
|
|
# Extract from tar archive
|
|
with tarfile.open(file_info['tar_path'], 'r') as tar:
|
|
img_file = tar.extractfile(file_info['file_path'])
|
|
if img_file:
|
|
return img_file.read()
|
|
else:
|
|
# Read from filesystem
|
|
with open(file_info['file_path'], 'rb') as f:
|
|
return f.read()
|
|
|
|
return None
|
|
except Exception as e:
|
|
logger.error(f"Error reading image {file_info['file_id']}: {e}")
|
|
return None
|
|
|
|
def log_multi_label_files(self, shard_idx: int, multi_label_samples: List[Dict[str, Any]]):
|
|
"""
|
|
Log files with multiple labels to a file
|
|
|
|
Args:
|
|
shard_idx: Shard index number (ignored in SQLite mode)
|
|
multi_label_samples: List of samples with multiple labels
|
|
"""
|
|
if not multi_label_samples:
|
|
return
|
|
|
|
if self.sqlite_output_path:
|
|
log_path = 'multi_label_sqlite.txt'
|
|
else:
|
|
log_path = os.path.join(self.output_dir, f"multi_label_shard_{shard_idx:06d}.txt")
|
|
|
|
with open(log_path, 'w') as f:
|
|
f.write(f"Multi-label files in shard {shard_idx} ({len(multi_label_samples)} files):\n\n")
|
|
|
|
for i, sample in enumerate(multi_label_samples, 1):
|
|
f.write(f"File {i}: {sample['file_info']['filename']}\n")
|
|
f.write(f"Job ID: {sample['job_id']}\n")
|
|
f.write(f"Labels: {', '.join(sample['tags'])}\n")
|
|
|
|
# Write matching lines for each app
|
|
for app, lines in sample['matching_lines'].items():
|
|
f.write(f"\n {app} matching lines:\n")
|
|
for line in lines:
|
|
f.write(f" {line}\n")
|
|
|
|
f.write("\n" + "-" * 80 + "\n\n")
|
|
|
|
logger.info(f"Logged {len(multi_label_samples)} multi-label files to {log_path}")
|
|
|
|
def log_unknown_script_files(self, shard_idx: int, unknown_samples: List[Dict[str, Any]]):
|
|
"""
|
|
Log files with unknown scripts to a file
|
|
|
|
Args:
|
|
shard_idx: Shard index number (ignored in SQLite mode)
|
|
unknown_samples: List of samples with unknown scripts
|
|
"""
|
|
if not unknown_samples:
|
|
return
|
|
|
|
if self.sqlite_output_path:
|
|
log_path = 'unknown_script_sqlite.txt'
|
|
else:
|
|
log_path = os.path.join(self.output_dir, f"unknown_script_shard_{shard_idx:06d}.txt")
|
|
|
|
with open(log_path, 'w', encoding='utf-8') as f:
|
|
f.write(f"Unknown script files in shard {shard_idx} ({len(unknown_samples)} files):\n\n")
|
|
|
|
for i, sample in enumerate(unknown_samples, 1):
|
|
f.write(f"File {i}: {sample['file_info']['filename']}\n")
|
|
f.write(f"Job ID: {sample['job_id']}\n")
|
|
f.write(f"Error: {sample.get('error', '')}\n")
|
|
|
|
# Filter out comment lines (starting with #) from script excerpt
|
|
script_excerpt = sample.get('script_excerpt', '')
|
|
if script_excerpt:
|
|
lines = script_excerpt.split('\n')
|
|
filtered_lines = [line for line in lines if not line.strip().startswith('#')]
|
|
filtered_excerpt = '\n'.join(filtered_lines)
|
|
f.write(f"Script excerpt (comments removed):\n{filtered_excerpt}\n")
|
|
else:
|
|
f.write("Script excerpt: (none)\n")
|
|
|
|
f.write("\n" + "-" * 80 + "\n\n")
|
|
|
|
logger.info(f"Logged {len(unknown_samples)} unknown script files to {log_path}")
|
|
|
|
def write_shard(self, samples: List[Dict[str, Any]], shard_idx: int):
|
|
"""
|
|
Write samples to a WebDataset shard
|
|
|
|
Args:
|
|
samples: List of samples to write
|
|
shard_idx: Shard index number
|
|
"""
|
|
if not samples:
|
|
return
|
|
|
|
# Create shard filename with proper padding
|
|
shard_path = os.path.join(
|
|
self.output_dir,
|
|
f"roofline_dataset_{shard_idx:06d}.tar"
|
|
)
|
|
|
|
# Track multi-label and unknown samples for this shard
|
|
multi_label_samples = []
|
|
unknown_samples = []
|
|
|
|
valid_count = 0
|
|
with wds.TarWriter(shard_path) as sink:
|
|
for sample in samples:
|
|
try:
|
|
file_info = sample['file_info']
|
|
|
|
# Extract image data
|
|
img_data = self.read_image_data(file_info)
|
|
|
|
if not img_data:
|
|
logger.warning(f"Could not read image for job {sample['job_id']}")
|
|
continue
|
|
|
|
# Check if this is a multi-label sample
|
|
if len(sample['tags']) > 1:
|
|
multi_label_samples.append(sample)
|
|
# Check if this is an unknown sample
|
|
if sample['tags'] == ["Unknown"]:
|
|
unknown_samples.append(sample)
|
|
|
|
# Create a unique key from original filename
|
|
key = file_info['filename'].split('.')[0] # Remove extension
|
|
|
|
# Extract lines that contain the keywords for the chosen app
|
|
keyword_lines = []
|
|
if sample['tags'][0] != "Unknown" and sample['tags'][0] in sample['matching_lines']:
|
|
keyword_lines = sample['matching_lines'][sample['tags'][0]]
|
|
|
|
# Create sample
|
|
wds_sample = {
|
|
"__key__": key,
|
|
"png": img_data,
|
|
"json": json.dumps({
|
|
"job_id": sample['job_id'],
|
|
"tags": sample['tags'], # Use the first tag for single-label classification
|
|
"script": keyword_lines, # Add lines containing detected keywords
|
|
"duration": sample.get('duration'), # Add job duration
|
|
"hpc_user": sample.get('hpc_user'), # Add HPC user
|
|
"start_time": sample.get('start_time') # Add job start time
|
|
})
|
|
}
|
|
|
|
# Write to WebDataset
|
|
sink.write(wds_sample)
|
|
valid_count += 1
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error adding sample {sample.get('job_id')} to WebDataset: {e}")
|
|
|
|
# Log multi-label samples for this shard
|
|
if multi_label_samples:
|
|
self.log_multi_label_files(shard_idx, multi_label_samples)
|
|
|
|
# Log unknown samples for this shard
|
|
if unknown_samples:
|
|
self.log_unknown_script_files(shard_idx, unknown_samples)
|
|
|
|
logger.info(f"Created shard {shard_path} with {valid_count} samples")
|
|
return valid_count
|
|
|
|
def write_to_sqlite(self, samples: List[Dict[str, Any]]):
|
|
"""
|
|
Write samples to SQLite database
|
|
|
|
Args:
|
|
samples: List of samples to write
|
|
"""
|
|
if not samples:
|
|
return
|
|
|
|
try:
|
|
# Connect to output database
|
|
conn = sqlite3.connect(self.sqlite_output_path)
|
|
cursor = conn.cursor()
|
|
|
|
# Create table if it doesn't exist
|
|
cursor.execute('''
|
|
CREATE TABLE IF NOT EXISTS job_tags (
|
|
job_id INTEGER PRIMARY KEY,
|
|
tags TEXT NOT NULL,
|
|
script_lines TEXT,
|
|
duration REAL,
|
|
hpc_user TEXT,
|
|
start_time INTEGER,
|
|
filename TEXT,
|
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
|
)
|
|
''')
|
|
|
|
# Track multi-label and unknown samples for logging
|
|
multi_label_samples = []
|
|
unknown_samples = []
|
|
|
|
# Insert samples
|
|
for sample in samples:
|
|
try:
|
|
# Check if this is a multi-label sample
|
|
if len(sample['tags']) > 1:
|
|
multi_label_samples.append(sample)
|
|
# Check if this is an unknown sample
|
|
if sample['tags'] == ["Unknown"]:
|
|
unknown_samples.append(sample)
|
|
|
|
# Extract lines that contain the keywords for the chosen app
|
|
keyword_lines = []
|
|
if sample['tags'] and sample['tags'][0] != "Unknown" and sample['tags'][0] in sample.get('matching_lines', {}):
|
|
keyword_lines = sample['matching_lines'][sample['tags'][0]]
|
|
|
|
# Create filename key
|
|
filename = sample['file_info']['filename'].split('.')[0] # Remove extension
|
|
|
|
cursor.execute('''
|
|
INSERT OR REPLACE INTO job_tags
|
|
(job_id, tags, script_lines, duration, hpc_user, start_time, filename)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?)
|
|
''', (
|
|
sample['job_id'],
|
|
json.dumps(sample['tags']),
|
|
json.dumps(keyword_lines),
|
|
sample.get('duration'),
|
|
sample.get('hpc_user'),
|
|
sample.get('start_time'),
|
|
filename
|
|
))
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error inserting sample {sample.get('job_id')} to SQLite: {e}")
|
|
|
|
conn.commit()
|
|
logger.info(f"Inserted {len(samples)} samples to SQLite database")
|
|
|
|
# Log multi-label samples
|
|
if multi_label_samples:
|
|
self.log_multi_label_files(0, multi_label_samples) # Use 0 as shard_idx for SQLite mode
|
|
|
|
# Log unknown samples
|
|
if unknown_samples:
|
|
self.log_unknown_script_files(0, unknown_samples) # Use 0 as shard_idx for SQLite mode
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error writing to SQLite database: {e}")
|
|
raise
|
|
finally:
|
|
if 'conn' in locals():
|
|
conn.close()
|
|
|
|
def update_debug_samples(self, results: List[Dict[str, Any]]):
|
|
"""
|
|
Update debug samples collections with new results
|
|
|
|
Args:
|
|
results: List of processing results
|
|
"""
|
|
for result in results:
|
|
# Only track multi-tag samples for debugging
|
|
if len(result['tags']) > 1:
|
|
# Add to multi-tag samples if not full
|
|
if len(self.multi_tag_samples) < self.max_debug_samples:
|
|
self.multi_tag_samples.append(result)
|
|
|
|
def write_debug_file(self):
|
|
"""Write debug samples to file"""
|
|
if self.sqlite_output_path:
|
|
debug_path = 'multi_tag_samples_sqlite.json'
|
|
else:
|
|
debug_path = os.path.join(self.output_dir, 'multi_tag_samples.json')
|
|
|
|
# Process multi-tag samples
|
|
debug_data = []
|
|
for sample in self.multi_tag_samples:
|
|
debug_entry = {
|
|
"job_id": sample['job_id'],
|
|
"file_id": sample['file_info']['file_id'],
|
|
"filename": sample['file_info']['filename'],
|
|
"tags": sample['tags'],
|
|
"tag_count": len(sample['tags']),
|
|
"matching_lines": sample['matching_lines']
|
|
}
|
|
debug_data.append(debug_entry)
|
|
|
|
# Write debug file
|
|
with open(debug_path, 'w') as f:
|
|
json.dump(debug_data, f, indent=2)
|
|
|
|
logger.info(f"Wrote {len(debug_data)} multi-tag samples to {debug_path}")
|
|
|
|
def write_user_stats(self):
|
|
"""Write user statistics to file"""
|
|
if self.sqlite_output_path:
|
|
user_stats_path = 'user_statistics_sqlite.json'
|
|
else:
|
|
user_stats_path = os.path.join(self.output_dir, 'user_statistics.json')
|
|
|
|
# Prepare user stats data
|
|
stats_data = {
|
|
'overall_stats': {tag: len(users) for tag, users in self.user_stats.items()},
|
|
'overall_users': {tag: sorted(list(users)) for tag, users in self.user_stats.items()},
|
|
'total_unique_users_overall': len(set().union(*self.user_stats.values())) if self.user_stats else 0
|
|
}
|
|
|
|
# Add recent stats if time limit is set
|
|
if self.start_time_limit is not None:
|
|
stats_data['recent_stats'] = {tag: len(users) for tag, users in self.user_stats_recent.items()}
|
|
stats_data['recent_users'] = {tag: sorted(list(users)) for tag, users in self.user_stats_recent.items()}
|
|
stats_data['total_unique_users_recent'] = len(set().union(*self.user_stats_recent.values())) if self.user_stats_recent else 0
|
|
stats_data['start_time_limit'] = self.start_time_limit
|
|
stats_data['start_time_limit_readable'] = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(self.start_time_limit))
|
|
|
|
# Write user stats file
|
|
with open(user_stats_path, 'w') as f:
|
|
json.dump(stats_data, f, indent=2)
|
|
|
|
logger.info(f"Wrote user statistics to {user_stats_path}")
|
|
|
|
# Log summary
|
|
logger.info("User Statistics Summary:")
|
|
for tag, count in stats_data['overall_stats'].items():
|
|
logger.info(f" {tag}: {count} distinct users")
|
|
logger.info(f" Total unique users overall: {stats_data['total_unique_users_overall']}")
|
|
|
|
if self.start_time_limit is not None:
|
|
logger.info(f"Recent User Statistics (since {stats_data['start_time_limit_readable']}):")
|
|
for tag, count in stats_data['recent_stats'].items():
|
|
logger.info(f" {tag}: {count} distinct users")
|
|
logger.info(f" Total unique users recent: {stats_data['total_unique_users_recent']}")
|
|
|
|
def process_all(self):
|
|
"""
|
|
Main processing function
|
|
"""
|
|
# Load keywords
|
|
self.load_keywords()
|
|
|
|
try:
|
|
# Get all folders with image files
|
|
folders = self.list_input_folders()
|
|
|
|
if not folders:
|
|
logger.info("No unprocessed folders found.")
|
|
return
|
|
|
|
# Sort folders for more predictable processing
|
|
folder_paths = sorted(folders.keys())
|
|
|
|
valid_samples_buffer = []
|
|
start_time = time.time()
|
|
|
|
# Process each folder
|
|
for i, folder_path in enumerate(folder_paths):
|
|
files = folders[folder_path]
|
|
logger.info(f"Processing folder {i+1}/{len(folder_paths)}: {folder_path} ({len(files)} files)")
|
|
|
|
# Process folder
|
|
valid_samples, other_samples = self.process_folder(folder_path, files)
|
|
|
|
# Add to buffers
|
|
valid_samples_buffer.extend(valid_samples)
|
|
|
|
# Update debug samples
|
|
self.update_debug_samples(other_samples)
|
|
|
|
# Write samples if buffer is large enough
|
|
if self.sqlite_output_path:
|
|
# SQLite mode: write all buffered samples at once
|
|
if valid_samples_buffer:
|
|
self.write_to_sqlite(valid_samples_buffer)
|
|
valid_samples_buffer = []
|
|
else:
|
|
# WebDataset mode: write shards
|
|
if len(valid_samples_buffer) >= self.shard_size:
|
|
# Write full shards
|
|
while len(valid_samples_buffer) >= self.shard_size:
|
|
shard_samples = valid_samples_buffer[:self.shard_size]
|
|
valid_samples_buffer = valid_samples_buffer[self.shard_size:]
|
|
|
|
self.write_shard(shard_samples, self.current_shard)
|
|
self.current_shard += 1
|
|
|
|
# Save progress
|
|
self._save_progress()
|
|
|
|
# Calculate and log progress
|
|
elapsed = time.time() - start_time
|
|
progress = self.stats['total_processed_folders'] / len(folder_paths) * 100
|
|
folders_per_sec = self.stats['total_processed_folders'] / elapsed if elapsed > 0 else 0
|
|
|
|
logger.info(f"Progress: {self.stats['total_processed_folders']}/{len(folder_paths)} folders "
|
|
f"({progress:.1f}%) - {folders_per_sec:.2f} folders/sec")
|
|
|
|
logger.info(f"Files: {self.stats['processed_files']}/{self.stats['total_files']} "
|
|
f"(Valid: {self.stats['valid_count']}, Multi-tag: {self.stats['multi_tag_count']}, "
|
|
f"Unknown: {self.stats['unknown_count']}, Errors: {self.stats['error_count']})")
|
|
|
|
# Project completion time
|
|
if folders_per_sec > 0:
|
|
remaining_folders = len(folder_paths) - self.stats['total_processed_folders']
|
|
remaining_time = remaining_folders / folders_per_sec
|
|
eta = time.strftime('%H:%M:%S', time.gmtime(remaining_time))
|
|
logger.info(f"Estimated time remaining: {eta}")
|
|
|
|
# Write remaining samples
|
|
if valid_samples_buffer:
|
|
if self.sqlite_output_path:
|
|
self.write_to_sqlite(valid_samples_buffer)
|
|
else:
|
|
self.write_shard(valid_samples_buffer, self.current_shard)
|
|
self.current_shard += 1
|
|
|
|
# Write debug samples
|
|
self.write_debug_file()
|
|
|
|
# Write user statistics
|
|
self.write_user_stats()
|
|
|
|
# Final progress update
|
|
self._save_progress()
|
|
|
|
# Log final statistics
|
|
elapsed = time.time() - start_time
|
|
logger.info(
|
|
f"Processing complete in {elapsed:.1f}s.\n"
|
|
f"Total folders: {len(folder_paths)}\n"
|
|
f"Total files: {self.stats['total_files']}\n"
|
|
f"Valid samples: {self.stats['valid_count']}\n"
|
|
f"Multi-tag samples: {self.stats['multi_tag_count']}\n"
|
|
f"Unknown samples: {self.stats['unknown_count']}\n"
|
|
f"Errors: {self.stats['error_count']}\n"
|
|
f"Shards created: {self.current_shard}"
|
|
)
|
|
|
|
finally:
|
|
# Close database connections
|
|
self.close_db_connections()
|
|
|
|
def process_folder(self, folder_path: str, files: List[Dict[str, str]]) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
|
|
"""
|
|
Process all files in a folder
|
|
|
|
Args:
|
|
folder_path: Folder path identifier
|
|
files: List of file info dictionaries in the folder
|
|
|
|
Returns:
|
|
Tuple of (valid_samples, other_samples)
|
|
"""
|
|
results = []
|
|
valid_samples = []
|
|
other_samples = []
|
|
|
|
with ThreadPoolExecutor(max_workers=self.num_workers) as executor:
|
|
# Submit tasks
|
|
future_to_file = {
|
|
executor.submit(self.process_file, file_info): file_info['file_id']
|
|
for file_info in files
|
|
}
|
|
|
|
# Process results as they complete
|
|
for future in as_completed(future_to_file):
|
|
file_id = future_to_file[future]
|
|
try:
|
|
result = future.result()
|
|
results.append(result)
|
|
|
|
# Update statistics
|
|
if result['error'] and not result['valid']:
|
|
self.stats['error_count'] += 1
|
|
other_samples.append(result)
|
|
elif result['valid']:
|
|
if result['tags'][0] == "Unknown":
|
|
self.stats['unknown_count'] += 1
|
|
self.stats['valid_count'] += 1
|
|
valid_samples.append(result)
|
|
elif len(result['tags']) > 1:
|
|
self.stats['multi_tag_count'] += 1
|
|
# For multi-tag samples, still consider them valid but track separately
|
|
valid_samples.append(result)
|
|
other_samples.append(result)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Exception processing file {file_id}: {e}")
|
|
self.stats['error_count'] += 1
|
|
|
|
# Mark folder as processed
|
|
self.processed_folders.add(folder_path)
|
|
self.stats['total_processed_folders'] = len(self.processed_folders)
|
|
self.stats['processed_files'] += len(files)
|
|
|
|
return valid_samples, other_samples
|
|
|
|
|
|
def main():
|
|
"""Main entry point"""
|
|
parser = argparse.ArgumentParser(description="Process job scripts and create ML datasets")
|
|
|
|
parser.add_argument('--db', required=True, help='Path to the SQLite database')
|
|
parser.add_argument('--keywords', required=True, help='Path to the keywords CSV file')
|
|
parser.add_argument('--output-dir', default='datasets', help='Output directory for datasets')
|
|
parser.add_argument('--sqlite-output', help='Path to SQLite database for output (instead of WebDataset shards)')
|
|
parser.add_argument('--workers', type=int, default=None, help='Number of parallel workers')
|
|
parser.add_argument('--shard-size', type=int, default=1000, help='Number of samples per shard')
|
|
parser.add_argument('--resume', action='store_true', help='Resume from previous run')
|
|
parser.add_argument('--debug-samples', type=int, default=100,
|
|
help='Maximum number of debug samples to keep per category')
|
|
parser.add_argument('--start-time-limit', type=int, default=None,
|
|
help='Unix timestamp limit for filtering jobs (e.g., for past 2 years: use timestamp from 2 years ago). '
|
|
'If not provided, all jobs are included. You can calculate 2 years ago as: '
|
|
'python -c "import time; print(int(time.time()) - 2*365*24*3600)"')
|
|
|
|
# Input source arguments
|
|
input_group = parser.add_mutually_exclusive_group(required=True)
|
|
input_group.add_argument('--tars', nargs='+', help='Paths to tar archives containing images')
|
|
input_group.add_argument('--dirs', nargs='+', help='Paths to directories containing images')
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Determine input mode and paths
|
|
input_mode = 'tar' if args.tars else 'dir'
|
|
input_paths = args.tars if args.tars else args.dirs
|
|
|
|
print(input_paths)
|
|
|
|
# Initialize tagger
|
|
tagger = JobScriptTagger(
|
|
db_path=args.db,
|
|
keywords_csv=args.keywords,
|
|
output_dir=args.output_dir,
|
|
input_mode=input_mode,
|
|
input_paths=input_paths,
|
|
num_workers=args.workers,
|
|
shard_size=args.shard_size,
|
|
resume=args.resume,
|
|
max_debug_samples=args.debug_samples,
|
|
start_time_limit=args.start_time_limit,
|
|
sqlite_output_path=args.sqlite_output
|
|
)
|
|
|
|
# Run processing
|
|
tagger.process_all()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main() |