Add multi_label_booster.py
This commit is contained in:
152
multi_label_booster.py
Normal file
152
multi_label_booster.py
Normal file
@@ -0,0 +1,152 @@
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import ast
|
||||
from sklearn.model_selection import train_test_split
|
||||
from sklearn.preprocessing import MultiLabelBinarizer
|
||||
from sklearn.metrics import classification_report
|
||||
from xgboost import XGBClassifier
|
||||
from sklearn.multiclass import OneVsRestClassifier
|
||||
from sklearn import set_config
|
||||
import joblib
|
||||
set_config(enable_metadata_routing=True)
|
||||
# ---------------------------------------------------------
|
||||
# 1. Load & clean
|
||||
# ---------------------------------------------------------
|
||||
print("Loading data from roofline_features.h5...")
|
||||
df = pd.read_hdf("roofline_features.h5", key="features")
|
||||
print(f"Loaded {len(df)} samples with {len(df.columns)} columns")
|
||||
|
||||
print("Cleaning data...")
|
||||
original_shape = df.shape
|
||||
# Drop empty columns
|
||||
df.dropna(axis=1, how="all", inplace=True)
|
||||
|
||||
# Treat empty strings as NaN, then drop rows with any NaN
|
||||
df.replace(r"^\s*$", np.nan, regex=True, inplace=True)
|
||||
df.dropna(axis=0, how="any", inplace=True)
|
||||
print(f"After cleaning: {len(df)} samples remaining (removed {original_shape[0] - len(df)} rows, {original_shape[1] - len(df.columns)} columns)")
|
||||
|
||||
# ---------------------------------------------------------
|
||||
# 2. Parse label column into Python lists
|
||||
# ---------------------------------------------------------
|
||||
print("Parsing labels...")
|
||||
def parse_label(x):
|
||||
# Convert string like '["OpenFOAM","Gaussian"]' to list
|
||||
if isinstance(x, str):
|
||||
try:
|
||||
return list(ast.literal_eval(x))
|
||||
except Exception:
|
||||
return []
|
||||
elif isinstance(x, list):
|
||||
return x
|
||||
else:
|
||||
return []
|
||||
|
||||
df["label"] = df["label"].apply(parse_label)
|
||||
|
||||
# Drop rows where label list is empty after parsing
|
||||
original_len = len(df)
|
||||
df = df[df["label"].map(len) > 0]
|
||||
print(f"Parsed labels: {len(df)} samples remaining (removed {original_len - len(df)} samples with empty labels)")
|
||||
|
||||
# ---------------------------------------------------------
|
||||
# 3. Features and multi-label target
|
||||
# ---------------------------------------------------------
|
||||
print("Preparing features and targets...")
|
||||
X = df.drop(columns=["job_id", "label"])
|
||||
y_lists = df["label"]
|
||||
|
||||
mlb = MultiLabelBinarizer()
|
||||
Y = mlb.fit_transform(y_lists)
|
||||
all_classes = mlb.classes_
|
||||
|
||||
print(f"Feature matrix shape: {X.shape}")
|
||||
print(f"Target matrix shape: {Y.shape}")
|
||||
print(f"Number of unique classes: {len(all_classes)}")
|
||||
print(f"Classes: {list(all_classes)}")
|
||||
|
||||
# ---------------------------------------------------------
|
||||
# 4. Split (stratification isn't directly supported for multi-label)
|
||||
# ---------------------------------------------------------
|
||||
print("Splitting data into train/validation sets...")
|
||||
X_train, X_val, Y_train, Y_val = train_test_split(
|
||||
X, Y, test_size=0.2, random_state=42
|
||||
)
|
||||
|
||||
print(f"Training set: {X_train.shape[0]} samples")
|
||||
print(f"Validation set: {X_val.shape[0]} samples")
|
||||
|
||||
# ---------------------------------------------------------
|
||||
# 5. Handle imbalance
|
||||
# ---------------------------------------------------------
|
||||
print("Calculating sample weights for class imbalance...")
|
||||
# For each label column, compute a weight: N / (2 * count)
|
||||
# Then weight each sample by sum of its label weights (simple heuristic)
|
||||
label_counts = Y_train.sum(axis=0)
|
||||
weights_per_class = (len(Y_train) / (2.0 * (label_counts + 1e-6)))
|
||||
sample_weights = (Y_train * weights_per_class).sum(axis=1)
|
||||
|
||||
print(f"Label frequencies in training set:")
|
||||
for i, class_name in enumerate(all_classes):
|
||||
count = label_counts[i]
|
||||
percentage = (count / len(Y_train)) * 100
|
||||
print(".1f")
|
||||
|
||||
print(".3f")
|
||||
|
||||
# ---------------------------------------------------------
|
||||
# 6. Train One-vs-Rest XGBoost
|
||||
# ---------------------------------------------------------
|
||||
print("Setting up XGBoost model...")
|
||||
# Each label gets its own binary XGB classifier
|
||||
xgb_base = XGBClassifier(
|
||||
objective="binary:logistic",
|
||||
eval_metric="logloss",
|
||||
tree_method="hist", # or "gpu_hist" if GPU is available
|
||||
learning_rate=0.1,
|
||||
max_depth=6,
|
||||
n_estimators=300,
|
||||
subsample=0.8,
|
||||
colsample_bytree=0.8,
|
||||
random_state=42
|
||||
)
|
||||
|
||||
# Enable metadata routing for sample_weight
|
||||
xgb_base.set_fit_request(sample_weight=True)
|
||||
|
||||
model = OneVsRestClassifier(xgb_base, n_jobs=-1)
|
||||
|
||||
print("Training One-vs-Rest XGBoost model...")
|
||||
# This may take a while depending on dataset size and number of classes
|
||||
model.fit(X_train, Y_train, sample_weight=sample_weights)
|
||||
print("Training completed!")
|
||||
|
||||
# ---------------------------------------------------------
|
||||
# 7. Evaluate
|
||||
# ---------------------------------------------------------
|
||||
print("Evaluating model on validation set...")
|
||||
Y_pred = model.predict(X_val)
|
||||
present = np.unique(np.where(Y_val.sum(axis=0) + Y_pred.sum(axis=0) > 0)[0])
|
||||
|
||||
print("Classification Report:")
|
||||
print("=" * 50)
|
||||
print(
|
||||
classification_report(
|
||||
Y_val,
|
||||
Y_pred,
|
||||
labels=present, # only evaluate classes that exist
|
||||
target_names=all_classes[present],
|
||||
zero_division=0
|
||||
)
|
||||
)
|
||||
|
||||
print("Saving model...")
|
||||
model_data = {
|
||||
'model': model,
|
||||
'mlb': mlb,
|
||||
'feature_columns': list(X.columns)
|
||||
}
|
||||
joblib.dump(model_data, 'xgb_model.joblib')
|
||||
print("Model saved to 'xgb_model.joblib'")
|
||||
print("To load the model later, use: model_data = joblib.load('xgb_model.joblib')")
|
||||
print("Processing complete!")
|
Reference in New Issue
Block a user