import matplotlib.pyplot as plt
import numpy as np
import json

from matplotlib import cm
from matplotlib.ticker import LinearLocator

# Load data
with open("computed_tables.json", "r") as f:
    data = json.load(f)

fig, ax = plt.subplots(subplot_kw={"projection": "3d"})

# Extract data
X_values = data["depths"]  # Depth in km
Y_values = data["angles"]  # Angle in degrees
X, Y = np.meshgrid(range(len(Y_values)), range(len(X_values)))
Z = np.array(data["p_table"])
Z2 = np.array(data["s_table"])

# Plot the first surface (P table)
surf = ax.plot_surface(X, Y, Z, cmap=cm.viridis, linewidth=1, antialiased=True, label="P-Wave Table")

# Plot the second surface (S table)
surf2 = ax.plot_surface(X, Y, Z2, cmap=cm.magma, linewidth=1, antialiased=True, label="S-Wave Table")

# Customize axes labels
ax.set_xlabel("Depth (km)", fontsize=12, labelpad=10)
ax.set_ylabel("Angle (°)", fontsize=12, labelpad=10)
ax.set_zlabel("Value", fontsize=12, labelpad=10)

# Customize Z-axis format
ax.zaxis.set_major_locator(LinearLocator(10))
ax.zaxis.set_major_formatter('{x:.02f}')

# Add color bars
cbar1 = fig.colorbar(surf, shrink=0.5, aspect=5, pad=0.1)
cbar1.set_label("P-Wave Value", fontsize=10)

cbar2 = fig.colorbar(surf2, shrink=0.5, aspect=5, pad=0.1)
cbar2.set_label("S-Wave Value", fontsize=10)

# Set title
ax.set_title("3D Visualization of P-Table and S-Table", fontsize=14, pad=20)

plt.show()
