import plotly.express as px
import plotly.graph_objects as go
import numpy as np
import pandas as pd
import scipy.signal
import os

# Create plots directory if it doesn't exist
os.makedirs('plots', exist_ok=True)

def save_plot(fig, name):
    fig.write_html('plots/'+name+'.html', include_plotlyjs='cdn')
    fig.write_image('plots/'+name+'.png')
    fig.write_image('plots/'+name+'.svg')

def calculate_psd(signal_data, sample_rate, bin_width=1):
    """Calculate PSD using Welch's method"""
    nperseg = int(sample_rate / bin_width)
    freqs, psd_values = scipy.signal.welch(
        signal_data, 
        fs=sample_rate, 
        nperseg=nperseg,
        scaling='density'
    )
    return freqs, psd_values

# Example apartment layout coordinates (x, y in meters or arbitrary units)
# Modify these to match your actual apartment layout
locations = {
    'Living Room Center': (5, 5),
    'Living Room Corner': (2, 2),
    'Kitchen': (8, 2),
    'Bedroom 1': (2, 8),
    'Bedroom 2': (8, 8),
    'Hallway': (5, 6),
    'Bathroom': (7, 5),
    'Near Window': (1, 5),
}

# Simulate sample data for each location
# In practice, replace this with your actual recorded data
sample_rate = 250
duration = 10
t = np.arange(0, duration, 1/sample_rate)

# Generate different signals for each location (replace with actual measurements)
np.random.seed(42)  # For reproducibility
location_signals = {}
location_psds = {}

for location in locations:
    # Simulate different noise characteristics at each location
    # Replace this with actual sensor data
    base_noise = np.random.randn(len(t)) * 0.5
    
    # Add location-specific frequency content
    if 'Living Room' in location:
        base_noise += np.sin(t * 2 * np.pi * 30) * 2  # Strong 30 Hz
    elif 'Kitchen' in location:
        base_noise += np.sin(t * 2 * np.pi * 60) * 1.5  # Strong 60 Hz (appliances)
    elif 'Bedroom' in location:
        base_noise += np.sin(t * 2 * np.pi * 20) * 1.2  # Lower frequency
    elif 'Window' in location:
        base_noise += np.sin(t * 2 * np.pi * 40) * 1.8  # External noise
    
    location_signals[location] = base_noise
    freqs, psd_values = calculate_psd(base_noise, sample_rate)
    location_psds[location] = pd.DataFrame({'frequency': freqs, 'psd': psd_values})

# Create DataFrame with all location data
all_data = []
for location, coords in locations.items():
    psd_df = location_psds[location].copy()
    psd_df['location'] = location
    psd_df['x'] = coords[0]
    psd_df['y'] = coords[1]
    all_data.append(psd_df)

df_all = pd.concat(all_data, ignore_index=True)

# Plot PSD for each location
print("\nGenerating individual PSD plots for each location...")
fig_all_locations = go.Figure()

for location in locations:
    psd_df = location_psds[location]
    fig_all_locations.add_trace(go.Scatter(
        x=psd_df['frequency'],
        y=psd_df['psd'],
        mode='lines',
        name=location
    ))

fig_all_locations.update_layout(
    title='PSD Comparison - All Locations',
    xaxis_title='Frequency (Hz)',
    yaxis_title='PSD (g²/Hz)',
    yaxis_type='log',
    width=1000,
    height=600,
    hovermode='x unified'
)
save_plot(fig_all_locations, '5_psd_all_locations')
fig_all_locations.show()

# Create individual PSD plots for each location
for location in locations:
    psd_df = location_psds[location]
    
    fig = go.Figure()
    fig.add_trace(go.Scatter(
        x=psd_df['frequency'],
        y=psd_df['psd'],
        mode='lines',
        name=location,
        line=dict(width=2)
    ))
    
    # Add markers for peak frequencies
    peak_indices = scipy.signal.find_peaks(psd_df['psd'].values, height=np.max(psd_df['psd']) * 0.1)[0]
    if len(peak_indices) > 0:
        fig.add_trace(go.Scatter(
            x=psd_df['frequency'].iloc[peak_indices],
            y=psd_df['psd'].iloc[peak_indices],
            mode='markers',
            name='Peaks',
            marker=dict(size=10, color='red', symbol='x')
        ))
    
    fig.update_layout(
        title=f'PSD - {location}',
        xaxis_title='Frequency (Hz)',
        yaxis_title='PSD (g²/Hz)',
        yaxis_type='log',
        width=800,
        height=500,
        showlegend=True
    )
    
    # Save individual plot
    safe_name = location.replace(' ', '_').replace('&', 'and')
    save_plot(fig, f'6_psd_{safe_name}')
    fig.show()
    
    # Print peak frequencies
    if len(peak_indices) > 0:
        print(f"\n{location} - Peak frequencies:")
        for idx in peak_indices[:5]:  # Show top 5 peaks
            print(f"  {psd_df['frequency'].iloc[idx]:.1f} Hz: {psd_df['psd'].iloc[idx]:.6f} g²/Hz")

