#!/usr/bin/env python3
"""
AWS S3 Cost Calculator

Calculate actual and forecasted monthly costs for S3 static website hosting.
"""

import argparse
import json
import subprocess
from datetime import datetime, timedelta, timezone
from typing import Optional


# S3 Pricing (us-east-1)
S3_STORAGE_PRICE_PER_GB = 0.023  # $/GB/month
S3_GET_REQUEST_PRICE = 0.0004   # $ per 1,000 requests
S3_PUT_REQUEST_PRICE = 0.005    # $ per 1,000 requests
DATA_TRANSFER_PRICE_PER_GB = 0.09  # $/GB (after first 100GB free)
DATA_TRANSFER_FREE_TIER_GB = 100


def run_aws_command(command: list[str]) -> Optional[dict | list]:
    """Execute AWS CLI command and return JSON output."""
    try:
        result = subprocess.run(
            command,
            capture_output=True,
            text=True,
            check=True
        )
        return json.loads(result.stdout) if result.stdout.strip() else None
    except subprocess.CalledProcessError as e:
        print(f"Error running command: {' '.join(command)}")
        print(f"Error: {e.stderr}")
        return None
    except json.JSONDecodeError:
        return None


def get_bucket_size(bucket_name: str) -> tuple[int, int]:
    """Get total size and object count for a bucket."""
    command = [
        "aws", "s3api", "list-objects-v2",
        "--bucket", bucket_name,
        "--query", "Contents[].Size",
        "--output", "json"
    ]
    result = run_aws_command(command)

    if not result:
        return 0, 0

    total_size = sum(size for size in result if size)
    object_count = len(result)
    return total_size, object_count


def get_bucket_metrics(bucket_name: str, days: int = 30) -> dict:
    """Get CloudWatch metrics for S3 bucket requests."""
    end_time = datetime.now(timezone.utc)
    start_time = end_time - timedelta(days=days)

    metrics = {
        "get_requests": 0,
        "put_requests": 0,
        "bytes_downloaded": 0,
        "bytes_uploaded": 0
    }

    # Get request metrics
    metric_queries = [
        ("GetRequests", "get_requests"),
        ("PutRequests", "put_requests"),
        ("BytesDownloaded", "bytes_downloaded"),
        ("BytesUploaded", "bytes_uploaded"),
    ]

    for metric_name, key in metric_queries:
        command = [
            "aws", "cloudwatch", "get-metric-statistics",
            "--namespace", "AWS/S3",
            "--metric-name", metric_name,
            "--dimensions",
            f"Name=BucketName,Value={bucket_name}",
            "Name=FilterId,Value=AllStorageTypes",
            "--start-time", start_time.isoformat(),
            "--end-time", end_time.isoformat(),
            "--period", str(days * 24 * 3600),
            "--statistics", "Sum",
            "--output", "json"
        ]
        result = run_aws_command(command)

        if result and result.get("Datapoints"):
            metrics[key] = sum(dp.get("Sum", 0) for dp in result["Datapoints"])

    return metrics


def calculate_storage_cost(size_bytes: int) -> float:
    """Calculate monthly storage cost."""
    size_gb = size_bytes / (1024 ** 3)
    return size_gb * S3_STORAGE_PRICE_PER_GB


def calculate_request_cost(get_requests: int, put_requests: int) -> float:
    """Calculate request costs."""
    get_cost = (get_requests / 1000) * S3_GET_REQUEST_PRICE
    put_cost = (put_requests / 1000) * S3_PUT_REQUEST_PRICE
    return get_cost + put_cost


def calculate_data_transfer_cost(bytes_transferred: int) -> float:
    """Calculate data transfer cost (excluding free tier)."""
    gb_transferred = bytes_transferred / (1024 ** 3)
    billable_gb = max(0, gb_transferred - DATA_TRANSFER_FREE_TIER_GB)
    return billable_gb * DATA_TRANSFER_PRICE_PER_GB


def format_bytes(bytes_value: int) -> str:
    """Format bytes to human readable string."""
    for unit in ['B', 'KB', 'MB', 'GB', 'TB']:
        if bytes_value < 1024:
            return f"{bytes_value:.2f} {unit}"
        bytes_value /= 1024
    return f"{bytes_value:.2f} PB"


def format_currency(amount: float) -> str:
    """Format amount as currency."""
    if amount < 0.01:
        return f"${amount:.6f}"
    return f"${amount:.4f}"


def list_buckets() -> list[str]:
    """List all S3 buckets."""
    command = ["aws", "s3api", "list-buckets", "--query", "Buckets[].Name", "--output", "json"]
    result = run_aws_command(command)
    return result if result else []


