#!/usr/bin/env python3
"""
generate_config.py

Generate optimized spaCy training configuration for text classification.

Usage:
    python generate_config.py --categories "Business,Technology,Programming,DevOps"
    python generate_config.py --categories "pos,neg" --architecture bow --output sentiment_config.cfg
"""

import argparse
import sys
from pathlib import Path


CONFIG_TEMPLATES = {
    "ensemble": '''# spaCy TextCategorizer Configuration
# Architecture: TextCatEnsemble (balanced speed/accuracy)
# Generated by generate_config.py

[paths]
train = "train.spacy"
dev = "dev.spacy"
vectors = null

[system]
gpu_allocator = null
seed = 0

[nlp]
lang = "en"
pipeline = ["{component}"]
batch_size = 128

[components]

[components.{component}]
factory = "{component}"
threshold = 0.5

[components.{component}.model]
@architectures = "spacy.TextCatEnsemble.v2"
nO = null

[components.{component}.model.tok2vec]
@architectures = "spacy.Tok2Vec.v2"

[components.{component}.model.tok2vec.embed]
@architectures = "spacy.MultiHashEmbed.v2"
width = 64
rows = [2000, 2000, 500, 1000, 500]
attrs = ["NORM", "LOWER", "PREFIX", "SUFFIX", "SHAPE"]
include_static_vectors = false

[components.{component}.model.tok2vec.encode]
@architectures = "spacy.MaxoutWindowEncoder.v2"
width = ${{components.{component}.model.tok2vec.embed.width}}
window_size = 1
maxout_pieces = 3
depth = 2

[components.{component}.model.linear_model]
@architectures = "spacy.TextCatBOW.v3"
exclusive_classes = {exclusive}
length = 262144
ngram_size = 1
no_output_layer = false

[corpora]

[corpora.train]
@readers = "spacy.Corpus.v1"
path = ${{paths.train}}
max_length = 0

[corpora.dev]
@readers = "spacy.Corpus.v1"
path = ${{paths.dev}}
max_length = 0

[training]
dev_corpus = "corpora.dev"
train_corpus = "corpora.train"
seed = ${{system.seed}}
dropout = 0.1
patience = 1600
max_steps = 20000
eval_frequency = 200
frozen_components = []
annotating_components = []

[training.batcher]
@batchers = "spacy.batch_by_words.v1"
discard_oversize = false
tolerance = 0.2

[training.batcher.size]
@schedules = "compounding.v1"
start = 100
stop = 1000
compound = 1.001

[training.optimizer]
@optimizers = "Adam.v1"
beta1 = 0.9
beta2 = 0.999
L2 = 0.01
grad_clip = 1.0

[training.optimizer.learn_rate]
@schedules = "warmup_linear.v1"
warmup_steps = 250
total_steps = 20000
initial_rate = 0.00005

[training.score_weights]
cats_score = 1.0
cats_score_desc = null
cats_micro_p = null
cats_micro_r = null
cats_micro_f = null
cats_macro_p = null
cats_macro_r = null
cats_macro_f = null
cats_macro_auc = null
cats_f_per_type = null
cats_macro_auc_per_type = null

[pretraining]

[initialize]
vectors = ${{paths.vectors}}
''',

    "bow": '''# spaCy TextCategorizer Configuration
# Architecture: TextCatBOW (fastest, good baseline)
# Generated by generate_config.py

[paths]
train = "train.spacy"
dev = "dev.spacy"

[system]
gpu_allocator = null
seed = 0

[nlp]
lang = "en"
pipeline = ["{component}"]
batch_size = 256

[components]

[components.{component}]
factory = "{component}"
threshold = 0.5

[components.{component}.model]
@architectures = "spacy.TextCatBOW.v3"
exclusive_classes = {exclusive}
ngram_size = 1
no_output_layer = false
length = 262144

[corpora]

[corpora.train]
@readers = "spacy.Corpus.v1"
path = ${{paths.train}}
max_length = 0

[corpora.dev]
@readers = "spacy.Corpus.v1"
path = ${{paths.dev}}
max_length = 0

[training]
dev_corpus = "corpora.dev"
train_corpus = "corpora.train"
seed = ${{system.seed}}
dropout = 0.1
patience = 800
max_steps = 10000
eval_frequency = 100

[training.batcher]
@batchers = "spacy.batch_by_words.v1"
discard_oversize = false
tolerance = 0.2

[training.batcher.size]
@schedules = "compounding.v1"
start = 100
stop = 1000
compound = 1.001

[training.optimizer]
@optimizers = "Adam.v1"
beta1 = 0.9
beta2 = 0.999
L2 = 0.01
grad_clip = 1.0

[training.optimizer.learn_rate]
@schedules = "warmup_linear.v1"
warmup_steps = 100
total_steps = 10000
initial_rate = 0.001

[training.score_weights]
cats_score = 1.0

[initialize]
''',

    "cnn": '''# spaCy TextCategorizer Configuration
# Architecture: TextCatCNN (better for longer texts)
# Generated by generate_config.py

[paths]
train = "train.spacy"
dev = "dev.spacy"
vectors = null

[system]
gpu_allocator = null
seed = 0

[nlp]
lang = "en"
pipeline = ["{component}"]
batch_size = 128

[components]

[components.{component}]
factory = "{component}"
threshold = 0.5

[components.{component}.model]
@architectures = "spacy.TextCatCNN.v2"
exclusive_classes = {exclusive}
nO = null

[components.{component}.model.tok2vec]
@architectures = "spacy.HashEmbedCNN.v2"
pretrained_vectors = null
width = 96
depth = 4
embed_size = 2000
window_size = 1
maxout_pieces = 3
subword_features = true

[corpora]

[corpora.train]
@readers = "spacy.Corpus.v1"
path = ${{paths.train}}
max_length = 0

[corpora.dev]
@readers = "spacy.Corpus.v1"
path = ${{paths.dev}}
max_length = 0

[training]
dev_corpus = "corpora.dev"
train_corpus = "corpora.train"
seed = ${{system.seed}}
dropout = 0.1
patience = 1600
max_steps = 20000
eval_frequency = 200

[training.batcher]
@batchers = "spacy.batch_by_words.v1"
discard_oversize = false
tolerance = 0.2

[training.batcher.size]
@schedules = "compounding.v1"
start = 100
stop = 1000
compound = 1.001

[training.optimizer]
@optimizers = "Adam.v1"
beta1 = 0.9
beta2 = 0.999
L2 = 0.01
grad_clip = 1.0

[training.optimizer.learn_rate]
@schedules = "warmup_linear.v1"
warmup_steps = 250
total_steps = 20000
initial_rate = 0.00005

[training.score_weights]
cats_score = 1.0

[initialize]
vectors = ${{paths.vectors}}
'''
}


