#!/usr/bin/env python3
"""
Nano Banana Pro - AI Image Generation & Editing
Uses Google's Gemini Image Generation API

Usage:
    python nano_banana.py generate "prompt" [options]
    python nano_banana.py edit input.png "prompt" [options]
    python nano_banana.py reference "prompt" ref1.png ref2.png ... [options]

Options:
    --output, -o    Output file path (default: generated_image.png)
    --model, -m     Model: flash, flash-lite, pro (default: flash)
    --ratio, -r     Aspect ratio: 1:1, 16:9, 9:16, 4:3, 3:4, etc (default: 1:1)
    --size, -s      Resolution: 1K, 2K, 4K (default: 1K)
"""

import argparse
import os
import sys
from pathlib import Path

try:
    from google import genai
    from google.genai import types
except ImportError:
    print("Error: google-genai not installed. Run: pip install google-genai>=1.52.0")
    sys.exit(1)

try:
    from PIL import Image
except ImportError:
    print("Error: Pillow not installed. Run: pip install Pillow")
    sys.exit(1)

from dotenv import load_dotenv

# Load .env file if exists
load_dotenv()

# Model mapping (Updated Dec 2025)
# - Flash: gemini-2.5-flash-image (for image generation, no advanced config)
# - Pro: gemini-3-pro-image-preview (Best quality, supports aspect_ratio/size)
# Note: gemini-2.5-flash (without -image) does NOT generate images
MODELS = {
    "flash": "gemini-2.5-flash-image",
    "pro": "gemini-3-pro-image-preview"
}

# Models that support advanced image config (aspect_ratio, image_size)
MODELS_WITH_IMAGE_CONFIG = ["pro"]

# Valid aspect ratios
VALID_RATIOS = ["1:1", "2:3", "3:2", "3:4", "4:3", "4:5", "5:4", "9:16", "16:9", "21:9"]

# Valid resolutions
VALID_SIZES = ["1K", "2K", "4K"]


def get_client():
    """Initialize the Gemini client with API key."""
    api_key = os.environ.get("GOOGLE_AI_API_KEY")
    if not api_key:
        print("Error: GOOGLE_AI_API_KEY environment variable not set.")
        print("Get your API key at: https://aistudio.google.com/apikey")
        sys.exit(1)
    return genai.Client(api_key=api_key)


def save_response_images(response, output_path):
    """Extract and save images from API response."""
    saved_files = []
    output_path = Path(output_path)

    image_count = 0
    for part in response.parts:
        if part.text is not None:
            print(f"Model response: {part.text}")
        elif part.inline_data is not None:
            image = part.as_image()
            if image_count == 0:
                save_path = output_path
            else:
                save_path = output_path.with_stem(f"{output_path.stem}_{image_count}")

            image.save(str(save_path))
            saved_files.append(str(save_path))
            image_count += 1

    return saved_files


def generate_image(prompt, model="flash", ratio="1:1", size="1K", output="generated_image.png"):
    """Generate image from text prompt."""
    client = get_client()
    model_id = MODELS.get(model, MODELS["flash"])

    print(f"Generating image with {model} model ({model_id})...")
    print(f"Prompt: {prompt}")

    # Only Pro model supports aspect_ratio and image_size
    if model in MODELS_WITH_IMAGE_CONFIG:
        print(f"Settings: ratio={ratio}, size={size}")
        config = types.GenerateContentConfig(
            response_modalities=['TEXT', 'IMAGE'],
            image_config=types.ImageConfig(
                aspect_ratio=ratio,
                image_size=size
            )
        )
    else:
        print("Note: Flash model uses default settings (no custom ratio/size)")
        config = types.GenerateContentConfig(
            response_modalities=['TEXT', 'IMAGE']
        )

    response = client.models.generate_content(
        model=model_id,
        contents=[prompt],
        config=config
    )

    saved = save_response_images(response, output)
    if saved:
        print(f"Image saved to: {', '.join(saved)}")
    else:
        print("No image was generated.")

    return saved


def edit_image(input_path, prompt, model="flash", ratio="1:1", size="1K", output="edited_image.png"):
    """Edit an existing image with a text prompt."""
    client = get_client()
    model_id = MODELS.get(model, MODELS["flash"])

    if not os.path.exists(input_path):
        print(f"Error: Input file not found: {input_path}")
        sys.exit(1)

    print(f"Editing image with {model} model ({model_id})...")
    print(f"Input: {input_path}")
    print(f"Prompt: {prompt}")

    image = Image.open(input_path)

    # Only Pro model supports aspect_ratio and image_size
    if model in MODELS_WITH_IMAGE_CONFIG:
        config = types.GenerateContentConfig(
            response_modalities=['TEXT', 'IMAGE'],
            image_config=types.ImageConfig(
                aspect_ratio=ratio,
                image_size=size
            )
        )
    else:
        config = types.GenerateContentConfig(
            response_modalities=['TEXT', 'IMAGE']
        )

    response = client.models.generate_content(
        model=model_id,
        contents=[prompt, image],
        config=config
    )

    saved = save_response_images(response, output)
    if saved:
        print(f"Edited image saved to: {', '.join(saved)}")
    else:
        print("No image was generated.")

    return saved


