import plotly.express as px
import plotly.graph_objects as go
import numpy as np
import pandas as pd
import scipy.signal
import os
import argparse

# Create plots directory if it doesn't exist
os.makedirs('plots', exist_ok=True)

def save_plot(fig, name):
    fig.write_html('plots/'+name+'.html', include_plotlyjs='cdn')
    fig.write_image('plots/'+name+'.png')
    fig.write_image('plots/'+name+'.svg')

def generate_sample_data(sample_rate=100, duration=10):
    """
    Generate synthetic vibration data for testing
    
    Parameters:
    - sample_rate: Sample rate in Hz
    - duration: Duration in seconds
    
    Returns:
    - df: DataFrame with sample Z-axis data
    """

    # Generate sample data as fallback (Z-axis only)
    sample_rate = 250
    duration = 10
    t = np.arange(0, duration, 1/sample_rate)
    
    white_noise = np.random.randn(len(t))
    sos = scipy.signal.butter(4, [10, 50], btype='bandpass', fs=sample_rate, output='sos')
    band_noise = scipy.signal.sosfilt(sos, white_noise)
    band_noise = band_noise / np.sqrt(np.mean(band_noise**2))
    
    sine_wave = np.sin(t * 2 * np.pi * 30) + np.sin(t * 2 * np.pi * 80.25)
    sine_on_random = sine_wave + band_noise
    
    # Only Z-axis data
    noise = pd.DataFrame({
        'Z-axis': sine_on_random
    }, index=t)
    noise.index.name = 'Time (s)'

    print(f"Generated {len(noise)} synthetic samples at {sample_rate} Hz ({duration:.2f} seconds)")
    
    return noise

def load_data_from_csv_fixed_rate(filename, sample_rate, data_cols=None):
    """
    Load sensor data from CSV with known sample rate
    
    Parameters:
    - filename: Path to CSV file
    - sample_rate: Sample rate in Hz
    - data_cols: List of column names to analyze (None = all numeric columns)
    
    Returns:
    - df: DataFrame with data indexed by time
    - sample_rate: Sample rate in Hz
    """
    df = pd.read_csv(filename)
    
    # Select data columns
    if data_cols is not None:
        df = df[data_cols]
    else:
        df = df.select_dtypes(include=[np.number])
    
    # Create time index
    duration = len(df) / sample_rate
    time_index = np.arange(0, duration, 1/sample_rate)[:len(df)]
    df.index = time_index
    df.index.name = 'Time (s)'
    
    print(f"Loaded {len(df)} samples at {sample_rate} Hz ({duration:.2f} seconds)")
    print(f"Columns: {list(df.columns)}")
    
    return df, sample_rate

if __name__ == "__main__":
    # Parse command-line arguments
    parser = argparse.ArgumentParser(
        description='Analyze Z-axis vibration data and generate PSD plots',
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Examples:
  python geo.py                    # Use sample data
  python geo.py data.csv           # Analyze CSV file
  python geo.py data.csv --sample-rate 250
  python geo.py data.csv --sample-rate 250 --output my_analysis
  python geo.py data.csv -r 250 -o my_analysis
        """
    )
    
    parser.add_argument('csv_file', 
                        nargs='?',
                        default=None,
                        help='Path to CSV file containing Z-axis data (optional, uses sample data if not provided)')
    parser.add_argument('-r', '--sample-rate', 
                        type=float, 
                        default=100,
                        help='Sample rate in Hz (default: 100)')
    parser.add_argument('-o', '--output',
                        default=None,
                        help='Root name for output plots (default: CSV filename or "sample")')
    parser.add_argument('-c', '--column',
                        default='z',
                        help='Name of Z-axis column (default: z)')
    
    args = parser.parse_args()
    
    """Load Data"""
    if args.csv_file is None:
        # Generate sample data
        noise = generate_sample_data(sample_rate=args.sample_rate, duration=10)
        sample_rate = args.sample_rate
        output_name = args.output if args.output else 'sample'
        print("✓ Using synthetic sample data (Z-axis)")
    else:
        # Load from CSV
        try:
            noise, sample_rate = load_data_from_csv_fixed_rate(
                args.csv_file,
                sample_rate=args.sample_rate,
                data_cols=[args.column]
            )
            output_name = args.output if args.output else os.path.splitext(os.path.basename(args.csv_file))[0]
            print("✓ Successfully loaded real data (Z-axis)")
        except FileNotFoundError:
            print(f"✗ Error: File '{args.csv_file}' not found")
            exit(1)
        except KeyError:
            print(f"✗ Error: Column '{args.column}' not found in CSV")
            print(f"  Run with --column <name> to specify the correct column")
            exit(1)

    """Plot Time Series"""
    fig = px.line(noise).update_layout(
        yaxis_title_text='Acceleration Z-axis (g)', 
        legend_title_text='',
        title='Time Series Data (Z-axis)'
    )
    save_plot(fig, f'{output_name}_time_series')
    fig.show()

    """Calculate PSD with Function"""
    # Calculate PSD using Welch's method for Z-axis
    bin_width = 1  # Hz - adjust this for frequency resolution
    nperseg = int(sample_rate / bin_width)

    psd_dict = {}
    for col in noise.columns:
        freqs, psd_values = scipy.signal.welch(
            noise[col].values, 
            fs=sample_rate,
            nperseg=nperseg,
            scaling='density'
        )
        psd_dict[col] = psd_values

    psd = pd.DataFrame(psd_dict, index=freqs)
    psd.index.name = 'frequency (Hz)'

    """Plot PSD"""
    fig = px.line(psd).update_layout(
        yaxis_title_text='Acceleration (g²/Hz)', 
        xaxis_title_text='Frequency (Hz)', 
        legend_title_text='',
        #yaxis_type='log',  # Log scale helps see details
        title='Power Spectral Density (Z-axis)'
    )
    save_plot(fig, f'{output_name}_psd')
    fig.show()

    print(f"\n✓ Analysis complete!")
    print(f"  Sample rate: {sample_rate} Hz")
    print(f"  Duration: {len(noise)/sample_rate:.2f} seconds")
    print(f"  Samples: {len(noise)}")
    print(f"  Channel analyzed: Z-axis only")
    print(f"  Frequency resolution: {bin_width} Hz")
    print(f"  Max frequency: {freqs[-1]:.1f} Hz")
    print(f"\n  Plots saved to: plots/")
    print(f"    - plots/{output_name}_time_series.html/png/svg")
    print(f"    - plots/{output_name}_psd.html/png/svg")