def generate_config(categories: list, architecture: str, multilabel: bool) -> str:
    """Generate config from template."""
    template = CONFIG_TEMPLATES.get(architecture)
    if not template:
        raise ValueError(f"Unknown architecture: {architecture}")
    
    component = "textcat_multilabel" if multilabel else "textcat"
    exclusive = "false" if multilabel else "true"
    
    return template.format(
        component=component,
        exclusive=exclusive
    )


def main():
    parser = argparse.ArgumentParser(
        description="Generate spaCy text classification config",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Architectures:
    ensemble  - TextCatEnsemble (default, balanced speed/accuracy)
    bow       - TextCatBOW (fastest, good baseline)
    cnn       - TextCatCNN (better for longer texts)

Examples:
    # Generate default config
    python generate_config.py --categories "Business,Technology,Programming"

    # Fast baseline
    python generate_config.py --categories "pos,neg" --architecture bow

    # Multi-label classification
    python generate_config.py --categories "urgent,billing,support" --multilabel
        """
    )
    
    parser.add_argument(
        "--categories", "-c",
        required=True,
        help="Comma-separated list of category names"
    )
    parser.add_argument(
        "--architecture", "-a",
        choices=["ensemble", "bow", "cnn"],
        default="ensemble",
        help="Model architecture (default: ensemble)"
    )
    parser.add_argument(
        "--multilabel", "-m",
        action="store_true",
        help="Enable multi-label classification"
    )
    parser.add_argument(
        "--output", "-o",
        default="config.cfg",
        help="Output config file (default: config.cfg)"
    )
    
    args = parser.parse_args()
    
    # Parse categories
    categories = [c.strip() for c in args.categories.split(",")]
    if len(categories) < 2:
        parser.error("Need at least 2 categories")
    
    print(f"Categories: {', '.join(categories)}")
    print(f"Architecture: {args.architecture}")
    print(f"Type: {'multi-label' if args.multilabel else 'single-label'}")
    
    # Generate config
    config = generate_config(categories, args.architecture, args.multilabel)
    
    # Save config
    output_path = Path(args.output)
    output_path.write_text(config)
    print(f"\n✓ Config saved to {args.output}")
    
    print(f"\nNext steps:")
    print(f"  1. Verify config: python -m spacy debug config {args.output}")
    print(f"  2. Check data: python -m spacy debug data {args.output}")
    print(f"  3. Train: python -m spacy train {args.output} --output ./output")


if __name__ == "__main__":
    main()