def analyze_bucket(bucket_name: str, forecast_visits: int = 0) -> dict:
    """Analyze a single bucket and calculate costs."""
    print(f"\nAnalyzing bucket: {bucket_name}")
    print("=" * 50)

    # Get bucket size
    total_size, object_count = get_bucket_size(bucket_name)
    print(f"Total size: {format_bytes(total_size)}")
    print(f"Object count: {object_count}")

    # Get metrics (last 30 days)
    metrics = get_bucket_metrics(bucket_name, days=30)
    print("\nLast 30 days metrics:")
    print(f"  GET requests: {metrics['get_requests']:,.0f}")
    print(f"  PUT requests: {metrics['put_requests']:,.0f}")
    print(f"  Data downloaded: {format_bytes(metrics['bytes_downloaded'])}")
    print(f"  Data uploaded: {format_bytes(metrics['bytes_uploaded'])}")

    # Calculate actual costs
    storage_cost = calculate_storage_cost(total_size)
    request_cost = calculate_request_cost(metrics['get_requests'], metrics['put_requests'])
    transfer_cost = calculate_data_transfer_cost(metrics['bytes_downloaded'])
    total_actual = storage_cost + request_cost + transfer_cost

    print("\nActual Monthly Cost (based on last 30 days):")
    print(f"  Storage: {format_currency(storage_cost)}")
    print(f"  Requests: {format_currency(request_cost)}")
    print(f"  Data Transfer: {format_currency(transfer_cost)}")
    print("  " + "-" * 25)
    print(f"  TOTAL: {format_currency(total_actual)}")

    # Forecast costs if visits provided
    if forecast_visits > 0:
        avg_page_size = total_size / max(object_count, 1)
        forecast_get_requests = forecast_visits * 2  # Assume 2 requests per visit
        forecast_transfer = forecast_visits * avg_page_size

        forecast_request_cost = calculate_request_cost(forecast_get_requests, 0)
        forecast_transfer_cost = calculate_data_transfer_cost(forecast_transfer)
        total_forecast = storage_cost + forecast_request_cost + forecast_transfer_cost

        print(f"\nForecasted Monthly Cost ({forecast_visits:,} visits/month):")
        print(f"  Storage: {format_currency(storage_cost)}")
        print(f"  Requests (~{forecast_get_requests:,} GET): {format_currency(forecast_request_cost)}")
        print(f"  Data Transfer (~{format_bytes(forecast_transfer)}): {format_currency(forecast_transfer_cost)}")
        print("  " + "-" * 25)
        print(f"  TOTAL: {format_currency(total_forecast)}")

    return {
        "bucket_name": bucket_name,
        "size_bytes": total_size,
        "object_count": object_count,
        "metrics": metrics,
        "actual_cost": {
            "storage": storage_cost,
            "requests": request_cost,
            "transfer": transfer_cost,
            "total": total_actual
        }
    }


def main():
    parser = argparse.ArgumentParser(
        description="Calculate AWS S3 hosting costs",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Examples:
  # Analyze a specific bucket
  uv run cost.py --bucket my-website-bucket

  # Analyze all buckets
  uv run cost.py --all

  # Forecast costs for 10,000 monthly visits
  uv run cost.py --bucket my-bucket --forecast 10000

  # Output as JSON
  uv run cost.py --bucket my-bucket --json
        """
    )

    parser.add_argument("--bucket", "-b", help="S3 bucket name to analyze")
    parser.add_argument("--all", "-a", action="store_true", help="Analyze all buckets")
    parser.add_argument("--forecast", "-f", type=int, default=0, help="Forecast visits per month")
    parser.add_argument("--json", "-j", action="store_true", help="Output as JSON")

    args = parser.parse_args()

    if not args.bucket and not args.all:
        # List available buckets
        buckets = list_buckets()
        if buckets:
            print("Available S3 buckets:")
            for bucket in buckets:
                print(f"  - {bucket}")
            print("\nUse --bucket <name> or --all to analyze")
        else:
            print("No S3 buckets found or unable to list buckets.")
        return

    results = []

    if args.all:
        buckets = list_buckets()
        for bucket in buckets:
            result = analyze_bucket(bucket, args.forecast)
            results.append(result)
    else:
        result = analyze_bucket(args.bucket, args.forecast)
        results.append(result)

    if args.json:
        print("\n" + json.dumps(results, indent=2))

    # Summary
    if len(results) > 1:
        total_cost = sum(r["actual_cost"]["total"] for r in results)
        print("\n" + "=" * 50)
        print(f"TOTAL MONTHLY COST (all buckets): {format_currency(total_cost)}")


if __name__ == "__main__":
    main()