# Function to create heatmap for a specific frequency
def create_frequency_heatmap(df, target_freq, freq_tolerance=0.5):
    """
    Create a spatial heatmap for a specific frequency
    
    Parameters:
    - df: DataFrame with columns ['frequency', 'psd', 'location', 'x', 'y']
    - target_freq: Frequency to visualize (Hz)
    - freq_tolerance: Frequency range to average around target (Hz)
    """
    # Filter data around target frequency
    df_freq = df[(df['frequency'] >= target_freq - freq_tolerance) & 
                 (df['frequency'] <= target_freq + freq_tolerance)]
    
    # Average PSD values for each location
    location_avg = df_freq.groupby(['location', 'x', 'y'])['psd'].mean().reset_index()
    
    # Create scatter plot on 2D apartment layout
    fig = px.scatter(
        location_avg,
        x='x',
        y='y',
        size='psd',
        color='psd',
        hover_name='location',
        hover_data={'x': ':.1f', 'y': ':.1f', 'psd': ':.4f'},
        color_continuous_scale='Viridis',
        size_max=50,
        title=f'PSD Intensity at {target_freq} Hz Across Apartment'
    )
    
    fig.update_layout(
        xaxis_title='X Position (m)',
        yaxis_title='Y Position (m)',
        coloraxis_colorbar_title='PSD (g²/Hz)',
        width=800,
        height=600
    )
    
    # Add location labels
    for _, row in location_avg.iterrows():
        fig.add_annotation(
            x=row['x'],
            y=row['y'],
            text=row['location'],
            showarrow=False,
            yshift=15,
            font=dict(size=9)
        )
    
    return fig, location_avg

# Create an interactive plot with frequency slider
def create_interactive_heatmap(df):
    """Create interactive heatmap with frequency slider"""
    freq_range = np.arange(0, 125, 5)  # 0 to 125 Hz in 5 Hz steps
    
    frames = []
    for freq in freq_range:
        df_freq = df[(df['frequency'] >= freq - 2.5) & 
                     (df['frequency'] <= freq + 2.5)]
        location_avg = df_freq.groupby(['location', 'x', 'y'])['psd'].mean().reset_index()
        
        frame = go.Frame(
            data=[go.Scatter(
                x=location_avg['x'],
                y=location_avg['y'],
                mode='markers+text',
                marker=dict(
                    size=location_avg['psd'] * 100,
                    color=location_avg['psd'],
                    colorscale='Viridis',
                    showscale=True,
                    colorbar=dict(title='PSD (g²/Hz)')
                ),
                text=location_avg['location'],
                textposition='top center',
                hovertemplate='<b>%{text}</b><br>PSD: %{marker.color:.4f} g²/Hz<extra></extra>'
            )],
            name=str(freq)
        )
        frames.append(frame)
    
    # Initial frame
    df_initial = df[(df['frequency'] >= 28) & (df['frequency'] <= 32)]
    location_avg_initial = df_initial.groupby(['location', 'x', 'y'])['psd'].mean().reset_index()
    
    fig = go.Figure(
        data=[go.Scatter(
            x=location_avg_initial['x'],
            y=location_avg_initial['y'],
            mode='markers+text',
            marker=dict(
                size=location_avg_initial['psd'] * 100,
                color=location_avg_initial['psd'],
                colorscale='Viridis',
                showscale=True,
                colorbar=dict(title='PSD (g²/Hz)')
            ),
            text=location_avg_initial['location'],
            textposition='top center',
            hovertemplate='<b>%{text}</b><br>PSD: %{marker.color:.4f} g²/Hz<extra></extra>'
        )],
        frames=frames
    )
    
    fig.update_layout(
        title='Interactive PSD Heatmap - Use Slider to Change Frequency',
        xaxis_title='X Position (m)',
        yaxis_title='Y Position (m)',
        updatemenus=[{
            'type': 'buttons',
            'showactive': False,
            'buttons': [
                {'label': 'Play', 'method': 'animate', 
                 'args': [None, {'frame': {'duration': 200}}]},
                {'label': 'Pause', 'method': 'animate',
                 'args': [[None], {'frame': {'duration': 0}, 'mode': 'immediate'}]}
            ]
        }],
        sliders=[{
            'active': 6,
            'steps': [
                {'args': [[f.name], {'frame': {'duration': 0}, 'mode': 'immediate'}],
                 'label': f'{f.name} Hz', 'method': 'animate'}
                for f in frames
            ]
        }],
        width=900,
        height=700
    )
    
    return fig

# Create interactive plot
fig_interactive = create_interactive_heatmap(df_all)
save_plot(fig_interactive, '4_interactive_heatmap')
fig_interactive.show()

print("\nTo use with real data:")
print("1. Record accelerometer data at each location in your apartment")
print("2. Replace the simulated signals with your actual data")
print("3. Update the 'locations' dictionary with actual coordinates")
print("4. Run the script to generate spatial frequency heatmaps")