def generate_with_references(prompt, reference_paths, model="pro", ratio="1:1", size="2K", output="reference_image.png"):
    """Generate image using reference images (up to 14)."""
    client = get_client()

    # Reference images work best with Pro model
    if model == "flash":
        print("Note: Reference images work best with Pro model. Consider using --model pro")

    model_id = MODELS.get(model, MODELS["pro"])

    # Validate reference paths
    images = []
    for ref_path in reference_paths:
        if not os.path.exists(ref_path):
            print(f"Error: Reference file not found: {ref_path}")
            sys.exit(1)
        images.append(Image.open(ref_path))

    if len(images) > 14:
        print("Warning: Maximum 14 reference images supported. Using first 14.")
        images = images[:14]

    print(f"Generating image with {model} model...")
    print(f"Prompt: {prompt}")
    print(f"Reference images: {len(images)}")

    config = types.GenerateContentConfig(
        response_modalities=['TEXT', 'IMAGE'],
        image_config=types.ImageConfig(
            aspect_ratio=ratio,
            image_size=size
        )
    )

    contents = [prompt] + images

    response = client.models.generate_content(
        model=model_id,
        contents=contents,
        config=config
    )

    saved = save_response_images(response, output)
    if saved:
        print(f"Image saved to: {', '.join(saved)}")
    else:
        print("No image was generated.")

    return saved


def main():
    parser = argparse.ArgumentParser(
        description="Nano Banana Pro - AI Image Generation & Editing",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Examples:
  %(prog)s generate "a cat wearing sunglasses"
  %(prog)s generate "sunset over mountains" -r 16:9 -s 2K -o sunset.png
  %(prog)s edit photo.jpg "add a rainbow in the sky"
  %(prog)s reference "group photo in office" person1.png person2.png -m pro
        """
    )

    subparsers = parser.add_subparsers(dest="command", help="Command to run")

    # Generate command
    gen_parser = subparsers.add_parser("generate", help="Generate image from text")
    gen_parser.add_argument("prompt", help="Text prompt for image generation")
    gen_parser.add_argument("-o", "--output", default="generated_image.png", help="Output file path")
    gen_parser.add_argument("-m", "--model", choices=["flash", "pro"], default="flash", help="Model to use")
    gen_parser.add_argument("-r", "--ratio", choices=VALID_RATIOS, default="1:1", help="Aspect ratio")
    gen_parser.add_argument("-s", "--size", choices=VALID_SIZES, default="1K", help="Resolution")

    # Edit command
    edit_parser = subparsers.add_parser("edit", help="Edit existing image")
    edit_parser.add_argument("input", help="Input image path")
    edit_parser.add_argument("prompt", help="Edit instruction")
    edit_parser.add_argument("-o", "--output", default="edited_image.png", help="Output file path")
    edit_parser.add_argument("-m", "--model", choices=["flash", "pro"], default="flash", help="Model to use")
    edit_parser.add_argument("-r", "--ratio", choices=VALID_RATIOS, default="1:1", help="Aspect ratio")
    edit_parser.add_argument("-s", "--size", choices=VALID_SIZES, default="1K", help="Resolution")

    # Reference command
    ref_parser = subparsers.add_parser("reference", help="Generate with reference images")
    ref_parser.add_argument("prompt", help="Text prompt")
    ref_parser.add_argument("references", nargs="+", help="Reference image paths (max 14)")
    ref_parser.add_argument("-o", "--output", default="reference_image.png", help="Output file path")
    ref_parser.add_argument("-m", "--model", choices=["flash", "pro"], default="pro", help="Model to use")
    ref_parser.add_argument("-r", "--ratio", choices=VALID_RATIOS, default="1:1", help="Aspect ratio")
    ref_parser.add_argument("-s", "--size", choices=VALID_SIZES, default="2K", help="Resolution")

    args = parser.parse_args()

    if args.command is None:
        parser.print_help()
        sys.exit(1)

    if args.command == "generate":
        generate_image(args.prompt, args.model, args.ratio, args.size, args.output)
    elif args.command == "edit":
        edit_image(args.input, args.prompt, args.model, args.ratio, args.size, args.output)
    elif args.command == "reference":
        generate_with_references(args.prompt, args.references, args.model, args.ratio, args.size, args.output)


if __name__ == "__main__":
    main()
