# -*- coding: utf-8 -*-
"""
Created on Fri Sep 25 14:09:12 2020

@author: jbobowsk
"""
# In this tutorial we will produce surface plots.  
import numpy as np
import matplotlib.pyplot as plt

# As example, we will plot the surface given by x*exp(-x^2-y^2).  To do
# this we have to make arrays of x values and y values.
x = np.arange(-2, 2, 0.35)
y = np.arange(-2, 2, 0.35)
n = len(x)
m = len(y)
print(n, m)

# We can then use 'np.meshgrid' to establish a list of all of the 
# required (x, y) coordinates.  I will plot the coordinates so that we can
# visualize the grid.
x, y = np.meshgrid(x, y)
plt.plot(x,y, marker='.', color='k', linestyle='none')

# We can now calculate the z-values at each of the grid positions.
z = x*np.exp(-x**2 - y**2)
print(z)

# Here's a contour plot of the data.
plt.figure()
plt.contourf(x, y, z)

# Here's a surface plot.  The syntax is a little different than we've been using
# in the previous tutorials.  Notice that I've specified the size of the figure
# in plt.figure().  'cmap' stands for colormap and there are many options to 
# choose from. 
plt.figure(figsize=(15,15))
ax = plt.axes(projection='3d')
ax.plot_surface(x, y, z,cmap='viridis')
ax.set_title('Surface plot')

# The figures above are not smooth because the data was too coarse.  Here, we
# reproduce the figures using a more finely-spaced grid of (x, y) coordinates.
x = np.arange(-2, 2, 0.01)
y = np.arange(-2, 2, 0.01)
n = len(x)
m = len(y)
x, y = np.meshgrid(x, y)
plt.figure(figsize=(15,15))
plt.plot(x,y, marker='.', markersize = 1, color='k', linestyle='none')

# We can now calculate the z-values at each of the grid positions.
z = x*np.exp(-x**2 - y**2)

# Here's a contour plot of the data.
plt.figure()
plt.contourf(x, y, z)

# Here's a surface plot.  I changed the orientation of the figure using 'azim' 
# and 'elev'.  For fun, I also changed the 'cmap' option and made the surface
# slightly transparent using `alpha'.  I also specified the number of grid lines
# that appear on the surface use 'rcount' (row count) and 'ccount' (column count).
# Finally, I lablled the x-, y-, and z-axes.  
fig = plt.figure(figsize=(15,15))
ax = plt.axes(projection='3d')
ax.view_init(azim=30)
ax.view_init(elev=15)
ax.plot_surface(x, y, z,cmap='seismic', alpha = 0.75,  linewidth = 1,\
                edgecolors = 'k', rcount = 30, ccount = 30)
ax.set_title('Surface plot')
ax.set_xlabel('x axis')
ax.set_ylabel('y axis')
ax.set_zlabel('z axis')